dataset_service.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime, timezone
  5. from pathlib import Path
  6. from typing import Any
  7. from fastapi import UploadFile
  8. from app.config import get_settings
  9. from app.core.background_tasks import background_task_manager
  10. from app.core.db import async_session, DatasetRecord, DatasetDownloadTask
  11. from app.core.logging import logger
  12. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  13. from sqlalchemy import select
  14. settings = get_settings()
  15. # Known metadata filenames that are NOT training data
  16. META_FILENAMES = frozenset({
  17. "configuration.json", "configuration.yaml", "README.md",
  18. ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
  19. "special_tokens_map.json", "tokenizer_config.json",
  20. "added_tokens.json", "vocab.json", "merges.txt",
  21. "config.json", "preprocessor_config.json",
  22. "dataset_infos.json", "dataset_info.json",
  23. "state.json", "card_data.json",
  24. })
  25. # File size threshold: files smaller than this (bytes) are likely metadata
  26. META_SIZE_THRESHOLD = 500
  27. def _is_training_data_file(path: Path) -> bool:
  28. """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
  29. if path.name in META_FILENAMES:
  30. return False
  31. if path.suffix in (".jsonl", ".parquet", ".csv"):
  32. # 小文件可能是元数据(如 ModelScope CLI 生成的 data.jsonl 只有几十字节)
  33. if path.stat().st_size < META_SIZE_THRESHOLD:
  34. return False
  35. return True
  36. if path.suffix == ".json":
  37. # 小 JSON 文件通常是配置
  38. if path.stat().st_size < META_SIZE_THRESHOLD:
  39. return False
  40. # 尝试读取首行判断格式
  41. try:
  42. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  43. obj = json.loads(first_line)
  44. # 如果有 input/output/conversation/instruction 等字段,则是训练数据
  45. if isinstance(obj, dict):
  46. data_keys = {"input", "output", "conversations", "instruction", "prompt",
  47. "text", "completion", "source", "target", "query", "response"}
  48. if data_keys & set(obj.keys()):
  49. return True
  50. # 如果只有 framework/task/model_type 等字段,则是元数据
  51. meta_keys = {"framework", "task", "license", "base_model", "model_type",
  52. "language", "domains", "tags", "authors"}
  53. if meta_keys & set(obj.keys()):
  54. return False
  55. return True # 大 JSON 文件默认是数据
  56. except Exception:
  57. return False
  58. # 无后缀文件:尝试读取判断是否为 JSON/JSONL
  59. if not path.suffix:
  60. try:
  61. first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
  62. json.loads(first_line)
  63. return True
  64. except Exception:
  65. return False
  66. return False
  67. def _extract_archives(ds_dir: Path):
  68. """检测并解压数据集目录中的压缩包(zip/tar.gz/tar.bz2/tgz),
  69. 图片数据集通常将图片存放在压缩包中,需要解压后才能在预览时显示。"""
  70. import zipfile
  71. import tarfile
  72. extracted_any = False
  73. for f in list(ds_dir.rglob("*")):
  74. if not f.is_file():
  75. continue
  76. # 判断是否为压缩包
  77. name_lower = f.name.lower()
  78. is_zip = name_lower.endswith(".zip")
  79. is_tar = any(name_lower.endswith(ext) for ext in
  80. (".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar"))
  81. if not is_zip and not is_tar:
  82. continue
  83. # 用压缩包名(去掉所有后缀)作为解压目标目录
  84. stem = f.name
  85. for ext in (".tar.gz", ".tar.bz2", ".tgz", ".tbz2", ".tar", ".zip"):
  86. if stem.lower().endswith(ext):
  87. stem = stem[:-len(ext)]
  88. break
  89. extract_dir = f.parent / stem
  90. if extract_dir.exists():
  91. logger.info(f"Archive already extracted, skipping: {f.name}")
  92. continue
  93. logger.info(f"Extracting archive: {f.name} -> {extract_dir}")
  94. try:
  95. if is_zip:
  96. with zipfile.ZipFile(f, "r") as zf:
  97. zf.extractall(f.parent)
  98. else:
  99. with tarfile.open(f, "r:*") as tf:
  100. tf.extractall(f.parent)
  101. extracted_any = True
  102. logger.info(f"Successfully extracted: {f.name}")
  103. except Exception as e:
  104. logger.warning(f"Failed to extract {f.name}: {e}")
  105. if extracted_any:
  106. logger.info(f"Archive extraction completed for {ds_dir}")
  107. def _download_modelscope_data_files(dataset_id: str, ds_dir: Path):
  108. """通过 ModelScope API 下载图片数据集的压缩包。
  109. 图片数据集有一个 {dataset_name}.json 配置文件,记录了各 split 对应的
  110. 元数据文件和压缩包名称,例如:
  111. {"default": {"train": {"meta": "train.csv", "file": "train.zip"},
  112. "validation": {"meta": "val.csv", "file": "val.zip"}}}
  113. CLI download 只下载 git 仓库文件(CSV 等元数据),
  114. 压缩包需要通过 /api/v1/datasets/{ns}/{name}/repo?FilePath=... 单独下载。
  115. """
  116. import urllib.request
  117. import urllib.parse
  118. api_base = "https://www.modelscope.cn"
  119. namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
  120. # Step 1: 找到配置文件 {dataset_name}.json 并读取
  121. config_files = [p for p in ds_dir.glob("*.json") if p.name not in META_FILENAMES]
  122. if not config_files:
  123. # 尝试通过 API 下载配置文件
  124. config_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo"
  125. f"?Source=SDK&Revision=master&FilePath={ds_name}.json&View=false")
  126. try:
  127. logger.info(f"尝试下载配置文件: {ds_name}.json")
  128. req = urllib.request.Request(config_url, headers={"User-Agent": "FineTuning-Backend"})
  129. with urllib.request.urlopen(req, timeout=30) as resp:
  130. config_data = json.loads(resp.read().decode())
  131. config_path = ds_dir / f"{ds_name}.json"
  132. config_path.write_text(json.dumps(config_data, ensure_ascii=False), encoding="utf-8")
  133. config_files = [config_path]
  134. except Exception as e:
  135. logger.info(f"未找到配置文件 {ds_name}.json,跳过数据文件下载: {e}")
  136. return
  137. # 在所有 json 配置文件中找到包含 "file" 字段的那个
  138. config = None
  139. for cf in config_files:
  140. try:
  141. data = json.loads(cf.read_text(encoding="utf-8"))
  142. # 检查是否包含 file 字段(数据集配置格式)
  143. if isinstance(data, dict):
  144. for subset in data.values():
  145. if isinstance(subset, dict):
  146. for split_info in subset.values():
  147. if isinstance(split_info, dict) and "file" in split_info and split_info["file"]:
  148. config = data
  149. break
  150. if config:
  151. break
  152. if config:
  153. break
  154. except (json.JSONDecodeError, UnicodeDecodeError):
  155. continue
  156. if not config:
  157. logger.info("未找到包含数据文件引用的配置文件,跳过")
  158. return
  159. logger.info(f"找到数据文件配置: {json.dumps(config, ensure_ascii=False)}")
  160. # Step 2: 收集所有需要下载的压缩包文件名
  161. archive_files = set()
  162. for subset in config.values():
  163. if not isinstance(subset, dict):
  164. continue
  165. for split_info in subset.values():
  166. if isinstance(split_info, dict):
  167. fname = split_info.get("file", "")
  168. if fname:
  169. archive_files.add(fname)
  170. if not archive_files:
  171. logger.info("配置中未找到数据文件,跳过")
  172. return
  173. # Step 3: 下载压缩包
  174. existing = {f.name for f in ds_dir.rglob("*") if f.is_file()}
  175. downloaded = []
  176. for fname in archive_files:
  177. if fname in existing:
  178. logger.info(f"压缩包已存在,跳过: {fname}")
  179. continue
  180. params = urllib.parse.urlencode({
  181. "Source": "SDK", "Revision": "master",
  182. "FilePath": fname, "View": "false",
  183. })
  184. dl_url = f"{api_base}/api/v1/datasets/{dataset_id}/repo?{params}"
  185. dest = ds_dir / fname
  186. logger.info(f"下载数据文件: {fname}")
  187. logger.info(f" URL: {dl_url}")
  188. try:
  189. req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"})
  190. with urllib.request.urlopen(req, timeout=600) as resp:
  191. dest.write_bytes(resp.read())
  192. downloaded.append(fname)
  193. logger.info(f" 下载完成: {fname} ({dest.stat().st_size / 1024 / 1024:.1f}MB)")
  194. except Exception as e:
  195. logger.warning(f" 下载失败 {fname}: {e}")
  196. if downloaded:
  197. logger.info(f"共下载 {len(downloaded)} 个数据文件: {downloaded}")
  198. else:
  199. logger.info("没有需要下载的数据文件")
  200. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  201. """启动数据集下载后台任务,立即返回 task_id。"""
  202. task_id = str(uuid.uuid4())
  203. # 写 DB
  204. record = DatasetDownloadTask(
  205. id=task_id,
  206. dataset_id=req.dataset_id,
  207. use_modelscope=1 if req.use_modelscope else 0,
  208. status="pending",
  209. )
  210. async with async_session() as session:
  211. session.add(record)
  212. await session.commit()
  213. # 注册并启动
  214. background_task_manager.register_task(task_id, "dataset_download", {"dataset_id": req.dataset_id})
  215. await background_task_manager.run(
  216. task_id, "dataset_download", _execute_dataset_download(task_id, req)
  217. )
  218. logger.info(f"Dataset download task started: {req.dataset_id} (task_id={task_id})")
  219. return DatasetDownloadResponse(
  220. dataset_id=req.dataset_id, status="pending", task_id=task_id, path=task_id
  221. )
  222. async def _execute_dataset_download(task_id: str, req: DatasetDownloadRequest) -> dict:
  223. """后台执行数据集下载。"""
  224. try:
  225. if req.use_modelscope:
  226. ds_dir, jsonl_path, record_count = await asyncio.to_thread(
  227. _download_modelscope_dataset, req.dataset_id
  228. )
  229. else:
  230. from datasets import load_dataset
  231. ds = load_dataset(req.dataset_id)
  232. ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
  233. ds_dir.mkdir(parents=True, exist_ok=True)
  234. if "train" in ds:
  235. split = ds["train"]
  236. else:
  237. split = ds[list(ds.keys())[0]]
  238. output_path = ds_dir / "data.jsonl"
  239. with open(output_path, "w", encoding="utf-8") as f:
  240. for item in split:
  241. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  242. jsonl_path = output_path
  243. record_count = len(split) if hasattr(split, "__len__") else 0
  244. db_record = DatasetRecord(
  245. id=str(uuid.uuid4()),
  246. name=req.dataset_id,
  247. format="jsonl",
  248. record_count=record_count,
  249. file_path=str(jsonl_path),
  250. created_at=datetime.utcnow(),
  251. )
  252. async with async_session() as session:
  253. session.add(db_record)
  254. await session.commit()
  255. await _update_dataset_download_status(task_id, "completed", path=str(jsonl_path), record_count=record_count)
  256. logger.info(f"Dataset downloaded: {req.dataset_id} ({record_count} records)")
  257. return {"path": str(jsonl_path), "record_count": record_count}
  258. except Exception as e:
  259. logger.error(f"Dataset download failed: {e}")
  260. await _update_dataset_download_status(task_id, "failed", error=str(e))
  261. return {"error": str(e)}
  262. async def _update_dataset_download_status(task_id: str, status: str, path: str = None, error: str = None, record_count: int = 0):
  263. async with async_session() as session:
  264. result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
  265. record = result.scalar_one_or_none()
  266. if record:
  267. record.status = status
  268. if path:
  269. record.path = path
  270. if error:
  271. record.error = error
  272. if record_count:
  273. record.record_count = record_count
  274. if status in ("completed", "failed"):
  275. record.finished_at = datetime.utcnow()
  276. await session.commit()
  277. background_task_manager.update_task(
  278. task_id, status=status, path=path, error=error, record_count=record_count,
  279. )
  280. async def get_dataset_download_status(task_id: str) -> dict[str, Any]:
  281. async with async_session() as session:
  282. result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
  283. record = result.scalar_one_or_none()
  284. if record:
  285. return {
  286. "task_id": record.id,
  287. "dataset_id": record.dataset_id,
  288. "status": record.status,
  289. "use_modelscope": bool(record.use_modelscope),
  290. "path": record.path,
  291. "error": record.error,
  292. "record_count": record.record_count,
  293. "created_at": record.created_at.isoformat() if record.created_at else "",
  294. }
  295. mem = background_task_manager.get_task(task_id)
  296. if mem:
  297. return {
  298. "task_id": task_id,
  299. "dataset_id": mem.get("dataset_id", ""),
  300. "status": mem["status"],
  301. "error": mem.get("error"),
  302. "record_count": mem.get("record_count", 0),
  303. }
  304. return {"task_id": task_id, "status": "not_found"}
  305. async def list_dataset_downloads() -> list[dict[str, Any]]:
  306. async with async_session() as session:
  307. result = await session.execute(
  308. select(DatasetDownloadTask).order_by(DatasetDownloadTask.created_at.desc())
  309. )
  310. records = result.scalars().all()
  311. return [
  312. {
  313. "task_id": r.id,
  314. "dataset_id": r.dataset_id,
  315. "status": r.status,
  316. "use_modelscope": bool(r.use_modelscope),
  317. "path": r.path,
  318. "error": r.error,
  319. "record_count": r.record_count,
  320. "created_at": r.created_at.isoformat() if r.created_at else "",
  321. }
  322. for r in records
  323. ]
  324. async def cancel_dataset_download(task_id: str) -> dict[str, Any]:
  325. background_task_manager.cancel_task(task_id)
  326. async with async_session() as session:
  327. result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
  328. record = result.scalar_one_or_none()
  329. if record and record.status in ("pending", "downloading"):
  330. record.status = "cancelled"
  331. record.error = "Cancelled by user"
  332. record.finished_at = datetime.utcnow()
  333. await session.commit()
  334. return {"task_id": task_id, "status": "cancelled"}
  335. async def recover_stale_downloads() -> None:
  336. async with async_session() as session:
  337. result = await session.execute(
  338. select(DatasetDownloadTask).where(
  339. DatasetDownloadTask.status.in_(["pending", "downloading"])
  340. )
  341. )
  342. records = result.scalars().all()
  343. for record in records:
  344. record.status = "failed"
  345. record.error = "Server restarted, task interrupted"
  346. record.finished_at = datetime.utcnow()
  347. if records:
  348. await session.commit()
  349. logger.info(f"Recovered {len(records)} stale dataset download tasks")
  350. def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
  351. """用 MsDataset.load() 下载数据集,支持图片数据集(自动从 CDN 下载图片)。
  352. 如果 MsDataset.load() 失败,fallback 到 CLI 方式。"""
  353. namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
  354. ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
  355. ds_dir.mkdir(parents=True, exist_ok=True)
  356. # 优先用 MsDataset.load(),它能自动下载"数据文件"区的图片
  357. try:
  358. records, record_count = _download_via_msdataset(dataset_id, ds_dir)
  359. if records:
  360. jsonl_path = ds_dir / "data.jsonl"
  361. with open(jsonl_path, "w", encoding="utf-8") as f:
  362. for item in records:
  363. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  364. logger.info(f"MsDataset.load() 成功: {dataset_id} ({record_count} records)")
  365. return ds_dir, jsonl_path, record_count
  366. except Exception as e:
  367. logger.warning(f"MsDataset.load() failed for {dataset_id}: {e}, falling back to CLI")
  368. # fallback: CLI 方式(只下载 git 仓库文件,不含数据文件区图片)
  369. return _download_modelscope_dataset_cli(dataset_id, ds_dir)
  370. def _download_via_msdataset(dataset_id: str, ds_dir: Path) -> tuple[list[dict], int]:
  371. """用 MsDataset.load() 下载数据集,处理图片列(复制图片文件到数据集目录)。"""
  372. from modelscope.msdatasets import MsDataset
  373. from PIL import Image
  374. import shutil
  375. import os
  376. namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
  377. images_dir = ds_dir / "images"
  378. # 尝试加载不同 split
  379. ds = None
  380. for split in ("train", "validation", "test"):
  381. try:
  382. if namespace:
  383. ds = MsDataset.load(ds_name, namespace=namespace, split=split)
  384. else:
  385. ds = MsDataset.load(dataset_id, split=split)
  386. if ds:
  387. logger.info(f"MsDataset.load() loaded split '{split}': {len(ds) if hasattr(ds, '__len__') else '?'} records")
  388. break
  389. except Exception as e:
  390. logger.debug(f"split '{split}' failed: {e}")
  391. if not ds:
  392. # 不带 split 参数试试
  393. try:
  394. if namespace:
  395. ds = MsDataset.load(ds_name, namespace=namespace)
  396. else:
  397. ds = MsDataset.load(dataset_id)
  398. except Exception as e:
  399. logger.warning(f"MsDataset.load() without split also failed: {e}")
  400. return [], 0
  401. if not ds:
  402. return [], 0
  403. # 检查是否 iterable
  404. if not hasattr(ds, '__iter__'):
  405. return [], 0
  406. records = []
  407. img_counter = 0
  408. columns = None
  409. for row in ds:
  410. if not isinstance(row, dict):
  411. continue
  412. if columns is None:
  413. columns = list(row.keys())
  414. record = {}
  415. for k, v in row.items():
  416. # 检查是否是 PIL.Image 对象
  417. if isinstance(v, Image.Image):
  418. # 图片对象:保存到磁盘,记录相对路径
  419. images_dir.mkdir(parents=True, exist_ok=True)
  420. img_name = f"{img_counter:06d}.jpg"
  421. img_path = images_dir / img_name
  422. if v.mode in ("RGBA", "P", "LA"):
  423. v = v.convert("RGB")
  424. v.save(str(img_path), format="JPEG", quality=90)
  425. record[k] = f"images/{img_name}"
  426. img_counter += 1
  427. # 检查是否是图片文件路径
  428. elif isinstance(v, str) and v.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
  429. # 如果是绝对路径,复制文件到 images 目录
  430. if os.path.isabs(v) and os.path.exists(v):
  431. images_dir.mkdir(parents=True, exist_ok=True)
  432. ext = os.path.splitext(v)[1]
  433. img_name = f"{img_counter:06d}{ext}"
  434. dest_path = images_dir / img_name
  435. try:
  436. shutil.copy2(v, dest_path)
  437. record[k] = f"images/{img_name}"
  438. img_counter += 1
  439. except Exception as e:
  440. logger.warning(f"Failed to copy image {v}: {e}")
  441. record[k] = v
  442. else:
  443. # 相对路径或其他情况,保持原样
  444. record[k] = v
  445. else:
  446. record[k] = v
  447. records.append(record)
  448. # 进度日志
  449. if len(records) % 500 == 0:
  450. logger.info(f" 处理中... {len(records)} records, {img_counter} images saved")
  451. if img_counter > 0:
  452. logger.info(f"共保存 {img_counter} 张图片到 {images_dir}")
  453. return records, len(records)
  454. def _download_modelscope_dataset_cli(dataset_id: str, ds_dir: Path) -> tuple[Path, Path, int]:
  455. """CLI 方式下载数据集(fallback,只下载 git 仓库文件)。"""
  456. import subprocess
  457. cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)]
  458. logger.info(f"Fallback CLI: {' '.join(cmd)}")
  459. proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
  460. if proc.returncode != 0:
  461. logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
  462. raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
  463. # CLI 下载完 git 仓库文件后(数据文件区的图片需要通过 MsDataset.load 获取,CLI 只能下载元数据)
  464. # 扫描下载目录中的所有文件
  465. all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
  466. logger.info(f"CLI downloaded {len(all_files)} files to {ds_dir}")
  467. # 识别训练数据文件
  468. data_files = [f for f in all_files if _is_training_data_file(f)]
  469. if not data_files:
  470. fallback = [f for f in all_files if f.suffix in (".json", ".jsonl") and f.name not in META_FILENAMES and f.name != "README.md"]
  471. logger.warning(f"No training data files found in {dataset_id}. "
  472. f"Available JSON files: {[f.name for f in fallback]}")
  473. if fallback:
  474. data_files = fallback
  475. else:
  476. # 如果还是没有,列出所有文件供排查
  477. logger.error(f"All downloaded files: {[str(f.relative_to(ds_dir)) for f in all_files]}")
  478. raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}. "
  479. f"Available files: {[f.name for f in all_files]}")
  480. # 按文件大小排序,取最大的文件作为训练数据(真正的数据集通常是最大的)
  481. target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0]
  482. logger.info(f"Selected data file: {target} (size={target.stat().st_size})")
  483. # 读取并统一转为 JSONL
  484. jsonl_path = ds_dir / "data.jsonl"
  485. record_count = 0
  486. content = target.read_text(encoding="utf-8")
  487. if target.suffix == ".jsonl" or not target.suffix:
  488. # JSONL 或无后缀文件:逐行解析
  489. records = []
  490. for line in content.splitlines():
  491. line = line.strip()
  492. if not line:
  493. continue
  494. try:
  495. records.append(json.loads(line))
  496. except json.JSONDecodeError:
  497. # 单行解析失败,尝试整体解析
  498. try:
  499. data = json.loads(content)
  500. records = data if isinstance(data, list) else [data]
  501. except json.JSONDecodeError:
  502. records = []
  503. break
  504. elif target.suffix == ".json":
  505. # JSON 文件:先尝试 JSON 数组,失败再逐行解析(可能是 JSONL 格式)
  506. try:
  507. records = json.loads(content)
  508. if not isinstance(records, list):
  509. records = [records]
  510. except json.JSONDecodeError:
  511. records = []
  512. for line in content.splitlines():
  513. line = line.strip()
  514. if not line:
  515. continue
  516. try:
  517. records.append(json.loads(line))
  518. except json.JSONDecodeError:
  519. continue
  520. elif target.suffix == ".csv":
  521. import csv as _csv
  522. records = []
  523. reader = _csv.DictReader(content.splitlines())
  524. for row in reader:
  525. records.append(dict(row))
  526. else:
  527. records = []
  528. with open(jsonl_path, "w", encoding="utf-8") as f:
  529. for item in records:
  530. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  531. record_count += 1
  532. return ds_dir, jsonl_path, record_count
  533. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  534. """保存上传文件并写入数据库。"""
  535. upload_dir = settings.uploads_dir
  536. upload_dir.mkdir(parents=True, exist_ok=True)
  537. safe_name = file.filename or "unknown"
  538. file_path = upload_dir / safe_name
  539. if file_path.exists():
  540. file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
  541. content = await file.read()
  542. file_path.write_bytes(content)
  543. fmt = _detect_format(safe_name)
  544. record_count = _count_records(file_path, fmt)
  545. record_id = str(uuid.uuid4())
  546. record = DatasetRecord(
  547. id=record_id,
  548. name=safe_name,
  549. format=fmt,
  550. record_count=record_count,
  551. file_path=str(file_path),
  552. created_at=datetime.utcnow(),
  553. )
  554. async with async_session() as session:
  555. session.add(record)
  556. await session.commit()
  557. logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
  558. return {
  559. "id": record_id,
  560. "name": safe_name,
  561. "format": fmt,
  562. "record_count": record_count,
  563. "file_path": str(file_path),
  564. "created_at": record.created_at.isoformat(),
  565. }
  566. def _detect_image_column(columns: list[str]) -> str | None:
  567. """检测哪一列是图片路径列。"""
  568. candidates = ["image_path", "image", "img_path", "img", "file_path", "filename", "path", "file"]
  569. for c in candidates:
  570. if c in columns:
  571. return c
  572. # 模糊匹配:列名包含 image 或 path
  573. for c in columns:
  574. cl = c.lower()
  575. if "image" in cl or ("path" in cl and "label" not in cl):
  576. return c
  577. return None
  578. def _resolve_image_path(path_str: str, data_dir: Path) -> Path | None:
  579. """解析图片路径,返回绝对路径。"""
  580. if not path_str:
  581. return None
  582. p = Path(path_str)
  583. # 已经是绝对路径
  584. if p.is_absolute():
  585. return p if p.exists() else None
  586. # 相对路径:先尝试相对于数据目录
  587. candidate = data_dir / p
  588. if candidate.exists():
  589. return candidate
  590. # 也可能直接在 data_dir 下(去掉目录前缀只保留文件名)
  591. if data_dir.joinpath(p.name).exists():
  592. return data_dir / p.name
  593. # 在 data_dir 的子目录中递归查找
  594. for child in data_dir.rglob(p.name):
  595. if child.is_file():
  596. return child
  597. logger.debug(f"Image not found: '{path_str}' (searched in {data_dir})")
  598. return None
  599. def _encode_image_base64(image_path: Path, max_size: int = 200) -> str | None:
  600. """将图片转为 base64 data URI,用于前端预览。"""
  601. import base64
  602. try:
  603. from PIL import Image
  604. img = Image.open(image_path)
  605. # 缩小尺寸用于预览
  606. img.thumbnail((max_size, max_size))
  607. if img.mode in ("RGBA", "P", "LA"):
  608. img = img.convert("RGB")
  609. import io
  610. buf = io.BytesIO()
  611. img.save(buf, format="JPEG", quality=85)
  612. b64 = base64.b64encode(buf.getvalue()).decode("ascii")
  613. return f"data:image/jpeg;base64,{b64}"
  614. except Exception as e:
  615. logger.warning(f"Failed to encode image {image_path}: {e}")
  616. return None
  617. def _format_value(value) -> str:
  618. """将复杂值格式化为可读字符串。"""
  619. if isinstance(value, (dict, list)):
  620. return json.dumps(value, ensure_ascii=False, indent=2)
  621. return str(value)
  622. def _is_sharegpt_format(records: list[dict]) -> bool:
  623. """检测是否为 ShareGPT 格式。"""
  624. if not records:
  625. return False
  626. first = records[0]
  627. if "conversations" in first and isinstance(first["conversations"], list):
  628. if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
  629. conv = first["conversations"][0]
  630. return "from" in conv and "value" in conv
  631. return False
  632. def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
  633. """将 ShareGPT 格式展平为 input/output 列。"""
  634. flat_rows = []
  635. for row in records:
  636. conversations = row.get("conversations", [])
  637. for i in range(0, len(conversations) - 1, 2):
  638. user_turn = conversations[i]
  639. assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
  640. if user_turn.get("from") in ("human", "user"):
  641. input_text = str(user_turn.get("value", ""))
  642. output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  643. else:
  644. input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
  645. output_text = str(user_turn.get("value", ""))
  646. if len(input_text) > 500:
  647. input_text = input_text[:500] + "..."
  648. if len(output_text) > 500:
  649. output_text = output_text[:500] + "..."
  650. flat_rows.append({"input": input_text, "output": output_text})
  651. return flat_rows, ["input", "output"]
  652. async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
  653. """预览数据集前 N 行。"""
  654. async with async_session() as session:
  655. from sqlalchemy import select
  656. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  657. record = result.scalar_one_or_none()
  658. if not record:
  659. return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None}
  660. file_path = Path(record.file_path)
  661. if not file_path.exists():
  662. return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None}
  663. fmt = record.format
  664. preview_data = _read_records(file_path, fmt, rows)
  665. # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
  666. if _is_sharegpt_format(preview_data):
  667. preview_data, columns = _flatten_sharegpt(preview_data)
  668. else:
  669. columns = list(preview_data[0].keys()) if preview_data else []
  670. # 检测是否为视觉数据集(有图片路径列),将图片转为 base64 嵌入预览
  671. image_column = _detect_image_column(columns)
  672. data_dir = file_path.parent
  673. preview_rows = []
  674. for i, row in enumerate(preview_data):
  675. data = {}
  676. for k, v in row.items():
  677. if k == image_column:
  678. # 解析图片路径,转为 base64 嵌入
  679. img_path = _resolve_image_path(str(v), data_dir)
  680. if img_path:
  681. encoded = _encode_image_base64(img_path)
  682. data[k] = encoded if encoded else str(v)
  683. else:
  684. # 路径解析失败,保留原始路径文本
  685. data[k] = str(v)
  686. else:
  687. data[k] = _format_value(v)
  688. preview_rows.append({"row_index": i, "data": data})
  689. return {
  690. "total_records": record.record_count,
  691. "preview_rows": preview_rows,
  692. "columns": columns,
  693. "image_column": image_column,
  694. }
  695. async def validate_dataset(dataset_id: str) -> dict[str, Any]:
  696. """校验数据集格式和 Schema。"""
  697. async with async_session() as session:
  698. from sqlalchemy import select
  699. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  700. record = result.scalar_one_or_none()
  701. if not record:
  702. return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
  703. file_path = Path(record.file_path)
  704. if not file_path.exists():
  705. return {"is_valid": False, "errors": ["File not found"], "warnings": []}
  706. errors = []
  707. warnings = []
  708. fmt = record.format
  709. if fmt not in ("jsonl", "csv", "json", "parquet"):
  710. errors.append(f"Unsupported format: {fmt}")
  711. try:
  712. preview = _read_records(file_path, fmt, 5)
  713. if not preview:
  714. warnings.append("Dataset appears to be empty")
  715. else:
  716. first = preview[0]
  717. has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
  718. if not has_sft_fields:
  719. warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
  720. except Exception as e:
  721. errors.append(f"Failed to read file: {str(e)}")
  722. return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
  723. async def list_datasets() -> list[dict[str, Any]]:
  724. """列出所有已上传数据集。"""
  725. async with async_session() as session:
  726. from sqlalchemy import select
  727. result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
  728. records = result.scalars().all()
  729. return [
  730. {
  731. "id": r.id,
  732. "name": r.name,
  733. "format": r.format,
  734. "record_count": r.record_count,
  735. "file_path": r.file_path,
  736. "created_at": r.created_at.isoformat(),
  737. }
  738. for r in records
  739. ]
  740. async def delete_dataset(dataset_id: str) -> dict[str, Any]:
  741. """删除数据集,同时清理关联的目录文件。"""
  742. import shutil
  743. async with async_session() as session:
  744. from sqlalchemy import select
  745. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  746. record = result.scalar_one_or_none()
  747. if record:
  748. file_path = Path(record.file_path)
  749. # 下载的数据集(processed_dir 下的子目录):删除整个目录
  750. if file_path.exists() and settings.processed_dir in file_path.parents:
  751. ds_dir = file_path.parent
  752. shutil.rmtree(ds_dir, ignore_errors=True)
  753. logger.info(f"Deleted dataset directory: {ds_dir}")
  754. elif file_path.exists():
  755. # 上传的数据集:只删文件
  756. file_path.unlink()
  757. await session.delete(record)
  758. await session.commit()
  759. logger.info(f"Deleted dataset: {record.name}")
  760. return {"status": "deleted"}
  761. def _detect_format(filename: str) -> str:
  762. ext = Path(filename).suffix.lower().lstrip(".")
  763. if ext in ("jsonl", "csv", "parquet", "json"):
  764. return ext
  765. return "unknown"
  766. def _count_records(file_path: Path, fmt: str) -> int:
  767. try:
  768. if fmt == "jsonl":
  769. return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
  770. elif fmt == "json":
  771. with open(file_path, encoding="utf-8") as f:
  772. data = json.load(f)
  773. return len(data) if isinstance(data, list) else 1
  774. elif fmt == "csv":
  775. import csv
  776. with open(file_path, encoding="utf-8") as f:
  777. return sum(1 for _ in csv.reader(f)) - 1
  778. elif fmt == "parquet":
  779. import pandas as pd
  780. return len(pd.read_parquet(file_path))
  781. except Exception:
  782. pass
  783. return 0
  784. def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
  785. if fmt == "jsonl":
  786. records = []
  787. with open(file_path, encoding="utf-8") as f:
  788. for i, line in enumerate(f):
  789. if i >= n:
  790. break
  791. line = line.strip()
  792. if line:
  793. records.append(json.loads(line))
  794. return records
  795. elif fmt == "json":
  796. with open(file_path, encoding="utf-8") as f:
  797. data = json.load(f)
  798. return data[:n] if isinstance(data, list) else [data]
  799. elif fmt == "csv":
  800. import csv
  801. with open(file_path, encoding="utf-8") as f:
  802. reader = csv.DictReader(f)
  803. return [dict(row) for i, row in enumerate(reader) if i < n]
  804. elif fmt == "parquet":
  805. import pandas as pd
  806. df = pd.read_parquet(file_path)
  807. return df.head(n).to_dict(orient="records")
  808. return []