dataset_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime, timezone
  5. from pathlib import Path
  6. from typing import Any
  7. from fastapi import UploadFile
  8. from app.config import get_settings
  9. from app.core.db import async_session, DatasetRecord
  10. from app.core.logging import logger
  11. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  12. settings = get_settings()
  13. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  14. """从 HuggingFace 或 ModelScope 下载数据集。"""
  15. try:
  16. if req.use_modelscope:
  17. # 用 asyncio.to_thread 包裹同步下载,避免阻塞事件循环
  18. ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
  19. else:
  20. from datasets import load_dataset
  21. ds = load_dataset(req.dataset_id)
  22. ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
  23. ds_dir.mkdir(parents=True, exist_ok=True)
  24. if "train" in ds:
  25. split = ds["train"]
  26. else:
  27. split = ds[list(ds.keys())[0]]
  28. output_path = ds_dir / "data.jsonl"
  29. with open(output_path, "w", encoding="utf-8") as f:
  30. for item in split:
  31. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  32. jsonl_path = output_path
  33. record_count = len(split) if hasattr(split, "__len__") else 0
  34. # 写入数据库
  35. record = DatasetRecord(
  36. id=str(uuid.uuid4()),
  37. name=req.dataset_id,
  38. format="jsonl",
  39. record_count=record_count,
  40. file_path=str(jsonl_path),
  41. created_at=datetime.now(timezone.utc),
  42. )
  43. async with async_session() as session:
  44. session.add(record)
  45. await session.commit()
  46. logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
  47. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
  48. except Exception as e:
  49. logger.error(f"Dataset download failed: {e}")
  50. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
  51. def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
  52. """用 snapshot_download 下载数据集文件,完全绕过 datasets 库,避免版本兼容问题。"""
  53. from modelscope import snapshot_download
  54. ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
  55. ds_dir.mkdir(parents=True, exist_ok=True)
  56. # 用 snapshot_download 下载数据集文件到本地
  57. local_path = snapshot_download(dataset_id, cache_dir=str(settings.processed_dir))
  58. # 扫描下载目录中的 JSON/JSONL 文件
  59. data_files = []
  60. for p in Path(local_path).rglob("*"):
  61. if p.is_file() and p.suffix in (".json", ".jsonl"):
  62. data_files.append(p)
  63. if not data_files:
  64. raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
  65. # 优先取 train / data 开头的文件
  66. target = None
  67. for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
  68. for f in data_files:
  69. if f.name == name:
  70. target = f
  71. break
  72. if target:
  73. break
  74. if not target:
  75. target = data_files[0]
  76. # 读取并统一转为 JSONL
  77. jsonl_path = ds_dir / "data.jsonl"
  78. record_count = 0
  79. content = target.read_text(encoding="utf-8")
  80. if target.suffix == ".jsonl":
  81. records = [json.loads(line.strip()) for line in content.splitlines() if line.strip()]
  82. else:
  83. records = json.loads(content)
  84. if not isinstance(records, list):
  85. records = [records]
  86. with open(jsonl_path, "w", encoding="utf-8") as f:
  87. for item in records:
  88. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  89. record_count += 1
  90. return ds_dir, jsonl_path, record_count
  91. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  92. """保存上传文件并写入数据库。"""
  93. upload_dir = settings.uploads_dir
  94. upload_dir.mkdir(parents=True, exist_ok=True)
  95. # 避免文件名冲突
  96. safe_name = file.filename or "unknown"
  97. file_path = upload_dir / safe_name
  98. if file_path.exists():
  99. file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
  100. content = await file.read()
  101. file_path.write_bytes(content)
  102. fmt = _detect_format(safe_name)
  103. record_count = _count_records(file_path, fmt)
  104. record_id = str(uuid.uuid4())
  105. record = DatasetRecord(
  106. id=record_id,
  107. name=safe_name,
  108. format=fmt,
  109. record_count=record_count,
  110. file_path=str(file_path),
  111. created_at=datetime.now(timezone.utc),
  112. )
  113. async with async_session() as session:
  114. session.add(record)
  115. await session.commit()
  116. logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
  117. return {
  118. "id": record_id,
  119. "name": safe_name,
  120. "format": fmt,
  121. "record_count": record_count,
  122. "file_path": str(file_path),
  123. "created_at": record.created_at.isoformat(),
  124. }
  125. def _format_value(value) -> str:
  126. """将复杂值格式化为可读字符串。"""
  127. if isinstance(value, (dict, list)):
  128. return json.dumps(value, ensure_ascii=False, indent=2)
  129. return str(value)
  130. def _is_sharegpt_format(records: list[dict]) -> bool:
  131. """检测是否为 ShareGPT 格式。"""
  132. if not records:
  133. return False
  134. first = records[0]
  135. if "conversations" in first and isinstance(first["conversations"], list):
  136. if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
  137. conv = first["conversations"][0]
  138. return "from" in conv and "value" in conv
  139. return False
  140. def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
  141. """将 ShareGPT 格式展平为 input/output 列。"""
  142. flat_rows = []
  143. for row in records:
  144. conversations = row.get("conversations", [])
  145. # 每轮 user+assistant 对话作为一行
  146. for i in range(0, len(conversations) - 1, 2):
  147. user_turn = conversations[i]
  148. assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
  149. if user_turn.get("from") in ("human", "user"):
  150. input_text = str(user_turn.get("value", ""))
  151. output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  152. else:
  153. input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  154. output_text = str(user_turn.get("value", ""))
  155. # 截断过长文本
  156. if len(input_text) > 500:
  157. input_text = input_text[:500] + "..."
  158. if len(output_text) > 500:
  159. output_text = output_text[:500] + "..."
  160. flat_rows.append({"input": input_text, "output": output_text})
  161. return flat_rows, ["input", "output"]
  162. async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
  163. """预览数据集前 N 行。"""
  164. async with async_session() as session:
  165. from sqlalchemy import select
  166. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  167. record = result.scalar_one_or_none()
  168. if not record:
  169. return {"total_records": 0, "preview_rows": [], "columns": []}
  170. file_path = Path(record.file_path)
  171. if not file_path.exists():
  172. return {"total_records": 0, "preview_rows": [], "columns": []}
  173. fmt = record.format
  174. preview_data = _read_records(file_path, fmt, rows)
  175. # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
  176. if _is_sharegpt_format(preview_data):
  177. preview_data, columns = _flatten_sharegpt(preview_data)
  178. else:
  179. columns = list(preview_data[0].keys()) if preview_data else []
  180. return {
  181. "total_records": record.record_count,
  182. "preview_rows": [
  183. {
  184. "row_index": i,
  185. "data": {k: _format_value(v) for k, v in row.items()},
  186. }
  187. for i, row in enumerate(preview_data)
  188. ],
  189. "columns": columns,
  190. }
  191. async def validate_dataset(dataset_id: str) -> dict[str, Any]:
  192. """校验数据集格式和 Schema。"""
  193. async with async_session() as session:
  194. from sqlalchemy import select
  195. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  196. record = result.scalar_one_or_none()
  197. if not record:
  198. return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
  199. file_path = Path(record.file_path)
  200. if not file_path.exists():
  201. return {"is_valid": False, "errors": ["File not found"], "warnings": []}
  202. errors = []
  203. warnings = []
  204. # 检查格式
  205. fmt = record.format
  206. if fmt not in ("jsonl", "csv", "json", "parquet"):
  207. errors.append(f"Unsupported format: {fmt}")
  208. # 检查内容
  209. try:
  210. preview = _read_records(file_path, fmt, 5)
  211. if not preview:
  212. warnings.append("Dataset appears to be empty")
  213. else:
  214. # 检查必需字段(SFT 格式)
  215. first = preview[0]
  216. has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
  217. if not has_sft_fields:
  218. warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
  219. except Exception as e:
  220. errors.append(f"Failed to read file: {str(e)}")
  221. return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
  222. async def list_datasets() -> list[dict[str, Any]]:
  223. """列出所有已上传数据集。"""
  224. async with async_session() as session:
  225. from sqlalchemy import select
  226. result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
  227. records = result.scalars().all()
  228. return [
  229. {
  230. "id": r.id,
  231. "name": r.name,
  232. "format": r.format,
  233. "record_count": r.record_count,
  234. "file_path": r.file_path,
  235. "created_at": r.created_at.isoformat(),
  236. }
  237. for r in records
  238. ]
  239. async def delete_dataset(dataset_id: str) -> dict[str, Any]:
  240. """删除数据集。"""
  241. async with async_session() as session:
  242. from sqlalchemy import select
  243. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  244. record = result.scalar_one_or_none()
  245. if record:
  246. # 删除文件
  247. file_path = Path(record.file_path)
  248. if file_path.exists():
  249. file_path.unlink()
  250. await session.delete(record)
  251. await session.commit()
  252. logger.info(f"Deleted dataset: {record.name}")
  253. return {"status": "deleted"}
  254. def _detect_format(filename: str) -> str:
  255. ext = Path(filename).suffix.lower().lstrip(".")
  256. if ext in ("jsonl", "csv", "parquet", "json"):
  257. return ext
  258. return "unknown"
  259. def _count_records(file_path: Path, fmt: str) -> int:
  260. try:
  261. if fmt == "jsonl":
  262. return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
  263. elif fmt == "json":
  264. with open(file_path, encoding="utf-8") as f:
  265. data = json.load(f)
  266. return len(data) if isinstance(data, list) else 1
  267. elif fmt == "csv":
  268. import csv
  269. with open(file_path, encoding="utf-8") as f:
  270. return sum(1 for _ in csv.reader(f)) - 1 # minus header
  271. elif fmt == "parquet":
  272. import pandas as pd
  273. return len(pd.read_parquet(file_path))
  274. except Exception:
  275. pass
  276. return 0
  277. def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
  278. if fmt == "jsonl":
  279. records = []
  280. with open(file_path, encoding="utf-8") as f:
  281. for i, line in enumerate(f):
  282. if i >= n:
  283. break
  284. line = line.strip()
  285. if line:
  286. records.append(json.loads(line))
  287. return records
  288. elif fmt == "json":
  289. with open(file_path, encoding="utf-8") as f:
  290. data = json.load(f)
  291. return data[:n] if isinstance(data, list) else [data]
  292. elif fmt == "csv":
  293. import csv
  294. with open(file_path, encoding="utf-8") as f:
  295. reader = csv.DictReader(f)
  296. return [dict(row) for i, row in enumerate(reader) if i < n]
  297. elif fmt == "parquet":
  298. import pandas as pd
  299. df = pd.read_parquet(file_path)
  300. return df.head(n).to_dict(orient="records")
  301. return []