import asyncio import json import uuid from datetime import datetime, timezone from pathlib import Path from typing import Any from fastapi import UploadFile from app.config import get_settings from app.core.background_tasks import background_task_manager from app.core.db import async_session, DatasetRecord, DatasetDownloadTask from app.core.logging import logger from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse from sqlalchemy import select settings = get_settings() # Known metadata filenames that are NOT training data META_FILENAMES = frozenset({ "configuration.json", "configuration.yaml", "README.md", ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json", "special_tokens_map.json", "tokenizer_config.json", "added_tokens.json", "vocab.json", "merges.txt", "config.json", "preprocessor_config.json", "dataset_infos.json", "dataset_info.json", "state.json", "card_data.json", }) # File size threshold: files smaller than this (bytes) are likely metadata META_SIZE_THRESHOLD = 500 def _is_training_data_file(path: Path) -> bool: """判断文件是否可能是训练数据文件(而非配置/元数据)。""" if path.name in META_FILENAMES: return False if path.suffix in (".jsonl", ".parquet", ".csv"): # 小文件可能是元数据(如 ModelScope CLI 生成的 data.jsonl 只有几十字节) if path.stat().st_size < META_SIZE_THRESHOLD: return False return True if path.suffix == ".json": # 小 JSON 文件通常是配置 if path.stat().st_size < META_SIZE_THRESHOLD: return False # 尝试读取首行判断格式 try: first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip() obj = json.loads(first_line) # 如果有 input/output/conversation/instruction 等字段,则是训练数据 if isinstance(obj, dict): data_keys = {"input", "output", "conversations", "instruction", "prompt", "text", "completion", "source", "target", "query", "response"} if data_keys & set(obj.keys()): return True # 如果只有 framework/task/model_type 等字段,则是元数据 meta_keys = {"framework", "task", "license", "base_model", "model_type", "language", "domains", "tags", "authors"} if meta_keys & set(obj.keys()): return False return True # 大 JSON 文件默认是数据 except Exception: return False # 无后缀文件:尝试读取判断是否为 JSON/JSONL if not path.suffix: try: first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip() json.loads(first_line) return True except Exception: return False return False def _extract_archives(ds_dir: Path): """检测并解压数据集目录中的压缩包(zip/tar.gz/tar.bz2/tgz), 图片数据集通常将图片存放在压缩包中,需要解压后才能在预览时显示。""" import zipfile import tarfile extracted_any = False for f in list(ds_dir.rglob("*")): if not f.is_file(): continue # 判断是否为压缩包 name_lower = f.name.lower() is_zip = name_lower.endswith(".zip") is_tar = any(name_lower.endswith(ext) for ext in (".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar")) if not is_zip and not is_tar: continue # 用压缩包名(去掉所有后缀)作为解压目标目录 stem = f.name for ext in (".tar.gz", ".tar.bz2", ".tgz", ".tbz2", ".tar", ".zip"): if stem.lower().endswith(ext): stem = stem[:-len(ext)] break extract_dir = f.parent / stem if extract_dir.exists(): logger.info(f"Archive already extracted, skipping: {f.name}") continue logger.info(f"Extracting archive: {f.name} -> {extract_dir}") try: if is_zip: with zipfile.ZipFile(f, "r") as zf: zf.extractall(f.parent) else: with tarfile.open(f, "r:*") as tf: tf.extractall(f.parent) extracted_any = True logger.info(f"Successfully extracted: {f.name}") except Exception as e: logger.warning(f"Failed to extract {f.name}: {e}") if extracted_any: logger.info(f"Archive extraction completed for {ds_dir}") def _download_modelscope_data_files(dataset_id: str, ds_dir: Path): """通过 ModelScope API 下载图片数据集的压缩包。 图片数据集有一个 {dataset_name}.json 配置文件,记录了各 split 对应的 元数据文件和压缩包名称,例如: {"default": {"train": {"meta": "train.csv", "file": "train.zip"}, "validation": {"meta": "val.csv", "file": "val.zip"}}} CLI download 只下载 git 仓库文件(CSV 等元数据), 压缩包需要通过 /api/v1/datasets/{ns}/{name}/repo?FilePath=... 单独下载。 """ import urllib.request import urllib.parse api_base = "https://www.modelscope.cn" namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id) # Step 1: 找到配置文件 {dataset_name}.json 并读取 config_files = [p for p in ds_dir.glob("*.json") if p.name not in META_FILENAMES] if not config_files: # 尝试通过 API 下载配置文件 config_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo" f"?Source=SDK&Revision=master&FilePath={ds_name}.json&View=false") try: logger.info(f"尝试下载配置文件: {ds_name}.json") req = urllib.request.Request(config_url, headers={"User-Agent": "FineTuning-Backend"}) with urllib.request.urlopen(req, timeout=30) as resp: config_data = json.loads(resp.read().decode()) config_path = ds_dir / f"{ds_name}.json" config_path.write_text(json.dumps(config_data, ensure_ascii=False), encoding="utf-8") config_files = [config_path] except Exception as e: logger.info(f"未找到配置文件 {ds_name}.json,跳过数据文件下载: {e}") return # 在所有 json 配置文件中找到包含 "file" 字段的那个 config = None for cf in config_files: try: data = json.loads(cf.read_text(encoding="utf-8")) # 检查是否包含 file 字段(数据集配置格式) if isinstance(data, dict): for subset in data.values(): if isinstance(subset, dict): for split_info in subset.values(): if isinstance(split_info, dict) and "file" in split_info and split_info["file"]: config = data break if config: break if config: break except (json.JSONDecodeError, UnicodeDecodeError): continue if not config: logger.info("未找到包含数据文件引用的配置文件,跳过") return logger.info(f"找到数据文件配置: {json.dumps(config, ensure_ascii=False)}") # Step 2: 收集所有需要下载的压缩包文件名 archive_files = set() for subset in config.values(): if not isinstance(subset, dict): continue for split_info in subset.values(): if isinstance(split_info, dict): fname = split_info.get("file", "") if fname: archive_files.add(fname) if not archive_files: logger.info("配置中未找到数据文件,跳过") return # Step 3: 下载压缩包 existing = {f.name for f in ds_dir.rglob("*") if f.is_file()} downloaded = [] for fname in archive_files: if fname in existing: logger.info(f"压缩包已存在,跳过: {fname}") continue params = urllib.parse.urlencode({ "Source": "SDK", "Revision": "master", "FilePath": fname, "View": "false", }) dl_url = f"{api_base}/api/v1/datasets/{dataset_id}/repo?{params}" dest = ds_dir / fname logger.info(f"下载数据文件: {fname}") logger.info(f" URL: {dl_url}") try: req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"}) with urllib.request.urlopen(req, timeout=600) as resp: dest.write_bytes(resp.read()) downloaded.append(fname) logger.info(f" 下载完成: {fname} ({dest.stat().st_size / 1024 / 1024:.1f}MB)") except Exception as e: logger.warning(f" 下载失败 {fname}: {e}") if downloaded: logger.info(f"共下载 {len(downloaded)} 个数据文件: {downloaded}") else: logger.info("没有需要下载的数据文件") async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse: """启动数据集下载后台任务,立即返回 task_id。""" task_id = str(uuid.uuid4()) # 写 DB record = DatasetDownloadTask( id=task_id, dataset_id=req.dataset_id, use_modelscope=1 if req.use_modelscope else 0, status="pending", ) async with async_session() as session: session.add(record) await session.commit() # 注册并启动 background_task_manager.register_task(task_id, "dataset_download", {"dataset_id": req.dataset_id}) await background_task_manager.run( task_id, "dataset_download", _execute_dataset_download(task_id, req) ) logger.info(f"Dataset download task started: {req.dataset_id} (task_id={task_id})") return DatasetDownloadResponse( dataset_id=req.dataset_id, status="pending", task_id=task_id, path=task_id ) async def _execute_dataset_download(task_id: str, req: DatasetDownloadRequest) -> dict: """后台执行数据集下载。""" try: if req.use_modelscope: ds_dir, jsonl_path, record_count = await asyncio.to_thread( _download_modelscope_dataset, req.dataset_id ) else: from datasets import load_dataset ds = load_dataset(req.dataset_id) ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}" ds_dir.mkdir(parents=True, exist_ok=True) if "train" in ds: split = ds["train"] else: split = ds[list(ds.keys())[0]] output_path = ds_dir / "data.jsonl" with open(output_path, "w", encoding="utf-8") as f: for item in split: f.write(json.dumps(item, ensure_ascii=False) + "\n") jsonl_path = output_path record_count = len(split) if hasattr(split, "__len__") else 0 db_record = DatasetRecord( id=str(uuid.uuid4()), name=req.dataset_id, format="jsonl", record_count=record_count, file_path=str(jsonl_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(db_record) await session.commit() await _update_dataset_download_status(task_id, "completed", path=str(jsonl_path), record_count=record_count) logger.info(f"Dataset downloaded: {req.dataset_id} ({record_count} records)") return {"path": str(jsonl_path), "record_count": record_count} except Exception as e: logger.error(f"Dataset download failed: {e}") await _update_dataset_download_status(task_id, "failed", error=str(e)) return {"error": str(e)} async def _update_dataset_download_status(task_id: str, status: str, path: str = None, error: str = None, record_count: int = 0): async with async_session() as session: result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id)) record = result.scalar_one_or_none() if record: record.status = status if path: record.path = path if error: record.error = error if record_count: record.record_count = record_count if status in ("completed", "failed"): record.finished_at = datetime.utcnow() await session.commit() background_task_manager.update_task( task_id, status=status, path=path, error=error, record_count=record_count, ) async def get_dataset_download_status(task_id: str) -> dict[str, Any]: async with async_session() as session: result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id)) record = result.scalar_one_or_none() if record: return { "task_id": record.id, "dataset_id": record.dataset_id, "status": record.status, "use_modelscope": bool(record.use_modelscope), "path": record.path, "error": record.error, "record_count": record.record_count, "created_at": record.created_at.isoformat() if record.created_at else "", } mem = background_task_manager.get_task(task_id) if mem: return { "task_id": task_id, "dataset_id": mem.get("dataset_id", ""), "status": mem["status"], "error": mem.get("error"), "record_count": mem.get("record_count", 0), } return {"task_id": task_id, "status": "not_found"} async def list_dataset_downloads() -> list[dict[str, Any]]: async with async_session() as session: result = await session.execute( select(DatasetDownloadTask).order_by(DatasetDownloadTask.created_at.desc()) ) records = result.scalars().all() return [ { "task_id": r.id, "dataset_id": r.dataset_id, "status": r.status, "use_modelscope": bool(r.use_modelscope), "path": r.path, "error": r.error, "record_count": r.record_count, "created_at": r.created_at.isoformat() if r.created_at else "", } for r in records ] async def cancel_dataset_download(task_id: str) -> dict[str, Any]: background_task_manager.cancel_task(task_id) async with async_session() as session: result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id)) record = result.scalar_one_or_none() if record and record.status in ("pending", "downloading"): record.status = "cancelled" record.error = "Cancelled by user" record.finished_at = datetime.utcnow() await session.commit() return {"task_id": task_id, "status": "cancelled"} async def recover_stale_downloads() -> None: async with async_session() as session: result = await session.execute( select(DatasetDownloadTask).where( DatasetDownloadTask.status.in_(["pending", "downloading"]) ) ) records = result.scalars().all() for record in records: record.status = "failed" record.error = "Server restarted, task interrupted" record.finished_at = datetime.utcnow() if records: await session.commit() logger.info(f"Recovered {len(records)} stale dataset download tasks") def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]: """用 MsDataset.load() 下载数据集,支持图片数据集(自动从 CDN 下载图片)。 如果 MsDataset.load() 失败,fallback 到 CLI 方式。""" namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id) ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}" ds_dir.mkdir(parents=True, exist_ok=True) # 优先用 MsDataset.load(),它能自动下载"数据文件"区的图片 try: records, record_count = _download_via_msdataset(dataset_id, ds_dir) if records: jsonl_path = ds_dir / "data.jsonl" with open(jsonl_path, "w", encoding="utf-8") as f: for item in records: f.write(json.dumps(item, ensure_ascii=False) + "\n") logger.info(f"MsDataset.load() 成功: {dataset_id} ({record_count} records)") return ds_dir, jsonl_path, record_count except Exception as e: logger.warning(f"MsDataset.load() failed for {dataset_id}: {e}, falling back to CLI") # fallback: CLI 方式(只下载 git 仓库文件,不含数据文件区图片) return _download_modelscope_dataset_cli(dataset_id, ds_dir) def _download_via_msdataset(dataset_id: str, ds_dir: Path) -> tuple[list[dict], int]: """用 MsDataset.load() 下载数据集,处理图片列(复制图片文件到数据集目录)。""" from modelscope.msdatasets import MsDataset from PIL import Image import shutil import os namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id) images_dir = ds_dir / "images" # 尝试加载不同 split ds = None for split in ("train", "validation", "test"): try: if namespace: ds = MsDataset.load(ds_name, namespace=namespace, split=split) else: ds = MsDataset.load(dataset_id, split=split) if ds: logger.info(f"MsDataset.load() loaded split '{split}': {len(ds) if hasattr(ds, '__len__') else '?'} records") break except Exception as e: logger.debug(f"split '{split}' failed: {e}") if not ds: # 不带 split 参数试试 try: if namespace: ds = MsDataset.load(ds_name, namespace=namespace) else: ds = MsDataset.load(dataset_id) except Exception as e: logger.warning(f"MsDataset.load() without split also failed: {e}") return [], 0 if not ds: return [], 0 # 检查是否 iterable if not hasattr(ds, '__iter__'): return [], 0 records = [] img_counter = 0 columns = None for row in ds: if not isinstance(row, dict): continue if columns is None: columns = list(row.keys()) record = {} for k, v in row.items(): # 检查是否是 PIL.Image 对象 if isinstance(v, Image.Image): # 图片对象:保存到磁盘,记录相对路径 images_dir.mkdir(parents=True, exist_ok=True) img_name = f"{img_counter:06d}.jpg" img_path = images_dir / img_name if v.mode in ("RGBA", "P", "LA"): v = v.convert("RGB") v.save(str(img_path), format="JPEG", quality=90) record[k] = f"images/{img_name}" img_counter += 1 # 检查是否是图片文件路径 elif isinstance(v, str) and v.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')): # 如果是绝对路径,复制文件到 images 目录 if os.path.isabs(v) and os.path.exists(v): images_dir.mkdir(parents=True, exist_ok=True) ext = os.path.splitext(v)[1] img_name = f"{img_counter:06d}{ext}" dest_path = images_dir / img_name try: shutil.copy2(v, dest_path) record[k] = f"images/{img_name}" img_counter += 1 except Exception as e: logger.warning(f"Failed to copy image {v}: {e}") record[k] = v else: # 相对路径或其他情况,保持原样 record[k] = v else: record[k] = v records.append(record) # 进度日志 if len(records) % 500 == 0: logger.info(f" 处理中... {len(records)} records, {img_counter} images saved") if img_counter > 0: logger.info(f"共保存 {img_counter} 张图片到 {images_dir}") return records, len(records) def _download_modelscope_dataset_cli(dataset_id: str, ds_dir: Path) -> tuple[Path, Path, int]: """CLI 方式下载数据集(fallback,只下载 git 仓库文件)。""" import subprocess cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)] logger.info(f"Fallback CLI: {' '.join(cmd)}") proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600) if proc.returncode != 0: logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}") raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}") # CLI 下载完 git 仓库文件后(数据文件区的图片需要通过 MsDataset.load 获取,CLI 只能下载元数据) # 扫描下载目录中的所有文件 all_files = [p for p in ds_dir.rglob("*") if p.is_file()] logger.info(f"CLI downloaded {len(all_files)} files to {ds_dir}") # 识别训练数据文件 data_files = [f for f in all_files if _is_training_data_file(f)] if not data_files: fallback = [f for f in all_files if f.suffix in (".json", ".jsonl") and f.name not in META_FILENAMES and f.name != "README.md"] logger.warning(f"No training data files found in {dataset_id}. " f"Available JSON files: {[f.name for f in fallback]}") if fallback: data_files = fallback else: # 如果还是没有,列出所有文件供排查 logger.error(f"All downloaded files: {[str(f.relative_to(ds_dir)) for f in all_files]}") raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}. " f"Available files: {[f.name for f in all_files]}") # 按文件大小排序,取最大的文件作为训练数据(真正的数据集通常是最大的) target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0] logger.info(f"Selected data file: {target} (size={target.stat().st_size})") # 读取并统一转为 JSONL jsonl_path = ds_dir / "data.jsonl" record_count = 0 content = target.read_text(encoding="utf-8") if target.suffix == ".jsonl" or not target.suffix: # JSONL 或无后缀文件:逐行解析 records = [] for line in content.splitlines(): line = line.strip() if not line: continue try: records.append(json.loads(line)) except json.JSONDecodeError: # 单行解析失败,尝试整体解析 try: data = json.loads(content) records = data if isinstance(data, list) else [data] except json.JSONDecodeError: records = [] break elif target.suffix == ".json": # JSON 文件:先尝试 JSON 数组,失败再逐行解析(可能是 JSONL 格式) try: records = json.loads(content) if not isinstance(records, list): records = [records] except json.JSONDecodeError: records = [] for line in content.splitlines(): line = line.strip() if not line: continue try: records.append(json.loads(line)) except json.JSONDecodeError: continue elif target.suffix == ".csv": import csv as _csv records = [] reader = _csv.DictReader(content.splitlines()) for row in reader: records.append(dict(row)) else: records = [] with open(jsonl_path, "w", encoding="utf-8") as f: for item in records: f.write(json.dumps(item, ensure_ascii=False) + "\n") record_count += 1 return ds_dir, jsonl_path, record_count async def upload_dataset(file: UploadFile) -> dict[str, Any]: """保存上传文件并写入数据库。""" upload_dir = settings.uploads_dir upload_dir.mkdir(parents=True, exist_ok=True) safe_name = file.filename or "unknown" file_path = upload_dir / safe_name if file_path.exists(): file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}" content = await file.read() file_path.write_bytes(content) fmt = _detect_format(safe_name) record_count = _count_records(file_path, fmt) record_id = str(uuid.uuid4()) record = DatasetRecord( id=record_id, name=safe_name, format=fmt, record_count=record_count, file_path=str(file_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})") return { "id": record_id, "name": safe_name, "format": fmt, "record_count": record_count, "file_path": str(file_path), "created_at": record.created_at.isoformat(), } def _detect_image_column(columns: list[str]) -> str | None: """检测哪一列是图片路径列。""" candidates = ["image_path", "image", "img_path", "img", "file_path", "filename", "path", "file"] for c in candidates: if c in columns: return c # 模糊匹配:列名包含 image 或 path for c in columns: cl = c.lower() if "image" in cl or ("path" in cl and "label" not in cl): return c return None def _resolve_image_path(path_str: str, data_dir: Path) -> Path | None: """解析图片路径,返回绝对路径。""" if not path_str: return None p = Path(path_str) # 已经是绝对路径 if p.is_absolute(): return p if p.exists() else None # 相对路径:先尝试相对于数据目录 candidate = data_dir / p if candidate.exists(): return candidate # 也可能直接在 data_dir 下(去掉目录前缀只保留文件名) if data_dir.joinpath(p.name).exists(): return data_dir / p.name # 在 data_dir 的子目录中递归查找 for child in data_dir.rglob(p.name): if child.is_file(): return child logger.debug(f"Image not found: '{path_str}' (searched in {data_dir})") return None def _encode_image_base64(image_path: Path, max_size: int = 200) -> str | None: """将图片转为 base64 data URI,用于前端预览。""" import base64 try: from PIL import Image img = Image.open(image_path) # 缩小尺寸用于预览 img.thumbnail((max_size, max_size)) if img.mode in ("RGBA", "P", "LA"): img = img.convert("RGB") import io buf = io.BytesIO() img.save(buf, format="JPEG", quality=85) b64 = base64.b64encode(buf.getvalue()).decode("ascii") return f"data:image/jpeg;base64,{b64}" except Exception as e: logger.warning(f"Failed to encode image {image_path}: {e}") return None def _format_value(value) -> str: """将复杂值格式化为可读字符串。""" if isinstance(value, (dict, list)): return json.dumps(value, ensure_ascii=False, indent=2) return str(value) def _is_sharegpt_format(records: list[dict]) -> bool: """检测是否为 ShareGPT 格式。""" if not records: return False first = records[0] if "conversations" in first and isinstance(first["conversations"], list): if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict): conv = first["conversations"][0] return "from" in conv and "value" in conv return False def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]: """将 ShareGPT 格式展平为 input/output 列。""" flat_rows = [] for row in records: conversations = row.get("conversations", []) for i in range(0, len(conversations) - 1, 2): user_turn = conversations[i] assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None if user_turn.get("from") in ("human", "user"): input_text = str(user_turn.get("value", "")) output_text = str(assistant_turn.get("value", "")) if assistant_turn else "" else: input_text = str(assistant_turn.get("value", "")) if assistant_turn else "" output_text = str(user_turn.get("value", "")) if len(input_text) > 500: input_text = input_text[:500] + "..." if len(output_text) > 500: output_text = output_text[:500] + "..." flat_rows.append({"input": input_text, "output": output_text}) return flat_rows, ["input", "output"] async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]: """预览数据集前 N 行。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id)) record = result.scalar_one_or_none() if not record: return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None} file_path = Path(record.file_path) if not file_path.exists(): return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None} fmt = record.format preview_data = _read_records(file_path, fmt, rows) # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列 if _is_sharegpt_format(preview_data): preview_data, columns = _flatten_sharegpt(preview_data) else: columns = list(preview_data[0].keys()) if preview_data else [] # 检测是否为视觉数据集(有图片路径列),将图片转为 base64 嵌入预览 image_column = _detect_image_column(columns) data_dir = file_path.parent preview_rows = [] for i, row in enumerate(preview_data): data = {} for k, v in row.items(): if k == image_column: # 解析图片路径,转为 base64 嵌入 img_path = _resolve_image_path(str(v), data_dir) if img_path: encoded = _encode_image_base64(img_path) data[k] = encoded if encoded else str(v) else: # 路径解析失败,保留原始路径文本 data[k] = str(v) else: data[k] = _format_value(v) preview_rows.append({"row_index": i, "data": data}) return { "total_records": record.record_count, "preview_rows": preview_rows, "columns": columns, "image_column": image_column, } async def validate_dataset(dataset_id: str) -> dict[str, Any]: """校验数据集格式和 Schema。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id)) record = result.scalar_one_or_none() if not record: return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []} file_path = Path(record.file_path) if not file_path.exists(): return {"is_valid": False, "errors": ["File not found"], "warnings": []} errors = [] warnings = [] fmt = record.format if fmt not in ("jsonl", "csv", "json", "parquet"): errors.append(f"Unsupported format: {fmt}") try: preview = _read_records(file_path, fmt, 5) if not preview: warnings.append("Dataset appears to be empty") else: first = preview[0] has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion")) if not has_sft_fields: warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}") except Exception as e: errors.append(f"Failed to read file: {str(e)}") return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings} async def list_datasets() -> list[dict[str, Any]]: """列出所有已上传数据集。""" async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc())) records = result.scalars().all() return [ { "id": r.id, "name": r.name, "format": r.format, "record_count": r.record_count, "file_path": r.file_path, "created_at": r.created_at.isoformat(), } for r in records ] async def delete_dataset(dataset_id: str) -> dict[str, Any]: """删除数据集,同时清理关联的目录文件。""" import shutil async with async_session() as session: from sqlalchemy import select result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id)) record = result.scalar_one_or_none() if record: file_path = Path(record.file_path) # 下载的数据集(processed_dir 下的子目录):删除整个目录 if file_path.exists() and settings.processed_dir in file_path.parents: ds_dir = file_path.parent shutil.rmtree(ds_dir, ignore_errors=True) logger.info(f"Deleted dataset directory: {ds_dir}") elif file_path.exists(): # 上传的数据集:只删文件 file_path.unlink() await session.delete(record) await session.commit() logger.info(f"Deleted dataset: {record.name}") return {"status": "deleted"} def _detect_format(filename: str) -> str: ext = Path(filename).suffix.lower().lstrip(".") if ext in ("jsonl", "csv", "parquet", "json"): return ext return "unknown" def _count_records(file_path: Path, fmt: str) -> int: try: if fmt == "jsonl": return sum(1 for line in open(file_path, encoding="utf-8") if line.strip()) elif fmt == "json": with open(file_path, encoding="utf-8") as f: data = json.load(f) return len(data) if isinstance(data, list) else 1 elif fmt == "csv": import csv with open(file_path, encoding="utf-8") as f: return sum(1 for _ in csv.reader(f)) - 1 elif fmt == "parquet": import pandas as pd return len(pd.read_parquet(file_path)) except Exception: pass return 0 def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]: if fmt == "jsonl": records = [] with open(file_path, encoding="utf-8") as f: for i, line in enumerate(f): if i >= n: break line = line.strip() if line: records.append(json.loads(line)) return records elif fmt == "json": with open(file_path, encoding="utf-8") as f: data = json.load(f) return data[:n] if isinstance(data, list) else [data] elif fmt == "csv": import csv with open(file_path, encoding="utf-8") as f: reader = csv.DictReader(f) return [dict(row) for i, row in enumerate(reader) if i < n] elif fmt == "parquet": import pandas as pd df = pd.read_parquet(file_path) return df.head(n).to_dict(orient="records") return []