lxylxy123321 1 неделя назад
Родитель
Сommit
67f7211d8b
1 измененных файлов с 61 добавлено и 26 удалено
  1. 61 26
      backend/app/services/dataset_service.py

+ 61 - 26
backend/app/services/dataset_service.py

@@ -15,12 +15,57 @@ from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
 
 settings = get_settings()
 
+# Known metadata filenames that are NOT training data
+META_FILENAMES = frozenset({
+    "configuration.json", "configuration.yaml", "README.md",
+    ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
+    "special_tokens_map.json", "tokenizer_config.json",
+    "added_tokens.json", "vocab.json", "merges.txt",
+    "config.json", "preprocessor_config.json",
+})
+
+# File size threshold: files smaller than this (bytes) are likely metadata
+META_SIZE_THRESHOLD = 500
+
+
+def _is_training_data_file(path: Path) -> bool:
+    """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
+    if path.suffix in (".jsonl", ".parquet", ".csv"):
+        return True
+    if path.suffix == ".json":
+        if path.name in META_FILENAMES:
+            return False
+        # 小 JSON 文件通常是配置
+        if path.stat().st_size < META_SIZE_THRESHOLD:
+            return False
+        # 尝试读取首行判断格式
+        try:
+            first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
+            obj = json.loads(first_line)
+            # 如果有 input/output/conversation/instruction 等字段,则是训练数据
+            if isinstance(obj, dict):
+                data_keys = {"input", "output", "conversations", "instruction", "prompt",
+                             "text", "completion", "source", "target", "query", "response"}
+                if data_keys & set(obj.keys()):
+                    return True
+            return True  # 大 JSON 文件默认是数据
+        except Exception:
+            return False
+    # 无后缀文件:尝试读取判断是否为 JSON/JSONL
+    if not path.suffix:
+        try:
+            first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
+            json.loads(first_line)
+            return True
+        except Exception:
+            return False
+    return False
+
 
 async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
     """从 HuggingFace 或 ModelScope 下载数据集。"""
     try:
         if req.use_modelscope:
-            # 用 asyncio.to_thread 包裹同步下载,避免阻塞事件循环
             ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
         else:
             from datasets import load_dataset
@@ -39,7 +84,6 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
             jsonl_path = output_path
             record_count = len(split) if hasattr(split, "__len__") else 0
 
-        # 写入数据库
         record = DatasetRecord(
             id=str(uuid.uuid4()),
             name=req.dataset_id,
@@ -66,25 +110,21 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
     ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
     ds_dir.mkdir(parents=True, exist_ok=True)
 
-    # 用 snapshot_download 下载数据集文件到本地
     local_path = snapshot_download(dataset_id, cache_dir=str(settings.processed_dir))
 
-    # 扫描下载目录中的 JSON/JSONL 文件
-    data_files = []
-    for p in Path(local_path).rglob("*"):
-        if p.is_file() and p.suffix in (".json", ".jsonl"):
-            data_files.append(p)
+    # 扫描所有文件,识别训练数据文件
+    all_files = [p for p in Path(local_path).rglob("*") if p.is_file()]
+    data_files = [f for f in all_files if _is_training_data_file(f)]
 
     if not data_files:
-        raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
-
-    # 排除元数据文件(framework/task/configuration 等),只保留数据文件
-    meta_patterns = ("configuration.json", ".mdl", ".msc", ".mv")
-    data_files = [f for f in data_files if f.name not in meta_patterns
-                  and not (f.suffix == ".json" and len(f.read_bytes()) < 200)]
-
-    if not data_files:
-        raise ValueError(f"No training data files found in dataset {dataset_id}")
+        # 回退:列出所有 JSON/JSONL 文件方便调试
+        fallback = [f for f in all_files if f.suffix in (".json", ".jsonl")]
+        logger.warning(f"No training data files found in {dataset_id}. "
+                       f"Available JSON files: {[f.name for f in fallback]}")
+        if fallback:
+            data_files = fallback
+        else:
+            raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}")
 
     # 优先取 train / data 开头的文件
     target = None
@@ -96,9 +136,11 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
         if target:
             break
     if not target:
-        # 优先取数据量大的文件(通常是实际训练数据)
+        # 优先取数据量大的文件
         target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0]
 
+    logger.info(f"Selected data file: {target} (size={target.stat().st_size})")
+
     # 读取并统一转为 JSONL
     jsonl_path = ds_dir / "data.jsonl"
     record_count = 0
@@ -123,7 +165,6 @@ async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     upload_dir = settings.uploads_dir
     upload_dir.mkdir(parents=True, exist_ok=True)
 
-    # 避免文件名冲突
     safe_name = file.filename or "unknown"
     file_path = upload_dir / safe_name
     if file_path.exists():
@@ -183,7 +224,6 @@ def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
     flat_rows = []
     for row in records:
         conversations = row.get("conversations", [])
-        # 每轮 user+assistant 对话作为一行
         for i in range(0, len(conversations) - 1, 2):
             user_turn = conversations[i]
             assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
@@ -195,7 +235,6 @@ def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
                 input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
                 output_text = str(user_turn.get("value", ""))
 
-            # 截断过长文本
             if len(input_text) > 500:
                 input_text = input_text[:500] + "..."
             if len(output_text) > 500:
@@ -259,18 +298,15 @@ async def validate_dataset(dataset_id: str) -> dict[str, Any]:
     errors = []
     warnings = []
 
-    # 检查格式
     fmt = record.format
     if fmt not in ("jsonl", "csv", "json", "parquet"):
         errors.append(f"Unsupported format: {fmt}")
 
-    # 检查内容
     try:
         preview = _read_records(file_path, fmt, 5)
         if not preview:
             warnings.append("Dataset appears to be empty")
         else:
-            # 检查必需字段(SFT 格式)
             first = preview[0]
             has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
             if not has_sft_fields:
@@ -310,7 +346,6 @@ async def delete_dataset(dataset_id: str) -> dict[str, Any]:
         result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
         record = result.scalar_one_or_none()
         if record:
-            # 删除文件
             file_path = Path(record.file_path)
             if file_path.exists():
                 file_path.unlink()
@@ -339,7 +374,7 @@ def _count_records(file_path: Path, fmt: str) -> int:
         elif fmt == "csv":
             import csv
             with open(file_path, encoding="utf-8") as f:
-                return sum(1 for _ in csv.reader(f)) - 1  # minus header
+                return sum(1 for _ in csv.reader(f)) - 1
         elif fmt == "parquet":
             import pandas as pd
             return len(pd.read_parquet(file_path))