dataset_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime, timezone
  5. from pathlib import Path
  6. from typing import Any
  7. from fastapi import UploadFile
  8. from app.config import get_settings
  9. from app.core.db import async_session, DatasetRecord
  10. from app.core.logging import logger
  11. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  12. settings = get_settings()
  13. # Known metadata filenames that are NOT training data
  14. META_FILENAMES = frozenset({
  15. "configuration.json", "configuration.yaml", "README.md",
  16. ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
  17. "special_tokens_map.json", "tokenizer_config.json",
  18. "added_tokens.json", "vocab.json", "merges.txt",
  19. "config.json", "preprocessor_config.json",
  20. })
  21. # File size threshold: files smaller than this (bytes) are likely metadata
  22. META_SIZE_THRESHOLD = 500
  23. def _is_training_data_file(path: Path) -> bool:
  24. """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
  25. if path.suffix in (".jsonl", ".parquet", ".csv"):
  26. return True
  27. if path.suffix == ".json":
  28. if path.name in META_FILENAMES:
  29. return False
  30. # 小 JSON 文件通常是配置
  31. if path.stat().st_size < META_SIZE_THRESHOLD:
  32. return False
  33. # 尝试读取首行判断格式
  34. try:
  35. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  36. obj = json.loads(first_line)
  37. # 如果有 input/output/conversation/instruction 等字段,则是训练数据
  38. if isinstance(obj, dict):
  39. data_keys = {"input", "output", "conversations", "instruction", "prompt",
  40. "text", "completion", "source", "target", "query", "response"}
  41. if data_keys & set(obj.keys()):
  42. return True
  43. return True # 大 JSON 文件默认是数据
  44. except Exception:
  45. return False
  46. # 无后缀文件:尝试读取判断是否为 JSON/JSONL
  47. if not path.suffix:
  48. try:
  49. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  50. json.loads(first_line)
  51. return True
  52. except Exception:
  53. return False
  54. return False
  55. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  56. """从 HuggingFace 或 ModelScope 下载数据集。"""
  57. try:
  58. if req.use_modelscope:
  59. ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
  60. else:
  61. from datasets import load_dataset
  62. ds = load_dataset(req.dataset_id)
  63. ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
  64. ds_dir.mkdir(parents=True, exist_ok=True)
  65. if "train" in ds:
  66. split = ds["train"]
  67. else:
  68. split = ds[list(ds.keys())[0]]
  69. output_path = ds_dir / "data.jsonl"
  70. with open(output_path, "w", encoding="utf-8") as f:
  71. for item in split:
  72. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  73. jsonl_path = output_path
  74. record_count = len(split) if hasattr(split, "__len__") else 0
  75. record = DatasetRecord(
  76. id=str(uuid.uuid4()),
  77. name=req.dataset_id,
  78. format="jsonl",
  79. record_count=record_count,
  80. file_path=str(jsonl_path),
  81. created_at=datetime.now(timezone.utc),
  82. )
  83. async with async_session() as session:
  84. session.add(record)
  85. await session.commit()
  86. logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
  87. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
  88. except Exception as e:
  89. logger.error(f"Dataset download failed: {e}")
  90. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
  91. def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
  92. """用 snapshot_download 下载数据集文件,完全绕过 datasets 库,避免版本兼容问题。"""
  93. from modelscope import snapshot_download
  94. ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
  95. ds_dir.mkdir(parents=True, exist_ok=True)
  96. local_path = snapshot_download(dataset_id, cache_dir=str(settings.processed_dir))
  97. # ModelScope 的 snapshot_download 把实际数据存到 cache_dir/downloads/<hash> 里
  98. # 而 local_path 指向的目录只有元数据文件,需要额外扫描 downloads 目录
  99. all_files = [p for p in Path(local_path).rglob("*") if p.is_file()]
  100. downloads_dir = settings.processed_dir / "downloads"
  101. if downloads_dir.exists():
  102. for p in downloads_dir.rglob("*"):
  103. if p.is_file() and str(p.parent) != str(ds_dir):
  104. all_files.append(p)
  105. # 识别训练数据文件
  106. data_files = [f for f in all_files if _is_training_data_file(f)]
  107. if not data_files:
  108. fallback = [f for f in all_files if f.suffix in (".json", ".jsonl")]
  109. logger.warning(f"No training data files found in {dataset_id}. "
  110. f"Available JSON files: {[f.name for f in fallback]}")
  111. if fallback:
  112. data_files = fallback
  113. else:
  114. raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}")
  115. # 优先取 train / data 开头的文件
  116. target = None
  117. for name in ("train.jsonl", "train.json", "data.jsonl", "data.json"):
  118. for f in data_files:
  119. if f.name == name:
  120. target = f
  121. break
  122. if target:
  123. break
  124. if not target:
  125. # 优先取数据量最大的文件
  126. target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0]
  127. logger.info(f"Selected data file: {target} (size={target.stat().st_size})")
  128. # 读取并统一转为 JSONL
  129. jsonl_path = ds_dir / "data.jsonl"
  130. record_count = 0
  131. content = target.read_text(encoding="utf-8")
  132. if target.suffix == ".jsonl" or not target.suffix:
  133. # JSONL 或无后缀文件:尝试逐行解析
  134. records = []
  135. for line in content.splitlines():
  136. line = line.strip()
  137. if not line:
  138. continue
  139. try:
  140. records.append(json.loads(line))
  141. except json.JSONDecodeError:
  142. # 如果逐行解析失败,尝试整体解析(可能是 JSON 数组)
  143. records = json.loads(content)
  144. if not isinstance(records, list):
  145. records = [records]
  146. break
  147. else:
  148. records = json.loads(content)
  149. if not isinstance(records, list):
  150. records = [records]
  151. with open(jsonl_path, "w", encoding="utf-8") as f:
  152. for item in records:
  153. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  154. record_count += 1
  155. return ds_dir, jsonl_path, record_count
  156. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  157. """保存上传文件并写入数据库。"""
  158. upload_dir = settings.uploads_dir
  159. upload_dir.mkdir(parents=True, exist_ok=True)
  160. safe_name = file.filename or "unknown"
  161. file_path = upload_dir / safe_name
  162. if file_path.exists():
  163. file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
  164. content = await file.read()
  165. file_path.write_bytes(content)
  166. fmt = _detect_format(safe_name)
  167. record_count = _count_records(file_path, fmt)
  168. record_id = str(uuid.uuid4())
  169. record = DatasetRecord(
  170. id=record_id,
  171. name=safe_name,
  172. format=fmt,
  173. record_count=record_count,
  174. file_path=str(file_path),
  175. created_at=datetime.now(timezone.utc),
  176. )
  177. async with async_session() as session:
  178. session.add(record)
  179. await session.commit()
  180. logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
  181. return {
  182. "id": record_id,
  183. "name": safe_name,
  184. "format": fmt,
  185. "record_count": record_count,
  186. "file_path": str(file_path),
  187. "created_at": record.created_at.isoformat(),
  188. }
  189. def _format_value(value) -> str:
  190. """将复杂值格式化为可读字符串。"""
  191. if isinstance(value, (dict, list)):
  192. return json.dumps(value, ensure_ascii=False, indent=2)
  193. return str(value)
  194. def _is_sharegpt_format(records: list[dict]) -> bool:
  195. """检测是否为 ShareGPT 格式。"""
  196. if not records:
  197. return False
  198. first = records[0]
  199. if "conversations" in first and isinstance(first["conversations"], list):
  200. if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
  201. conv = first["conversations"][0]
  202. return "from" in conv and "value" in conv
  203. return False
  204. def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
  205. """将 ShareGPT 格式展平为 input/output 列。"""
  206. flat_rows = []
  207. for row in records:
  208. conversations = row.get("conversations", [])
  209. for i in range(0, len(conversations) - 1, 2):
  210. user_turn = conversations[i]
  211. assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
  212. if user_turn.get("from") in ("human", "user"):
  213. input_text = str(user_turn.get("value", ""))
  214. output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  215. else:
  216. input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  217. output_text = str(user_turn.get("value", ""))
  218. if len(input_text) > 500:
  219. input_text = input_text[:500] + "..."
  220. if len(output_text) > 500:
  221. output_text = output_text[:500] + "..."
  222. flat_rows.append({"input": input_text, "output": output_text})
  223. return flat_rows, ["input", "output"]
  224. async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
  225. """预览数据集前 N 行。"""
  226. async with async_session() as session:
  227. from sqlalchemy import select
  228. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  229. record = result.scalar_one_or_none()
  230. if not record:
  231. return {"total_records": 0, "preview_rows": [], "columns": []}
  232. file_path = Path(record.file_path)
  233. if not file_path.exists():
  234. return {"total_records": 0, "preview_rows": [], "columns": []}
  235. fmt = record.format
  236. preview_data = _read_records(file_path, fmt, rows)
  237. # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
  238. if _is_sharegpt_format(preview_data):
  239. preview_data, columns = _flatten_sharegpt(preview_data)
  240. else:
  241. columns = list(preview_data[0].keys()) if preview_data else []
  242. return {
  243. "total_records": record.record_count,
  244. "preview_rows": [
  245. {
  246. "row_index": i,
  247. "data": {k: _format_value(v) for k, v in row.items()},
  248. }
  249. for i, row in enumerate(preview_data)
  250. ],
  251. "columns": columns,
  252. }
  253. async def validate_dataset(dataset_id: str) -> dict[str, Any]:
  254. """校验数据集格式和 Schema。"""
  255. async with async_session() as session:
  256. from sqlalchemy import select
  257. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  258. record = result.scalar_one_or_none()
  259. if not record:
  260. return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
  261. file_path = Path(record.file_path)
  262. if not file_path.exists():
  263. return {"is_valid": False, "errors": ["File not found"], "warnings": []}
  264. errors = []
  265. warnings = []
  266. fmt = record.format
  267. if fmt not in ("jsonl", "csv", "json", "parquet"):
  268. errors.append(f"Unsupported format: {fmt}")
  269. try:
  270. preview = _read_records(file_path, fmt, 5)
  271. if not preview:
  272. warnings.append("Dataset appears to be empty")
  273. else:
  274. first = preview[0]
  275. has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
  276. if not has_sft_fields:
  277. warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
  278. except Exception as e:
  279. errors.append(f"Failed to read file: {str(e)}")
  280. return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
  281. async def list_datasets() -> list[dict[str, Any]]:
  282. """列出所有已上传数据集。"""
  283. async with async_session() as session:
  284. from sqlalchemy import select
  285. result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
  286. records = result.scalars().all()
  287. return [
  288. {
  289. "id": r.id,
  290. "name": r.name,
  291. "format": r.format,
  292. "record_count": r.record_count,
  293. "file_path": r.file_path,
  294. "created_at": r.created_at.isoformat(),
  295. }
  296. for r in records
  297. ]
  298. async def delete_dataset(dataset_id: str) -> dict[str, Any]:
  299. """删除数据集。"""
  300. async with async_session() as session:
  301. from sqlalchemy import select
  302. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  303. record = result.scalar_one_or_none()
  304. if record:
  305. file_path = Path(record.file_path)
  306. if file_path.exists():
  307. file_path.unlink()
  308. await session.delete(record)
  309. await session.commit()
  310. logger.info(f"Deleted dataset: {record.name}")
  311. return {"status": "deleted"}
  312. def _detect_format(filename: str) -> str:
  313. ext = Path(filename).suffix.lower().lstrip(".")
  314. if ext in ("jsonl", "csv", "parquet", "json"):
  315. return ext
  316. return "unknown"
  317. def _count_records(file_path: Path, fmt: str) -> int:
  318. try:
  319. if fmt == "jsonl":
  320. return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
  321. elif fmt == "json":
  322. with open(file_path, encoding="utf-8") as f:
  323. data = json.load(f)
  324. return len(data) if isinstance(data, list) else 1
  325. elif fmt == "csv":
  326. import csv
  327. with open(file_path, encoding="utf-8") as f:
  328. return sum(1 for _ in csv.reader(f)) - 1
  329. elif fmt == "parquet":
  330. import pandas as pd
  331. return len(pd.read_parquet(file_path))
  332. except Exception:
  333. pass
  334. return 0
  335. def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
  336. if fmt == "jsonl":
  337. records = []
  338. with open(file_path, encoding="utf-8") as f:
  339. for i, line in enumerate(f):
  340. if i >= n:
  341. break
  342. line = line.strip()
  343. if line:
  344. records.append(json.loads(line))
  345. return records
  346. elif fmt == "json":
  347. with open(file_path, encoding="utf-8") as f:
  348. data = json.load(f)
  349. return data[:n] if isinstance(data, list) else [data]
  350. elif fmt == "csv":
  351. import csv
  352. with open(file_path, encoding="utf-8") as f:
  353. reader = csv.DictReader(f)
  354. return [dict(row) for i, row in enumerate(reader) if i < n]
  355. elif fmt == "parquet":
  356. import pandas as pd
  357. df = pd.read_parquet(file_path)
  358. return df.head(n).to_dict(orient="records")
  359. return []