| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- import json
- import uuid
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any
- from fastapi import UploadFile
- from app.config import get_settings
- from app.core.db import async_session, DatasetRecord
- from app.core.logging import logger
- from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
- settings = get_settings()
- async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
- """从 HuggingFace 或 ModelScope 下载数据集。"""
- try:
- from datasets import load_dataset
- ds = load_dataset(req.dataset_id)
- ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
- ds_dir.mkdir(parents=True, exist_ok=True)
- # 保存为 JSONL
- if "train" in ds:
- split = ds["train"]
- else:
- split = ds[list(ds.keys())[0]]
- output_path = ds_dir / "data.jsonl"
- with open(output_path, "w", encoding="utf-8") as f:
- for item in split:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- # 写入数据库
- record = DatasetRecord(
- id=str(uuid.uuid4()),
- name=req.dataset_id,
- format="jsonl",
- record_count=len(split),
- file_path=str(output_path),
- created_at=datetime.now(timezone.utc),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Downloaded dataset: {req.dataset_id} ({len(split)} records)")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(output_path))
- except Exception as e:
- logger.error(f"Dataset download failed: {e}")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
- async def upload_dataset(file: UploadFile) -> dict[str, Any]:
- """保存上传文件并写入数据库。"""
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- # 避免文件名冲突
- safe_name = file.filename or "unknown"
- file_path = upload_dir / safe_name
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
- content = await file.read()
- file_path.write_bytes(content)
- fmt = _detect_format(safe_name)
- record_count = _count_records(file_path, fmt)
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=safe_name,
- format=fmt,
- record_count=record_count,
- file_path=str(file_path),
- created_at=datetime.now(timezone.utc),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
- return {
- "id": record_id,
- "name": safe_name,
- "format": fmt,
- "record_count": record_count,
- "file_path": str(file_path),
- "created_at": record.created_at.isoformat(),
- }
- async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
- """预览数据集前 N 行。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"total_records": 0, "preview_rows": [], "columns": []}
- file_path = Path(record.file_path)
- if not file_path.exists():
- return {"total_records": 0, "preview_rows": [], "columns": []}
- fmt = record.format
- preview_data = _read_records(file_path, fmt, rows)
- columns = list(preview_data[0].keys()) if preview_data else []
- return {
- "total_records": record.record_count,
- "preview_rows": [{"row_index": i, "data": row} for i, row in enumerate(preview_data)],
- "columns": columns,
- }
- async def validate_dataset(dataset_id: str) -> dict[str, Any]:
- """校验数据集格式和 Schema。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
- file_path = Path(record.file_path)
- if not file_path.exists():
- return {"is_valid": False, "errors": ["File not found"], "warnings": []}
- errors = []
- warnings = []
- # 检查格式
- fmt = record.format
- if fmt not in ("jsonl", "csv", "json", "parquet"):
- errors.append(f"Unsupported format: {fmt}")
- # 检查内容
- try:
- preview = _read_records(file_path, fmt, 5)
- if not preview:
- warnings.append("Dataset appears to be empty")
- else:
- # 检查必需字段(SFT 格式)
- first = preview[0]
- has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
- if not has_sft_fields:
- warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
- except Exception as e:
- errors.append(f"Failed to read file: {str(e)}")
- return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
- async def list_datasets() -> list[dict[str, Any]]:
- """列出所有已上传数据集。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
- records = result.scalars().all()
- return [
- {
- "id": r.id,
- "name": r.name,
- "format": r.format,
- "record_count": r.record_count,
- "file_path": r.file_path,
- "created_at": r.created_at.isoformat(),
- }
- for r in records
- ]
- async def delete_dataset(dataset_id: str) -> dict[str, Any]:
- """删除数据集。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if record:
- # 删除文件
- file_path = Path(record.file_path)
- if file_path.exists():
- file_path.unlink()
- await session.delete(record)
- await session.commit()
- logger.info(f"Deleted dataset: {record.name}")
- return {"status": "deleted"}
- def _detect_format(filename: str) -> str:
- ext = Path(filename).suffix.lower().lstrip(".")
- if ext in ("jsonl", "csv", "parquet", "json"):
- return ext
- return "unknown"
- def _count_records(file_path: Path, fmt: str) -> int:
- try:
- if fmt == "jsonl":
- return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
- elif fmt == "json":
- with open(file_path, encoding="utf-8") as f:
- data = json.load(f)
- return len(data) if isinstance(data, list) else 1
- elif fmt == "csv":
- import csv
- with open(file_path, encoding="utf-8") as f:
- return sum(1 for _ in csv.reader(f)) - 1 # minus header
- elif fmt == "parquet":
- import pandas as pd
- return len(pd.read_parquet(file_path))
- except Exception:
- pass
- return 0
- def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
- if fmt == "jsonl":
- records = []
- with open(file_path, encoding="utf-8") as f:
- for i, line in enumerate(f):
- if i >= n:
- break
- line = line.strip()
- if line:
- records.append(json.loads(line))
- return records
- elif fmt == "json":
- with open(file_path, encoding="utf-8") as f:
- data = json.load(f)
- return data[:n] if isinstance(data, list) else [data]
- elif fmt == "csv":
- import csv
- with open(file_path, encoding="utf-8") as f:
- reader = csv.DictReader(f)
- return [dict(row) for i, row in enumerate(reader) if i < n]
- elif fmt == "parquet":
- import pandas as pd
- df = pd.read_parquet(file_path)
- return df.head(n).to_dict(orient="records")
- return []
|