| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427 |
- import asyncio
- import json
- import uuid
- from datetime import datetime
- from pathlib import Path
- from typing import Any
- from fastapi import UploadFile
- from app.config import get_settings
- from app.core.db import async_session, DatasetRecord
- from app.core.logging import logger
- from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
- settings = get_settings()
- # Known metadata filenames that are NOT training data
- META_FILENAMES = frozenset({
- "configuration.json", "configuration.yaml", "README.md",
- ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
- "special_tokens_map.json", "tokenizer_config.json",
- "added_tokens.json", "vocab.json", "merges.txt",
- "config.json", "preprocessor_config.json",
- # HF/ModelScope dataset metadata
- "dataset_info.json", "dataset_infos.json", "dataset.json",
- "state.json", "dataset_dict.json",
- })
- # File size threshold: files smaller than this (bytes) are likely metadata
- META_SIZE_THRESHOLD = 500
- def _is_training_data_file(path: Path) -> bool:
- """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
- if path.suffix in (".jsonl", ".parquet", ".csv"):
- return True
- if path.suffix == ".json":
- if path.name in META_FILENAMES:
- return False
- # 小 JSON 文件通常是配置
- if path.stat().st_size < META_SIZE_THRESHOLD:
- return False
- # 尝试读取首行判断格式
- try:
- first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
- obj = json.loads(first_line)
- # 如果有 input/output/conversation/instruction 等字段,则是训练数据
- if isinstance(obj, dict):
- data_keys = {"input", "output", "conversations", "instruction", "prompt",
- "text", "completion", "source", "target", "query", "response"}
- if data_keys & set(obj.keys()):
- return True
- return True # 大 JSON 文件默认是数据
- except Exception:
- return False
- # 无后缀文件:尝试读取判断是否为 JSON/JSONL
- if not path.suffix:
- try:
- first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
- json.loads(first_line)
- return True
- except Exception:
- return False
- return False
- async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
- """从 HuggingFace 或 ModelScope 下载数据集。"""
- try:
- if req.use_modelscope:
- # ModelScope 数据集是 HF 镜像,直接用 datasets 库加载
- from datasets import load_dataset
- ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
- ds_dir.mkdir(parents=True, exist_ok=True)
- ds = load_dataset(req.dataset_id)
- if "train" in ds:
- split = ds["train"]
- else:
- split = ds[list(ds.keys())[0]]
- output_path = ds_dir / "data.jsonl"
- record_count = 0
- with open(output_path, "w", encoding="utf-8") as f:
- for item in split:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- record_count += 1
- if record_count == 0:
- raise RuntimeError("Dataset loaded but returned 0 records")
- jsonl_path = output_path
- else:
- from datasets import load_dataset
- ds = load_dataset(req.dataset_id)
- ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
- ds_dir.mkdir(parents=True, exist_ok=True)
- if "train" in ds:
- split = ds["train"]
- else:
- split = ds[list(ds.keys())[0]]
- output_path = ds_dir / "data.jsonl"
- with open(output_path, "w", encoding="utf-8") as f:
- for item in split:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- jsonl_path = output_path
- record_count = len(split) if hasattr(split, "__len__") else 0
- if record_count == 0:
- raise RuntimeError("HF dataset loaded but returned 0 records")
- record = DatasetRecord(
- id=str(uuid.uuid4()),
- name=req.dataset_id,
- format="jsonl",
- record_count=record_count,
- file_path=str(jsonl_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
- except Exception as e:
- logger.error(f"Dataset download failed: {e}")
- return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
- def _scan_and_convert_to_jsonl(ds_dir: Path) -> tuple[Path, int]:
- """扫描 CLI 下载的数据集目录,找训练数据文件并转为 JSONL。"""
- # 找所有可能的数据文件
- data_files = []
- for ext in ("*.jsonl", "*.json", "*.csv"):
- data_files.extend(ds_dir.rglob(ext))
- # 过滤掉元数据文件
- data_files = [f for f in data_files if f.name not in META_FILENAMES]
- if not data_files:
- raise RuntimeError(f"No dataset files found in {ds_dir}")
- jsonl_path = ds_dir / "data.jsonl"
- record_count = 0
- with open(jsonl_path, "w", encoding="utf-8") as out:
- for data_file in data_files:
- if data_file.suffix == ".jsonl":
- with open(data_file, "r", encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if line:
- out.write(line + "\n")
- record_count += 1
- elif data_file.suffix == ".json":
- try:
- with open(data_file, "r", encoding="utf-8") as f:
- data = json.load(f)
- if isinstance(data, list):
- for item in data:
- out.write(json.dumps(item, ensure_ascii=False) + "\n")
- record_count += 1
- elif isinstance(data, dict):
- # 跳过 HF/ModelScope dataset metadata(features/splits 结构)
- if "features" in data or "splits" in data or "dataset_name" in data:
- continue
- out.write(json.dumps(data, ensure_ascii=False) + "\n")
- record_count += 1
- except Exception:
- pass
- elif data_file.suffix == ".csv":
- import csv
- with open(data_file, "r", encoding="utf-8") as f:
- reader = csv.DictReader(f)
- for row in reader:
- out.write(json.dumps(dict(row), ensure_ascii=False) + "\n")
- record_count += 1
- return jsonl_path, record_count
- async def upload_dataset(file: UploadFile) -> dict[str, Any]:
- """保存上传文件并写入数据库。"""
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- safe_name = file.filename or "unknown"
- file_path = upload_dir / safe_name
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
- content = await file.read()
- file_path.write_bytes(content)
- fmt = _detect_format(safe_name)
- record_count = _count_records(file_path, fmt)
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=safe_name,
- format=fmt,
- record_count=record_count,
- file_path=str(file_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
- return {
- "id": record_id,
- "name": safe_name,
- "format": fmt,
- "record_count": record_count,
- "file_path": str(file_path),
- "created_at": record.created_at.isoformat(),
- }
- def _format_value(value) -> str:
- """将复杂值格式化为可读字符串。"""
- if isinstance(value, (dict, list)):
- return json.dumps(value, ensure_ascii=False, indent=2)
- return str(value)
- def _is_sharegpt_format(records: list[dict]) -> bool:
- """检测是否为 ShareGPT 格式。"""
- if not records:
- return False
- first = records[0]
- if "conversations" in first and isinstance(first["conversations"], list):
- if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
- conv = first["conversations"][0]
- return "from" in conv and "value" in conv
- return False
- def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
- """将 ShareGPT 格式展平为 input/output 列。"""
- flat_rows = []
- for row in records:
- conversations = row.get("conversations", [])
- for i in range(0, len(conversations) - 1, 2):
- user_turn = conversations[i]
- assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
- if user_turn.get("from") in ("human", "user"):
- input_text = str(user_turn.get("value", ""))
- output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
- else:
- input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
- output_text = str(user_turn.get("value", ""))
- if len(input_text) > 500:
- input_text = input_text[:500] + "..."
- if len(output_text) > 500:
- output_text = output_text[:500] + "..."
- flat_rows.append({"input": input_text, "output": output_text})
- return flat_rows, ["input", "output"]
- async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
- """预览数据集前 N 行。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"total_records": 0, "preview_rows": [], "columns": []}
- file_path = Path(record.file_path)
- if not file_path.exists():
- return {"total_records": 0, "preview_rows": [], "columns": []}
- fmt = record.format
- preview_data = _read_records(file_path, fmt, rows)
- # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
- if _is_sharegpt_format(preview_data):
- preview_data, columns = _flatten_sharegpt(preview_data)
- else:
- columns = list(preview_data[0].keys()) if preview_data else []
- return {
- "total_records": record.record_count,
- "preview_rows": [
- {
- "row_index": i,
- "data": {k: _format_value(v) for k, v in row.items()},
- }
- for i, row in enumerate(preview_data)
- ],
- "columns": columns,
- }
- async def validate_dataset(dataset_id: str) -> dict[str, Any]:
- """校验数据集格式和 Schema。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
- file_path = Path(record.file_path)
- if not file_path.exists():
- return {"is_valid": False, "errors": ["File not found"], "warnings": []}
- errors = []
- warnings = []
- fmt = record.format
- if fmt not in ("jsonl", "csv", "json", "parquet"):
- errors.append(f"Unsupported format: {fmt}")
- try:
- preview = _read_records(file_path, fmt, 5)
- if not preview:
- warnings.append("Dataset appears to be empty")
- else:
- first = preview[0]
- has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
- if not has_sft_fields:
- warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
- except Exception as e:
- errors.append(f"Failed to read file: {str(e)}")
- return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
- async def list_datasets() -> list[dict[str, Any]]:
- """列出所有已上传数据集。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
- records = result.scalars().all()
- return [
- {
- "id": r.id,
- "name": r.name,
- "format": r.format,
- "record_count": r.record_count,
- "file_path": r.file_path,
- "created_at": r.created_at.isoformat(),
- }
- for r in records
- ]
- async def delete_dataset(dataset_id: str) -> dict[str, Any]:
- """删除数据集。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if record:
- file_path = Path(record.file_path)
- if file_path.exists():
- file_path.unlink()
- await session.delete(record)
- await session.commit()
- logger.info(f"Deleted dataset: {record.name}")
- return {"status": "deleted"}
- def _detect_format(filename: str) -> str:
- ext = Path(filename).suffix.lower().lstrip(".")
- if ext in ("jsonl", "csv", "parquet", "json"):
- return ext
- return "unknown"
- def _count_records(file_path: Path, fmt: str) -> int:
- try:
- if fmt == "jsonl":
- return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
- elif fmt == "json":
- with open(file_path, encoding="utf-8") as f:
- data = json.load(f)
- return len(data) if isinstance(data, list) else 1
- elif fmt == "csv":
- import csv
- with open(file_path, encoding="utf-8") as f:
- return sum(1 for _ in csv.reader(f)) - 1
- elif fmt == "parquet":
- import pandas as pd
- return len(pd.read_parquet(file_path))
- except Exception:
- pass
- return 0
- def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
- if fmt == "jsonl":
- records = []
- with open(file_path, encoding="utf-8") as f:
- for i, line in enumerate(f):
- if i >= n:
- break
- line = line.strip()
- if line:
- records.append(json.loads(line))
- return records
- elif fmt == "json":
- with open(file_path, encoding="utf-8") as f:
- data = json.load(f)
- return data[:n] if isinstance(data, list) else [data]
- elif fmt == "csv":
- import csv
- with open(file_path, encoding="utf-8") as f:
- reader = csv.DictReader(f)
- return [dict(row) for i, row in enumerate(reader) if i < n]
- elif fmt == "parquet":
- import pandas as pd
- df = pd.read_parquet(file_path)
- return df.head(n).to_dict(orient="records")
- return []
|