|
|
@@ -15,12 +15,57 @@ from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
|
|
|
|
|
|
settings = get_settings()
|
|
|
|
|
|
+# Known metadata filenames that are NOT training data
|
|
|
+META_FILENAMES = frozenset({
|
|
|
+ "configuration.json", "configuration.yaml", "README.md",
|
|
|
+ ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
|
|
|
+ "special_tokens_map.json", "tokenizer_config.json",
|
|
|
+ "added_tokens.json", "vocab.json", "merges.txt",
|
|
|
+ "config.json", "preprocessor_config.json",
|
|
|
+})
|
|
|
+
|
|
|
+# File size threshold: files smaller than this (bytes) are likely metadata
|
|
|
+META_SIZE_THRESHOLD = 500
|
|
|
+
|
|
|
+
|
|
|
+def _is_training_data_file(path: Path) -> bool:
|
|
|
+ """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
|
|
|
+ if path.suffix in (".jsonl", ".parquet", ".csv"):
|
|
|
+ return True
|
|
|
+ if path.suffix == ".json":
|
|
|
+ if path.name in META_FILENAMES:
|
|
|
+ return False
|
|
|
+ # 小 JSON 文件通常是配置
|
|
|
+ if path.stat().st_size < META_SIZE_THRESHOLD:
|
|
|
+ return False
|
|
|
+ # 尝试读取首行判断格式
|
|
|
+ try:
|
|
|
+ first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
|
|
|
+ obj = json.loads(first_line)
|
|
|
+ # 如果有 input/output/conversation/instruction 等字段,则是训练数据
|
|
|
+ if isinstance(obj, dict):
|
|
|
+ data_keys = {"input", "output", "conversations", "instruction", "prompt",
|
|
|
+ "text", "completion", "source", "target", "query", "response"}
|
|
|
+ if data_keys & set(obj.keys()):
|
|
|
+ return True
|
|
|
+ return True # 大 JSON 文件默认是数据
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
+ # 无后缀文件:尝试读取判断是否为 JSON/JSONL
|
|
|
+ if not path.suffix:
|
|
|
+ try:
|
|
|
+ first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
|
|
|
+ json.loads(first_line)
|
|
|
+ return True
|
|
|
+ except Exception:
|
|
|
+ return False
|
|
|
+ return False
|
|
|
+
|
|
|
|
|
|
async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
|
|
|
"""从 HuggingFace 或 ModelScope 下载数据集。"""
|
|
|
try:
|
|
|
if req.use_modelscope:
|
|
|
- # 用 asyncio.to_thread 包裹同步下载,避免阻塞事件循环
|
|
|
ds_dir, jsonl_path, record_count = await asyncio.to_thread(_download_modelscope_dataset, req.dataset_id)
|
|
|
else:
|
|
|
from datasets import load_dataset
|
|
|
@@ -39,7 +84,6 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
|
|
|
jsonl_path = output_path
|
|
|
record_count = len(split) if hasattr(split, "__len__") else 0
|
|
|
|
|
|
- # 写入数据库
|
|
|
record = DatasetRecord(
|
|
|
id=str(uuid.uuid4()),
|
|
|
name=req.dataset_id,
|
|
|
@@ -66,25 +110,21 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
|
|
|
ds_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- # 用 snapshot_download 下载数据集文件到本地
|
|
|
local_path = snapshot_download(dataset_id, cache_dir=str(settings.processed_dir))
|
|
|
|
|
|
- # 扫描下载目录中的 JSON/JSONL 文件
|
|
|
- data_files = []
|
|
|
- for p in Path(local_path).rglob("*"):
|
|
|
- if p.is_file() and p.suffix in (".json", ".jsonl"):
|
|
|
- data_files.append(p)
|
|
|
+ # 扫描所有文件,识别训练数据文件
|
|
|
+ all_files = [p for p in Path(local_path).rglob("*") if p.is_file()]
|
|
|
+ data_files = [f for f in all_files if _is_training_data_file(f)]
|
|
|
|
|
|
if not data_files:
|
|
|
- raise ValueError(f"No JSON/JSONL files found in dataset {dataset_id}")
|
|
|
-
|
|
|
- # 排除元数据文件(framework/task/configuration 等),只保留数据文件
|
|
|
- meta_patterns = ("configuration.json", ".mdl", ".msc", ".mv")
|
|
|
- data_files = [f for f in data_files if f.name not in meta_patterns
|
|
|
- and not (f.suffix == ".json" and len(f.read_bytes()) < 200)]
|
|
|
-
|
|
|
- if not data_files:
|
|
|
- raise ValueError(f"No training data files found in dataset {dataset_id}")
|
|
|
+ # 回退:列出所有 JSON/JSONL 文件方便调试
|
|
|
+ fallback = [f for f in all_files if f.suffix in (".json", ".jsonl")]
|
|
|
+ logger.warning(f"No training data files found in {dataset_id}. "
|
|
|
+ f"Available JSON files: {[f.name for f in fallback]}")
|
|
|
+ if fallback:
|
|
|
+ data_files = fallback
|
|
|
+ else:
|
|
|
+ raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}")
|
|
|
|
|
|
# 优先取 train / data 开头的文件
|
|
|
target = None
|
|
|
@@ -96,9 +136,11 @@ def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
|
|
|
if target:
|
|
|
break
|
|
|
if not target:
|
|
|
- # 优先取数据量大的文件(通常是实际训练数据)
|
|
|
+ # 优先取数据量最大的文件
|
|
|
target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0]
|
|
|
|
|
|
+ logger.info(f"Selected data file: {target} (size={target.stat().st_size})")
|
|
|
+
|
|
|
# 读取并统一转为 JSONL
|
|
|
jsonl_path = ds_dir / "data.jsonl"
|
|
|
record_count = 0
|
|
|
@@ -123,7 +165,6 @@ async def upload_dataset(file: UploadFile) -> dict[str, Any]:
|
|
|
upload_dir = settings.uploads_dir
|
|
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
- # 避免文件名冲突
|
|
|
safe_name = file.filename or "unknown"
|
|
|
file_path = upload_dir / safe_name
|
|
|
if file_path.exists():
|
|
|
@@ -183,7 +224,6 @@ def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
|
|
|
flat_rows = []
|
|
|
for row in records:
|
|
|
conversations = row.get("conversations", [])
|
|
|
- # 每轮 user+assistant 对话作为一行
|
|
|
for i in range(0, len(conversations) - 1, 2):
|
|
|
user_turn = conversations[i]
|
|
|
assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
|
|
|
@@ -195,7 +235,6 @@ def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
|
|
|
input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
|
|
|
output_text = str(user_turn.get("value", ""))
|
|
|
|
|
|
- # 截断过长文本
|
|
|
if len(input_text) > 500:
|
|
|
input_text = input_text[:500] + "..."
|
|
|
if len(output_text) > 500:
|
|
|
@@ -259,18 +298,15 @@ async def validate_dataset(dataset_id: str) -> dict[str, Any]:
|
|
|
errors = []
|
|
|
warnings = []
|
|
|
|
|
|
- # 检查格式
|
|
|
fmt = record.format
|
|
|
if fmt not in ("jsonl", "csv", "json", "parquet"):
|
|
|
errors.append(f"Unsupported format: {fmt}")
|
|
|
|
|
|
- # 检查内容
|
|
|
try:
|
|
|
preview = _read_records(file_path, fmt, 5)
|
|
|
if not preview:
|
|
|
warnings.append("Dataset appears to be empty")
|
|
|
else:
|
|
|
- # 检查必需字段(SFT 格式)
|
|
|
first = preview[0]
|
|
|
has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
|
|
|
if not has_sft_fields:
|
|
|
@@ -310,7 +346,6 @@ async def delete_dataset(dataset_id: str) -> dict[str, Any]:
|
|
|
result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
|
|
|
record = result.scalar_one_or_none()
|
|
|
if record:
|
|
|
- # 删除文件
|
|
|
file_path = Path(record.file_path)
|
|
|
if file_path.exists():
|
|
|
file_path.unlink()
|
|
|
@@ -339,7 +374,7 @@ def _count_records(file_path: Path, fmt: str) -> int:
|
|
|
elif fmt == "csv":
|
|
|
import csv
|
|
|
with open(file_path, encoding="utf-8") as f:
|
|
|
- return sum(1 for _ in csv.reader(f)) - 1 # minus header
|
|
|
+ return sum(1 for _ in csv.reader(f)) - 1
|
|
|
elif fmt == "parquet":
|
|
|
import pandas as pd
|
|
|
return len(pd.read_parquet(file_path))
|