dataset_service.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. import json
  2. import uuid
  3. from datetime import datetime, timezone
  4. from pathlib import Path
  5. from typing import Any
  6. from fastapi import UploadFile
  7. from app.config import get_settings
  8. from app.core.db import async_session, DatasetRecord
  9. from app.core.logging import logger
  10. from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
  11. settings = get_settings()
  12. async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
  13. """从 HuggingFace 或 ModelScope 下载数据集。"""
  14. try:
  15. from datasets import load_dataset
  16. ds = load_dataset(req.dataset_id)
  17. ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
  18. ds_dir.mkdir(parents=True, exist_ok=True)
  19. # 保存为 JSONL
  20. if "train" in ds:
  21. split = ds["train"]
  22. else:
  23. split = ds[list(ds.keys())[0]]
  24. output_path = ds_dir / "data.jsonl"
  25. with open(output_path, "w", encoding="utf-8") as f:
  26. for item in split:
  27. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  28. # 写入数据库
  29. record = DatasetRecord(
  30. id=str(uuid.uuid4()),
  31. name=req.dataset_id,
  32. format="jsonl",
  33. record_count=len(split),
  34. file_path=str(output_path),
  35. created_at=datetime.now(timezone.utc),
  36. )
  37. async with async_session() as session:
  38. session.add(record)
  39. await session.commit()
  40. logger.info(f"Downloaded dataset: {req.dataset_id} ({len(split)} records)")
  41. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(output_path))
  42. except Exception as e:
  43. logger.error(f"Dataset download failed: {e}")
  44. return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
  45. async def upload_dataset(file: UploadFile) -> dict[str, Any]:
  46. """保存上传文件并写入数据库。"""
  47. upload_dir = settings.uploads_dir
  48. upload_dir.mkdir(parents=True, exist_ok=True)
  49. # 避免文件名冲突
  50. safe_name = file.filename or "unknown"
  51. file_path = upload_dir / safe_name
  52. if file_path.exists():
  53. file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
  54. content = await file.read()
  55. file_path.write_bytes(content)
  56. fmt = _detect_format(safe_name)
  57. record_count = _count_records(file_path, fmt)
  58. record_id = str(uuid.uuid4())
  59. record = DatasetRecord(
  60. id=record_id,
  61. name=safe_name,
  62. format=fmt,
  63. record_count=record_count,
  64. file_path=str(file_path),
  65. created_at=datetime.now(timezone.utc),
  66. )
  67. async with async_session() as session:
  68. session.add(record)
  69. await session.commit()
  70. logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
  71. return {
  72. "id": record_id,
  73. "name": safe_name,
  74. "format": fmt,
  75. "record_count": record_count,
  76. "file_path": str(file_path),
  77. "created_at": record.created_at.isoformat(),
  78. }
  79. async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
  80. """预览数据集前 N 行。"""
  81. async with async_session() as session:
  82. from sqlalchemy import select
  83. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  84. record = result.scalar_one_or_none()
  85. if not record:
  86. return {"total_records": 0, "preview_rows": [], "columns": []}
  87. file_path = Path(record.file_path)
  88. if not file_path.exists():
  89. return {"total_records": 0, "preview_rows": [], "columns": []}
  90. fmt = record.format
  91. preview_data = _read_records(file_path, fmt, rows)
  92. columns = list(preview_data[0].keys()) if preview_data else []
  93. return {
  94. "total_records": record.record_count,
  95. "preview_rows": [{"row_index": i, "data": row} for i, row in enumerate(preview_data)],
  96. "columns": columns,
  97. }
  98. async def validate_dataset(dataset_id: str) -> dict[str, Any]:
  99. """校验数据集格式和 Schema。"""
  100. async with async_session() as session:
  101. from sqlalchemy import select
  102. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  103. record = result.scalar_one_or_none()
  104. if not record:
  105. return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
  106. file_path = Path(record.file_path)
  107. if not file_path.exists():
  108. return {"is_valid": False, "errors": ["File not found"], "warnings": []}
  109. errors = []
  110. warnings = []
  111. # 检查格式
  112. fmt = record.format
  113. if fmt not in ("jsonl", "csv", "json", "parquet"):
  114. errors.append(f"Unsupported format: {fmt}")
  115. # 检查内容
  116. try:
  117. preview = _read_records(file_path, fmt, 5)
  118. if not preview:
  119. warnings.append("Dataset appears to be empty")
  120. else:
  121. # 检查必需字段(SFT 格式)
  122. first = preview[0]
  123. has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
  124. if not has_sft_fields:
  125. warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
  126. except Exception as e:
  127. errors.append(f"Failed to read file: {str(e)}")
  128. return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
  129. async def list_datasets() -> list[dict[str, Any]]:
  130. """列出所有已上传数据集。"""
  131. async with async_session() as session:
  132. from sqlalchemy import select
  133. result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
  134. records = result.scalars().all()
  135. return [
  136. {
  137. "id": r.id,
  138. "name": r.name,
  139. "format": r.format,
  140. "record_count": r.record_count,
  141. "file_path": r.file_path,
  142. "created_at": r.created_at.isoformat(),
  143. }
  144. for r in records
  145. ]
  146. async def delete_dataset(dataset_id: str) -> dict[str, Any]:
  147. """删除数据集。"""
  148. async with async_session() as session:
  149. from sqlalchemy import select
  150. result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
  151. record = result.scalar_one_or_none()
  152. if record:
  153. # 删除文件
  154. file_path = Path(record.file_path)
  155. if file_path.exists():
  156. file_path.unlink()
  157. await session.delete(record)
  158. await session.commit()
  159. logger.info(f"Deleted dataset: {record.name}")
  160. return {"status": "deleted"}
  161. def _detect_format(filename: str) -> str:
  162. ext = Path(filename).suffix.lower().lstrip(".")
  163. if ext in ("jsonl", "csv", "parquet", "json"):
  164. return ext
  165. return "unknown"
  166. def _count_records(file_path: Path, fmt: str) -> int:
  167. try:
  168. if fmt == "jsonl":
  169. return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
  170. elif fmt == "json":
  171. with open(file_path, encoding="utf-8") as f:
  172. data = json.load(f)
  173. return len(data) if isinstance(data, list) else 1
  174. elif fmt == "csv":
  175. import csv
  176. with open(file_path, encoding="utf-8") as f:
  177. return sum(1 for _ in csv.reader(f)) - 1 # minus header
  178. elif fmt == "parquet":
  179. import pandas as pd
  180. return len(pd.read_parquet(file_path))
  181. except Exception:
  182. pass
  183. return 0
  184. def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
  185. if fmt == "jsonl":
  186. records = []
  187. with open(file_path, encoding="utf-8") as f:
  188. for i, line in enumerate(f):
  189. if i >= n:
  190. break
  191. line = line.strip()
  192. if line:
  193. records.append(json.loads(line))
  194. return records
  195. elif fmt == "json":
  196. with open(file_path, encoding="utf-8") as f:
  197. data = json.load(f)
  198. return data[:n] if isinstance(data, list) else [data]
  199. elif fmt == "csv":
  200. import csv
  201. with open(file_path, encoding="utf-8") as f:
  202. reader = csv.DictReader(f)
  203. return [dict(row) for i, row in enumerate(reader) if i < n]
  204. elif fmt == "parquet":
  205. import pandas as pd
  206. df = pd.read_parquet(file_path)
  207. return df.head(n).to_dict(orient="records")
  208. return []