dataset_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime
  5. from pathlib import Path
  6. from typing import Any
  7. from fastapi import UploadFile
  8. from app.config import get_settings
  9. from app.core.db import async_session, DatasetRecord
  10. from app.core.logging import logger
  11. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  12. settings = get_settings()
  13. # Known metadata filenames that are NOT training data
  14. META_FILENAMES = frozenset({
  15. "configuration.json", "configuration.yaml", "README.md",
  16. ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
  17. "special_tokens_map.json", "tokenizer_config.json",
  18. "added_tokens.json", "vocab.json", "merges.txt",
  19. "config.json", "preprocessor_config.json",
  20. # HF/ModelScope dataset metadata
  21. "dataset_info.json", "dataset_infos.json", "dataset.json",
  22. "state.json", "dataset_dict.json",
  23. })
  24. # File size threshold: files smaller than this (bytes) are likely metadata
  25. META_SIZE_THRESHOLD = 500
  26. def _is_training_data_file(path: Path) -> bool:
  27. """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
  28. if path.suffix in (".jsonl", ".parquet", ".csv"):
  29. return True
  30. if path.suffix == ".json":
  31. if path.name in META_FILENAMES:
  32. return False
  33. # 小 JSON 文件通常是配置
  34. if path.stat().st_size < META_SIZE_THRESHOLD:
  35. return False
  36. # 尝试读取首行判断格式
  37. try:
  38. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  39. obj = json.loads(first_line)
  40. # 如果有 input/output/conversation/instruction 等字段,则是训练数据
  41. if isinstance(obj, dict):
  42. data_keys = {"input", "output", "conversations", "instruction", "prompt",
  43. "text", "completion", "source", "target", "query", "response"}
  44. if data_keys & set(obj.keys()):
  45. return True
  46. return True # 大 JSON 文件默认是数据
  47. except Exception:
  48. return False
  49. # 无后缀文件:尝试读取判断是否为 JSON/JSONL
  50. if not path.suffix:
  51. try:
  52. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  53. json.loads(first_line)
  54. return True
  55. except Exception:
  56. return False
  57. return False
  58. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  59. """从 HuggingFace 或 ModelScope 下载数据集。"""
  60. try:
  61. if req.use_modelscope:
  62. import subprocess
  63. ds_dir = settings.processed_dir / f"ms_{req.dataset_id.replace('/', '_')}"
  64. ds_dir.mkdir(parents=True, exist_ok=True)
  65. # 优先尝试用 modelscope 的 load_dataset API 加载数据
  66. try:
  67. from modelscope.msdatasets import MsDataset
  68. ms_ds = MsDataset.load(req.dataset_id)
  69. if hasattr(ms_ds, '__getitem__') and hasattr(ms_ds, '__len__'):
  70. split = ms_ds
  71. elif isinstance(ms_ds, dict):
  72. split = ms_ds.get("train") or ms_ds.get("default") or list(ms_ds.values())[0]
  73. else:
  74. split = ms_ds
  75. output_path = ds_dir / "data.jsonl"
  76. record_count = 0
  77. with open(output_path, "w", encoding="utf-8") as f:
  78. for item in split:
  79. f.write(json.dumps({k: str(v) for k, v in item.items()}, ensure_ascii=False) + "\n")
  80. record_count += 1
  81. if record_count == 0:
  82. raise RuntimeError("MsDataset loaded but returned 0 records")
  83. jsonl_path = output_path
  84. except (ImportError, RuntimeError) as e:
  85. # 回退到 CLI 下载方式
  86. logger.warning(f"MsDataset.load failed: {e}, falling back to CLI download")
  87. proc = subprocess.run(
  88. [
  89. "modelscope", "download",
  90. "--dataset", req.dataset_id,
  91. "--local_dir", str(ds_dir),
  92. ],
  93. capture_output=True, text=True, timeout=3600,
  94. )
  95. if proc.returncode != 0:
  96. raise RuntimeError(f"modelscope CLI failed: {proc.stderr}")
  97. jsonl_path, record_count = _scan_and_convert_to_jsonl(ds_dir)
  98. if record_count == 0:
  99. raise RuntimeError("No training data found in downloaded dataset files")
  100. else:
  101. from datasets import load_dataset
  102. ds = load_dataset(req.dataset_id)
  103. ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
  104. ds_dir.mkdir(parents=True, exist_ok=True)
  105. if "train" in ds:
  106. split = ds["train"]
  107. else:
  108. split = ds[list(ds.keys())[0]]
  109. output_path = ds_dir / "data.jsonl"
  110. with open(output_path, "w", encoding="utf-8") as f:
  111. for item in split:
  112. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  113. jsonl_path = output_path
  114. record_count = len(split) if hasattr(split, "__len__") else 0
  115. if record_count == 0:
  116. raise RuntimeError("HF dataset loaded but returned 0 records")
  117. record = DatasetRecord(
  118. id=str(uuid.uuid4()),
  119. name=req.dataset_id,
  120. format="jsonl",
  121. record_count=record_count,
  122. file_path=str(jsonl_path),
  123. created_at=datetime.utcnow(),
  124. )
  125. async with async_session() as session:
  126. session.add(record)
  127. await session.commit()
  128. logger.info(f"Downloaded dataset: {req.dataset_id} ({record_count} records, source={'ModelScope' if req.use_modelscope else 'HuggingFace'})")
  129. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(jsonl_path))
  130. except Exception as e:
  131. logger.error(f"Dataset download failed: {e}")
  132. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
  133. def _scan_and_convert_to_jsonl(ds_dir: Path) -> tuple[Path, int]:
  134. """扫描 CLI 下载的数据集目录,找训练数据文件并转为 JSONL。"""
  135. # 找所有可能的数据文件
  136. data_files = []
  137. for ext in ("*.jsonl", "*.json", "*.csv"):
  138. data_files.extend(ds_dir.rglob(ext))
  139. # 过滤掉元数据文件
  140. data_files = [f for f in data_files if f.name not in META_FILENAMES]
  141. if not data_files:
  142. raise RuntimeError(f"No dataset files found in {ds_dir}")
  143. jsonl_path = ds_dir / "data.jsonl"
  144. record_count = 0
  145. with open(jsonl_path, "w", encoding="utf-8") as out:
  146. for data_file in data_files:
  147. if data_file.suffix == ".jsonl":
  148. with open(data_file, "r", encoding="utf-8") as f:
  149. for line in f:
  150. line = line.strip()
  151. if line:
  152. out.write(line + "\n")
  153. record_count += 1
  154. elif data_file.suffix == ".json":
  155. try:
  156. with open(data_file, "r", encoding="utf-8") as f:
  157. data = json.load(f)
  158. if isinstance(data, list):
  159. for item in data:
  160. out.write(json.dumps(item, ensure_ascii=False) + "\n")
  161. record_count += 1
  162. elif isinstance(data, dict):
  163. # 跳过 HF/ModelScope dataset metadata(features/splits 结构)
  164. if "features" in data or "splits" in data or "dataset_name" in data:
  165. continue
  166. out.write(json.dumps(data, ensure_ascii=False) + "\n")
  167. record_count += 1
  168. except Exception:
  169. pass
  170. elif data_file.suffix == ".csv":
  171. import csv
  172. with open(data_file, "r", encoding="utf-8") as f:
  173. reader = csv.DictReader(f)
  174. for row in reader:
  175. out.write(json.dumps(dict(row), ensure_ascii=False) + "\n")
  176. record_count += 1
  177. return jsonl_path, record_count
  178. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  179. """保存上传文件并写入数据库。"""
  180. upload_dir = settings.uploads_dir
  181. upload_dir.mkdir(parents=True, exist_ok=True)
  182. safe_name = file.filename or "unknown"
  183. file_path = upload_dir / safe_name
  184. if file_path.exists():
  185. file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
  186. content = await file.read()
  187. file_path.write_bytes(content)
  188. fmt = _detect_format(safe_name)
  189. record_count = _count_records(file_path, fmt)
  190. record_id = str(uuid.uuid4())
  191. record = DatasetRecord(
  192. id=record_id,
  193. name=safe_name,
  194. format=fmt,
  195. record_count=record_count,
  196. file_path=str(file_path),
  197. created_at=datetime.utcnow(),
  198. )
  199. async with async_session() as session:
  200. session.add(record)
  201. await session.commit()
  202. logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
  203. return {
  204. "id": record_id,
  205. "name": safe_name,
  206. "format": fmt,
  207. "record_count": record_count,
  208. "file_path": str(file_path),
  209. "created_at": record.created_at.isoformat(),
  210. }
  211. def _format_value(value) -> str:
  212. """将复杂值格式化为可读字符串。"""
  213. if isinstance(value, (dict, list)):
  214. return json.dumps(value, ensure_ascii=False, indent=2)
  215. return str(value)
  216. def _is_sharegpt_format(records: list[dict]) -> bool:
  217. """检测是否为 ShareGPT 格式。"""
  218. if not records:
  219. return False
  220. first = records[0]
  221. if "conversations" in first and isinstance(first["conversations"], list):
  222. if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
  223. conv = first["conversations"][0]
  224. return "from" in conv and "value" in conv
  225. return False
  226. def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
  227. """将 ShareGPT 格式展平为 input/output 列。"""
  228. flat_rows = []
  229. for row in records:
  230. conversations = row.get("conversations", [])
  231. for i in range(0, len(conversations) - 1, 2):
  232. user_turn = conversations[i]
  233. assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
  234. if user_turn.get("from") in ("human", "user"):
  235. input_text = str(user_turn.get("value", ""))
  236. output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  237. else:
  238. input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  239. output_text = str(user_turn.get("value", ""))
  240. if len(input_text) > 500:
  241. input_text = input_text[:500] + "..."
  242. if len(output_text) > 500:
  243. output_text = output_text[:500] + "..."
  244. flat_rows.append({"input": input_text, "output": output_text})
  245. return flat_rows, ["input", "output"]
  246. async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
  247. """预览数据集前 N 行。"""
  248. async with async_session() as session:
  249. from sqlalchemy import select
  250. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  251. record = result.scalar_one_or_none()
  252. if not record:
  253. return {"total_records": 0, "preview_rows": [], "columns": []}
  254. file_path = Path(record.file_path)
  255. if not file_path.exists():
  256. return {"total_records": 0, "preview_rows": [], "columns": []}
  257. fmt = record.format
  258. preview_data = _read_records(file_path, fmt, rows)
  259. # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
  260. if _is_sharegpt_format(preview_data):
  261. preview_data, columns = _flatten_sharegpt(preview_data)
  262. else:
  263. columns = list(preview_data[0].keys()) if preview_data else []
  264. return {
  265. "total_records": record.record_count,
  266. "preview_rows": [
  267. {
  268. "row_index": i,
  269. "data": {k: _format_value(v) for k, v in row.items()},
  270. }
  271. for i, row in enumerate(preview_data)
  272. ],
  273. "columns": columns,
  274. }
  275. async def validate_dataset(dataset_id: str) -> dict[str, Any]:
  276. """校验数据集格式和 Schema。"""
  277. async with async_session() as session:
  278. from sqlalchemy import select
  279. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  280. record = result.scalar_one_or_none()
  281. if not record:
  282. return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
  283. file_path = Path(record.file_path)
  284. if not file_path.exists():
  285. return {"is_valid": False, "errors": ["File not found"], "warnings": []}
  286. errors = []
  287. warnings = []
  288. fmt = record.format
  289. if fmt not in ("jsonl", "csv", "json", "parquet"):
  290. errors.append(f"Unsupported format: {fmt}")
  291. try:
  292. preview = _read_records(file_path, fmt, 5)
  293. if not preview:
  294. warnings.append("Dataset appears to be empty")
  295. else:
  296. first = preview[0]
  297. has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
  298. if not has_sft_fields:
  299. warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
  300. except Exception as e:
  301. errors.append(f"Failed to read file: {str(e)}")
  302. return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
  303. async def list_datasets() -> list[dict[str, Any]]:
  304. """列出所有已上传数据集。"""
  305. async with async_session() as session:
  306. from sqlalchemy import select
  307. result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
  308. records = result.scalars().all()
  309. return [
  310. {
  311. "id": r.id,
  312. "name": r.name,
  313. "format": r.format,
  314. "record_count": r.record_count,
  315. "file_path": r.file_path,
  316. "created_at": r.created_at.isoformat(),
  317. }
  318. for r in records
  319. ]
  320. async def delete_dataset(dataset_id: str) -> dict[str, Any]:
  321. """删除数据集。"""
  322. async with async_session() as session:
  323. from sqlalchemy import select
  324. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  325. record = result.scalar_one_or_none()
  326. if record:
  327. file_path = Path(record.file_path)
  328. if file_path.exists():
  329. file_path.unlink()
  330. await session.delete(record)
  331. await session.commit()
  332. logger.info(f"Deleted dataset: {record.name}")
  333. return {"status": "deleted"}
  334. def _detect_format(filename: str) -> str:
  335. ext = Path(filename).suffix.lower().lstrip(".")
  336. if ext in ("jsonl", "csv", "parquet", "json"):
  337. return ext
  338. return "unknown"
  339. def _count_records(file_path: Path, fmt: str) -> int:
  340. try:
  341. if fmt == "jsonl":
  342. return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
  343. elif fmt == "json":
  344. with open(file_path, encoding="utf-8") as f:
  345. data = json.load(f)
  346. return len(data) if isinstance(data, list) else 1
  347. elif fmt == "csv":
  348. import csv
  349. with open(file_path, encoding="utf-8") as f:
  350. return sum(1 for _ in csv.reader(f)) - 1
  351. elif fmt == "parquet":
  352. import pandas as pd
  353. return len(pd.read_parquet(file_path))
  354. except Exception:
  355. pass
  356. return 0
  357. def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
  358. if fmt == "jsonl":
  359. records = []
  360. with open(file_path, encoding="utf-8") as f:
  361. for i, line in enumerate(f):
  362. if i >= n:
  363. break
  364. line = line.strip()
  365. if line:
  366. records.append(json.loads(line))
  367. return records
  368. elif fmt == "json":
  369. with open(file_path, encoding="utf-8") as f:
  370. data = json.load(f)
  371. return data[:n] if isinstance(data, list) else [data]
  372. elif fmt == "csv":
  373. import csv
  374. with open(file_path, encoding="utf-8") as f:
  375. reader = csv.DictReader(f)
  376. return [dict(row) for i, row in enumerate(reader) if i < n]
  377. elif fmt == "parquet":
  378. import pandas as pd
  379. df = pd.read_parquet(file_path)
  380. return df.head(n).to_dict(orient="records")
  381. return []