import asyncio import json import uuid from datetime import datetime from pathlib import Path from typing import Any from fastapi import UploadFile from app.config import get_settings from app.core.db import async_session, DatasetRecord from app.core.logging import logger from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse settings = get_settings() # Known metadata filenames that are NOT training data META_FILENAMES = frozenset({ "configuration.json", "configuration.yaml", "README.md", ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json", "special_tokens_map.json", "tokenizer_config.json", "added_tokens.json", "vocab.json", "merges.txt", "config.json", "preprocessor_config.json", }) # File size threshold: files smaller than this (bytes) are likely metadata META_SIZE_THRESHOLD = 500 def _is_training_data_file(path: Path) -> bool: """判断文件是否可能是训练数据文件(而非配置/元数据)。""" if path.suffix in (".jsonl", ".parquet", ".csv"): return True if path.suffix == ".json": if path.name in META_FILENAMES: return False # 小 JSON 文件通常是配置 if path.stat().st_size < META_SIZE_THRESHOLD: return False # 尝试读取首行判断格式 try: first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip() obj = json.loads(first_line) # 如果有 input/output/conversation/instruction 等字段,则是训练数据 if isinstance(obj, dict): data_keys = {"input", "output", "conversations", "instruction", "prompt", "text", "completion", "source", "target", "query", "response"} if data_keys & set(obj.keys()): return True return True # 大 JSON 文件默认是数据 except Exception: return False # 无后缀文件:尝试读取判断是否为 JSON/JSONL if not path.suffix: try: first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip() json.loads(first_line) return True except Exception: return False return False async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse: """从 HuggingFace 或 ModelScope 下载数据集。""" try: if req.use_modelscope: import subprocess ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}" ds_dir.mkdir(parents=True, exist_ok=True) # 用独立进程调用 CLI,完全隔离 FastAPI 事件循环 proc = subprocess.run( [ "modelscope", "download", "--dataset", req.dataset_id, "--local_dir", str(ds_dir), ], capture_output=True, text=True, timeout=3600, ) if proc.returncode != 0: raise RuntimeError(f"modelscope CLI failed: {proc.stderr}") # 扫描下载的文件,找训练数据 jsonl_path, record_count = _scan_and_convert_to_jsonl(ds_dir) else: from datasets import load_dataset ds = load_dataset(req.dataset_id) ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}" ds_dir.mkdir(parents=True, exist_ok=True) if "train" in ds: split = ds["train"] else: split = ds[list(ds.keys())[0]] output_path = ds_dir / "data.jsonl" with open(output_path, "w", encoding="utf-8") as f: for item in split: f.write(json.dumps(item, ensure_ascii=False) + "\n") jsonl_path = output_path record_count = len(split) if hasattr(split, "__len__") else 0 record = DatasetRecord( id=str(uuid.uuid4()), name=req.dataset_id, format="jsonl", record_count=record_count, file_path=str(jsonl_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})") return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path)) except Exception as e: logger.error(f"Dataset download failed: {e}") return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e)) def _scan_and_convert_to_jsonl(ds_dir: Path) -> tuple[Path, int]: """扫描 CLI 下载的数据集目录,找训练数据文件并转为 JSONL。""" # 找所有可能的数据文件 data_files = [] for ext in ("*.jsonl", "*.json", "*.csv"): data_files.extend(ds_dir.rglob(ext)) # 过滤掉元数据文件 data_files = [f for f in data_files if f.name not in META_FILENAMES] if not data_files: raise RuntimeError(f"No dataset files found in {ds_dir}") jsonl_path = ds_dir / "data.jsonl" record_count = 0 with open(jsonl_path, "w", encoding="utf-8") as out: for data_file in data_files: if data_file.suffix == ".jsonl": with open(data_file, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: out.write(line + "\n") record_count += 1 elif data_file.suffix == ".json": try: with open(data_file, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): for item in data: out.write(json.dumps(item, ensure_ascii=False) + "\n") record_count += 1 elif isinstance(data, dict): out.write(json.dumps(data, ensure_ascii=False) + "\n") record_count += 1 except Exception: pass elif data_file.suffix == ".csv": import csv with open(data_file, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: out.write(json.dumps(dict(row), ensure_ascii=False) + "\n") record_count += 1 return jsonl_path, record_count async def upload_dataset(file: UploadFile) -> dict[str, Any]: """保存上传文件并写入数据库。""" upload_dir = settings.uploads_dir upload_dir.mkdir(parents=True, exist_ok=True) safe_name = file.filename or "unknown" file_path = upload_dir / safe_name if file_path.exists(): file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}" content = await file.read() file_path.write_bytes(content) fmt = _detect_format(safe_name) record_count = _count_records(file_path, fmt) record_id = str(uuid.uuid4()) record = DatasetRecord( id=record_id, name=safe_name, format=fmt, record_count=record_count, file_path=str(file_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})") return { "id": record_id, "name": safe_name, "format": fmt, "record_count": record_count, "file_path": str(file_path), "created_at": record.created_at.isoformat(), } def _format_value(value) -> str: """将复杂值格式化为可读字符串。""" if isinstance(value, (dict, list)): return json.dumps(value, ensure_ascii=False, indent=2) return str(value) def _is_sharegpt_format(records: list[dict]) -> bool: """检测是否为 ShareGPT 格式。""" if not records: return False first = records[0] if "conversations" in first and isinstance(first["conversations"], list): if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict): conv = first["conversations"][0] return "from" in conv and "value" in conv return False def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]: """将 ShareGPT 格式展平为 input/output 列。""" flat_rows = [] for row in records: conversations = row.get("conversations", []) for i in range(0, len(conversations) - 1, 2): user_turn = conversations[i] assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None if user_turn.get("from") in ("human", "user"): input_text = str(user_turn.get("value", "")) output_text = str(assistant_turn.get("value", "")) if assistant_turn else "" else: input_text = str(assistant_turn.get("value", "")) if assistant_turn else "" output_text = str(user_turn.get("value", "")) if len(input_text) > 500: input_text = input_text[:500] + "..." if len(output_text) > 500: output_text = output_text[:500] + "..." flat_rows.append({"input": input_text, "output": output_text}) return flat_rows, ["input", "output"] async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]: """预览数据集前 N 行。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id)) record = result.scalar_one_or_none() if not record: return {"total_records": 0, "preview_rows": [], "columns": []} file_path = Path(record.file_path) if not file_path.exists(): return {"total_records": 0, "preview_rows": [], "columns": []} fmt = record.format preview_data = _read_records(file_path, fmt, rows) # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列 if _is_sharegpt_format(preview_data): preview_data, columns = _flatten_sharegpt(preview_data) else: columns = list(preview_data[0].keys()) if preview_data else [] return { "total_records": record.record_count, "preview_rows": [ { "row_index": i, "data": {k: _format_value(v) for k, v in row.items()}, } for i, row in enumerate(preview_data) ], "columns": columns, } async def validate_dataset(dataset_id: str) -> dict[str, Any]: """校验数据集格式和 Schema。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id)) record = result.scalar_one_or_none() if not record: return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []} file_path = Path(record.file_path) if not file_path.exists(): return {"is_valid": False, "errors": ["File not found"], "warnings": []} errors = [] warnings = [] fmt = record.format if fmt not in ("jsonl", "csv", "json", "parquet"): errors.append(f"Unsupported format: {fmt}") try: preview = _read_records(file_path, fmt, 5) if not preview: warnings.append("Dataset appears to be empty") else: first = preview[0] has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion")) if not has_sft_fields: warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}") except Exception as e: errors.append(f"Failed to read file: {str(e)}") return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings} async def list_datasets() -> list[dict[str, Any]]: """列出所有已上传数据集。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc())) records = result.scalars().all() return [ { "id": r.id, "name": r.name, "format": r.format, "record_count": r.record_count, "file_path": r.file_path, "created_at": r.created_at.isoformat(), } for r in records ] async def delete_dataset(dataset_id: str) -> dict[str, Any]: """删除数据集。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id)) record = result.scalar_one_or_none() if record: file_path = Path(record.file_path) if file_path.exists(): file_path.unlink() await session.delete(record) await session.commit() logger.info(f"Deleted dataset: {record.name}") return {"status": "deleted"} def _detect_format(filename: str) -> str: ext = Path(filename).suffix.lower().lstrip(".") if ext in ("jsonl", "csv", "parquet", "json"): return ext return "unknown" def _count_records(file_path: Path, fmt: str) -> int: try: if fmt == "jsonl": return sum(1 for line in open(file_path, encoding="utf-8") if line.strip()) elif fmt == "json": with open(file_path, encoding="utf-8") as f: data = json.load(f) return len(data) if isinstance(data, list) else 1 elif fmt == "csv": import csv with open(file_path, encoding="utf-8") as f: return sum(1 for _ in csv.reader(f)) - 1 elif fmt == "parquet": import pandas as pd return len(pd.read_parquet(file_path)) except Exception: pass return 0 def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]: if fmt == "jsonl": records = [] with open(file_path, encoding="utf-8") as f: for i, line in enumerate(f): if i >= n: break line = line.strip() if line: records.append(json.loads(line)) return records elif fmt == "json": with open(file_path, encoding="utf-8") as f: data = json.load(f) return data[:n] if isinstance(data, list) else [data] elif fmt == "csv": import csv with open(file_path, encoding="utf-8") as f: reader = csv.DictReader(f) return [dict(row) for i, row in enumerate(reader) if i < n] elif fmt == "parquet": import pandas as pd df = pd.read_parquet(file_path) return df.head(n).to_dict(orient="records") return []