dataset_service.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime, timezone
  5. from pathlib import Path
  6. from typing import Any
  7. from fastapi import UploadFile
  8. from app.config import get_settings
  9. from app.core.background_tasks import background_task_manager
  10. from app.core.db import async_session, DatasetRecord, DatasetDownloadTask
  11. from app.core.logging import logger
  12. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  13. from sqlalchemy import select
  14. settings = get_settings()
  15. # Known metadata filenames that are NOT training data
  16. META_FILENAMES = frozenset({
  17. "configuration.json", "configuration.yaml", "README.md",
  18. ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
  19. "special_tokens_map.json", "tokenizer_config.json",
  20. "added_tokens.json", "vocab.json", "merges.txt",
  21. "config.json", "preprocessor_config.json",
  22. "dataset_infos.json", "dataset_info.json",
  23. "state.json", "card_data.json",
  24. })
  25. # File size threshold: files smaller than this (bytes) are likely metadata
  26. META_SIZE_THRESHOLD = 500
  27. def _is_training_data_file(path: Path) -> bool:
  28. """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
  29. if path.name in META_FILENAMES:
  30. return False
  31. if path.suffix in (".jsonl", ".parquet", ".csv"):
  32. # 小文件可能是元数据(如 ModelScope CLI 生成的 data.jsonl 只有几十字节)
  33. if path.stat().st_size < META_SIZE_THRESHOLD:
  34. return False
  35. return True
  36. if path.suffix == ".json":
  37. # 小 JSON 文件通常是配置
  38. if path.stat().st_size < META_SIZE_THRESHOLD:
  39. return False
  40. # 尝试读取首行判断格式
  41. try:
  42. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  43. obj = json.loads(first_line)
  44. # 如果有 input/output/conversation/instruction 等字段,则是训练数据
  45. if isinstance(obj, dict):
  46. data_keys = {"input", "output", "conversations", "instruction", "prompt",
  47. "text", "completion", "source", "target", "query", "response"}
  48. if data_keys & set(obj.keys()):
  49. return True
  50. # 如果只有 framework/task/model_type 等字段,则是元数据
  51. meta_keys = {"framework", "task", "license", "base_model", "model_type",
  52. "language", "domains", "tags", "authors"}
  53. if meta_keys & set(obj.keys()):
  54. return False
  55. return True # 大 JSON 文件默认是数据
  56. except Exception:
  57. return False
  58. # 无后缀文件:尝试读取判断是否为 JSON/JSONL
  59. if not path.suffix:
  60. try:
  61. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  62. json.loads(first_line)
  63. return True
  64. except Exception:
  65. return False
  66. return False
  67. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  68. """启动数据集下载后台任务,立即返回 task_id。"""
  69. task_id = str(uuid.uuid4())
  70. # 写 DB
  71. record = DatasetDownloadTask(
  72. id=task_id,
  73. dataset_id=req.dataset_id,
  74. use_modelscope=1 if req.use_modelscope else 0,
  75. status="pending",
  76. )
  77. async with async_session() as session:
  78. session.add(record)
  79. await session.commit()
  80. # 注册并启动
  81. background_task_manager.register_task(task_id, "dataset_download", {"dataset_id": req.dataset_id})
  82. background_task_manager.run(
  83. task_id, "dataset_download", _execute_dataset_download(task_id, req)
  84. )
  85. logger.info(f"Dataset download task started: {req.dataset_id} (task_id={task_id})")
  86. return DatasetDownloadResponse(
  87. dataset_id=req.dataset_id, status="pending", path=task_id
  88. )
  89. async def _execute_dataset_download(task_id: str, req: DatasetDownloadRequest) -> dict:
  90. """后台执行数据集下载。"""
  91. try:
  92. if req.use_modelscope:
  93. ds_dir, jsonl_path, record_count = await asyncio.to_thread(
  94. _download_modelscope_dataset, req.dataset_id
  95. )
  96. else:
  97. from datasets import load_dataset
  98. ds = load_dataset(req.dataset_id)
  99. ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
  100. ds_dir.mkdir(parents=True, exist_ok=True)
  101. if "train" in ds:
  102. split = ds["train"]
  103. else:
  104. split = ds[list(ds.keys())[0]]
  105. output_path = ds_dir / "data.jsonl"
  106. with open(output_path, "w", encoding="utf-8") as f:
  107. for item in split:
  108. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  109. jsonl_path = output_path
  110. record_count = len(split) if hasattr(split, "__len__") else 0
  111. db_record = DatasetRecord(
  112. id=str(uuid.uuid4()),
  113. name=req.dataset_id,
  114. format="jsonl",
  115. record_count=record_count,
  116. file_path=str(jsonl_path),
  117. created_at=datetime.utcnow(),
  118. )
  119. async with async_session() as session:
  120. session.add(db_record)
  121. await session.commit()
  122. await _update_dataset_download_status(task_id, "completed", path=str(jsonl_path), record_count=record_count)
  123. logger.info(f"Dataset downloaded: {req.dataset_id} ({record_count} records)")
  124. return {"path": str(jsonl_path), "record_count": record_count}
  125. except Exception as e:
  126. logger.error(f"Dataset download failed: {e}")
  127. await _update_dataset_download_status(task_id, "failed", error=str(e))
  128. return {"error": str(e)}
  129. async def _update_dataset_download_status(task_id: str, status: str, path: str = None, error: str = None, record_count: int = 0):
  130. async with async_session() as session:
  131. result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
  132. record = result.scalar_one_or_none()
  133. if record:
  134. record.status = status
  135. if path:
  136. record.path = path
  137. if error:
  138. record.error = error
  139. if record_count:
  140. record.record_count = record_count
  141. if status in ("completed", "failed"):
  142. record.finished_at = datetime.utcnow()
  143. await session.commit()
  144. background_task_manager.update_task(
  145. task_id, status=status, path=path, error=error, record_count=record_count,
  146. )
  147. async def get_dataset_download_status(task_id: str) -> dict[str, Any]:
  148. async with async_session() as session:
  149. result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
  150. record = result.scalar_one_or_none()
  151. if record:
  152. return {
  153. "task_id": record.id,
  154. "dataset_id": record.dataset_id,
  155. "status": record.status,
  156. "use_modelscope": bool(record.use_modelscope),
  157. "path": record.path,
  158. "error": record.error,
  159. "record_count": record.record_count,
  160. "created_at": record.created_at.isoformat() if record.created_at else "",
  161. }
  162. mem = background_task_manager.get_task(task_id)
  163. if mem:
  164. return {
  165. "task_id": task_id,
  166. "dataset_id": mem.get("dataset_id", ""),
  167. "status": mem["status"],
  168. "error": mem.get("error"),
  169. "record_count": mem.get("record_count", 0),
  170. }
  171. return {"task_id": task_id, "status": "not_found"}
  172. async def list_dataset_downloads() -> list[dict[str, Any]]:
  173. async with async_session() as session:
  174. result = await session.execute(
  175. select(DatasetDownloadTask).order_by(DatasetDownloadTask.created_at.desc())
  176. )
  177. records = result.scalars().all()
  178. return [
  179. {
  180. "task_id": r.id,
  181. "dataset_id": r.dataset_id,
  182. "status": r.status,
  183. "use_modelscope": bool(r.use_modelscope),
  184. "path": r.path,
  185. "error": r.error,
  186. "record_count": r.record_count,
  187. "created_at": r.created_at.isoformat() if r.created_at else "",
  188. }
  189. for r in records
  190. ]
  191. async def cancel_dataset_download(task_id: str) -> dict[str, Any]:
  192. background_task_manager.cancel_task(task_id)
  193. async with async_session() as session:
  194. result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
  195. record = result.scalar_one_or_none()
  196. if record and record.status in ("pending", "downloading"):
  197. record.status = "cancelled"
  198. record.error = "Cancelled by user"
  199. record.finished_at = datetime.utcnow()
  200. await session.commit()
  201. return {"task_id": task_id, "status": "cancelled"}
  202. async def recover_stale_downloads() -> None:
  203. async with async_session() as session:
  204. result = await session.execute(
  205. select(DatasetDownloadTask).where(
  206. DatasetDownloadTask.status.in_(["pending", "downloading"])
  207. )
  208. )
  209. records = result.scalars().all()
  210. for record in records:
  211. record.status = "failed"
  212. record.error = "Server restarted, task interrupted"
  213. record.finished_at = datetime.utcnow()
  214. if records:
  215. await session.commit()
  216. logger.info(f"Recovered {len(records)} stale dataset download tasks")
  217. def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
  218. """用 modelscope CLI 下载数据集,完全绕过 datasets 库,避免版本兼容问题。"""
  219. import subprocess
  220. ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
  221. ds_dir.mkdir(parents=True, exist_ok=True)
  222. # 使用 CLI 方式下载,避免 snapshot_download API 的路径问题
  223. cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)]
  224. logger.info(f"Running: {' '.join(cmd)}")
  225. proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
  226. if proc.returncode != 0:
  227. logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
  228. raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
  229. # 扫描下载目录中的所有文件
  230. all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
  231. logger.info(f"ModelScope CLI downloaded {len(all_files)} files to {ds_dir}")
  232. # 识别训练数据文件
  233. data_files = [f for f in all_files if _is_training_data_file(f)]
  234. if not data_files:
  235. fallback = [f for f in all_files if f.suffix in (".json", ".jsonl") and f.name not in META_FILENAMES and f.name != "README.md"]
  236. logger.warning(f"No training data files found in {dataset_id}. "
  237. f"Available JSON files: {[f.name for f in fallback]}")
  238. if fallback:
  239. data_files = fallback
  240. else:
  241. # 如果还是没有,列出所有文件供排查
  242. logger.error(f"All downloaded files: {[str(f.relative_to(ds_dir)) for f in all_files]}")
  243. raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}. "
  244. f"Available files: {[f.name for f in all_files]}")
  245. # 按文件大小排序,取最大的文件作为训练数据(真正的数据集通常是最大的)
  246. target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0]
  247. logger.info(f"Selected data file: {target} (size={target.stat().st_size})")
  248. # 读取并统一转为 JSONL
  249. jsonl_path = ds_dir / "data.jsonl"
  250. record_count = 0
  251. content = target.read_text(encoding="utf-8")
  252. if target.suffix == ".jsonl" or not target.suffix:
  253. # JSONL 或无后缀文件:逐行解析
  254. records = []
  255. for line in content.splitlines():
  256. line = line.strip()
  257. if not line:
  258. continue
  259. try:
  260. records.append(json.loads(line))
  261. except json.JSONDecodeError:
  262. records = json.loads(content)
  263. if not isinstance(records, list):
  264. records = [records]
  265. break
  266. elif target.suffix == ".json":
  267. # JSON 文件:先尝试 JSON 数组,失败再逐行解析(可能是 JSONL 格式)
  268. try:
  269. records = json.loads(content)
  270. if not isinstance(records, list):
  271. records = [records]
  272. except json.JSONDecodeError:
  273. records = []
  274. for line in content.splitlines():
  275. line = line.strip()
  276. if not line:
  277. continue
  278. try:
  279. records.append(json.loads(line))
  280. except json.JSONDecodeError:
  281. continue
  282. with open(jsonl_path, "w", encoding="utf-8") as f:
  283. for item in records:
  284. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  285. record_count += 1
  286. return ds_dir, jsonl_path, record_count
  287. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  288. """保存上传文件并写入数据库。"""
  289. upload_dir = settings.uploads_dir
  290. upload_dir.mkdir(parents=True, exist_ok=True)
  291. safe_name = file.filename or "unknown"
  292. file_path = upload_dir / safe_name
  293. if file_path.exists():
  294. file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
  295. content = await file.read()
  296. file_path.write_bytes(content)
  297. fmt = _detect_format(safe_name)
  298. record_count = _count_records(file_path, fmt)
  299. record_id = str(uuid.uuid4())
  300. record = DatasetRecord(
  301. id=record_id,
  302. name=safe_name,
  303. format=fmt,
  304. record_count=record_count,
  305. file_path=str(file_path),
  306. created_at=datetime.utcnow(),
  307. )
  308. async with async_session() as session:
  309. session.add(record)
  310. await session.commit()
  311. logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
  312. return {
  313. "id": record_id,
  314. "name": safe_name,
  315. "format": fmt,
  316. "record_count": record_count,
  317. "file_path": str(file_path),
  318. "created_at": record.created_at.isoformat(),
  319. }
  320. def _format_value(value) -> str:
  321. """将复杂值格式化为可读字符串。"""
  322. if isinstance(value, (dict, list)):
  323. return json.dumps(value, ensure_ascii=False, indent=2)
  324. return str(value)
  325. def _is_sharegpt_format(records: list[dict]) -> bool:
  326. """检测是否为 ShareGPT 格式。"""
  327. if not records:
  328. return False
  329. first = records[0]
  330. if "conversations" in first and isinstance(first["conversations"], list):
  331. if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
  332. conv = first["conversations"][0]
  333. return "from" in conv and "value" in conv
  334. return False
  335. def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
  336. """将 ShareGPT 格式展平为 input/output 列。"""
  337. flat_rows = []
  338. for row in records:
  339. conversations = row.get("conversations", [])
  340. for i in range(0, len(conversations) - 1, 2):
  341. user_turn = conversations[i]
  342. assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
  343. if user_turn.get("from") in ("human", "user"):
  344. input_text = str(user_turn.get("value", ""))
  345. output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  346. else:
  347. input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  348. output_text = str(user_turn.get("value", ""))
  349. if len(input_text) > 500:
  350. input_text = input_text[:500] + "..."
  351. if len(output_text) > 500:
  352. output_text = output_text[:500] + "..."
  353. flat_rows.append({"input": input_text, "output": output_text})
  354. return flat_rows, ["input", "output"]
  355. async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
  356. """预览数据集前 N 行。"""
  357. async with async_session() as session:
  358. from sqlalchemy import select
  359. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  360. record = result.scalar_one_or_none()
  361. if not record:
  362. return {"total_records": 0, "preview_rows": [], "columns": []}
  363. file_path = Path(record.file_path)
  364. if not file_path.exists():
  365. return {"total_records": 0, "preview_rows": [], "columns": []}
  366. fmt = record.format
  367. preview_data = _read_records(file_path, fmt, rows)
  368. # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
  369. if _is_sharegpt_format(preview_data):
  370. preview_data, columns = _flatten_sharegpt(preview_data)
  371. else:
  372. columns = list(preview_data[0].keys()) if preview_data else []
  373. return {
  374. "total_records": record.record_count,
  375. "preview_rows": [
  376. {
  377. "row_index": i,
  378. "data": {k: _format_value(v) for k, v in row.items()},
  379. }
  380. for i, row in enumerate(preview_data)
  381. ],
  382. "columns": columns,
  383. }
  384. async def validate_dataset(dataset_id: str) -> dict[str, Any]:
  385. """校验数据集格式和 Schema。"""
  386. async with async_session() as session:
  387. from sqlalchemy import select
  388. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  389. record = result.scalar_one_or_none()
  390. if not record:
  391. return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
  392. file_path = Path(record.file_path)
  393. if not file_path.exists():
  394. return {"is_valid": False, "errors": ["File not found"], "warnings": []}
  395. errors = []
  396. warnings = []
  397. fmt = record.format
  398. if fmt not in ("jsonl", "csv", "json", "parquet"):
  399. errors.append(f"Unsupported format: {fmt}")
  400. try:
  401. preview = _read_records(file_path, fmt, 5)
  402. if not preview:
  403. warnings.append("Dataset appears to be empty")
  404. else:
  405. first = preview[0]
  406. has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
  407. if not has_sft_fields:
  408. warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
  409. except Exception as e:
  410. errors.append(f"Failed to read file: {str(e)}")
  411. return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
  412. async def list_datasets() -> list[dict[str, Any]]:
  413. """列出所有已上传数据集。"""
  414. async with async_session() as session:
  415. from sqlalchemy import select
  416. result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
  417. records = result.scalars().all()
  418. return [
  419. {
  420. "id": r.id,
  421. "name": r.name,
  422. "format": r.format,
  423. "record_count": r.record_count,
  424. "file_path": r.file_path,
  425. "created_at": r.created_at.isoformat(),
  426. }
  427. for r in records
  428. ]
  429. async def delete_dataset(dataset_id: str) -> dict[str, Any]:
  430. """删除数据集。"""
  431. async with async_session() as session:
  432. from sqlalchemy import select
  433. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  434. record = result.scalar_one_or_none()
  435. if record:
  436. file_path = Path(record.file_path)
  437. if file_path.exists():
  438. file_path.unlink()
  439. await session.delete(record)
  440. await session.commit()
  441. logger.info(f"Deleted dataset: {record.name}")
  442. return {"status": "deleted"}
  443. def _detect_format(filename: str) -> str:
  444. ext = Path(filename).suffix.lower().lstrip(".")
  445. if ext in ("jsonl", "csv", "parquet", "json"):
  446. return ext
  447. return "unknown"
  448. def _count_records(file_path: Path, fmt: str) -> int:
  449. try:
  450. if fmt == "jsonl":
  451. return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
  452. elif fmt == "json":
  453. with open(file_path, encoding="utf-8") as f:
  454. data = json.load(f)
  455. return len(data) if isinstance(data, list) else 1
  456. elif fmt == "csv":
  457. import csv
  458. with open(file_path, encoding="utf-8") as f:
  459. return sum(1 for _ in csv.reader(f)) - 1
  460. elif fmt == "parquet":
  461. import pandas as pd
  462. return len(pd.read_parquet(file_path))
  463. except Exception:
  464. pass
  465. return 0
  466. def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
  467. if fmt == "jsonl":
  468. records = []
  469. with open(file_path, encoding="utf-8") as f:
  470. for i, line in enumerate(f):
  471. if i >= n:
  472. break
  473. line = line.strip()
  474. if line:
  475. records.append(json.loads(line))
  476. return records
  477. elif fmt == "json":
  478. with open(file_path, encoding="utf-8") as f:
  479. data = json.load(f)
  480. return data[:n] if isinstance(data, list) else [data]
  481. elif fmt == "csv":
  482. import csv
  483. with open(file_path, encoding="utf-8") as f:
  484. reader = csv.DictReader(f)
  485. return [dict(row) for i, row in enumerate(reader) if i < n]
  486. elif fmt == "parquet":
  487. import pandas as pd
  488. df = pd.read_parquet(file_path)
  489. return df.head(n).to_dict(orient="records")
  490. return []