| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948 |
- import asyncio
- import json
- import uuid
- from datetime import datetime, timezone
- from pathlib import Path
- from typing import Any
- from fastapi import UploadFile
- from app.config import get_settings
- from app.core.background_tasks import background_task_manager
- from app.core.db import async_session, DatasetRecord, DatasetDownloadTask
- from app.core.logging import logger
- from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
- from sqlalchemy import select
- settings = get_settings()
- # Known metadata filenames that are NOT training data
- META_FILENAMES = frozenset({
- "configuration.json", "configuration.yaml", "README.md",
- ".mdl", ".msc", ".mv", "model_index.json", "generation_config.json",
- "special_tokens_map.json", "tokenizer_config.json",
- "added_tokens.json", "vocab.json", "merges.txt",
- "config.json", "preprocessor_config.json",
- "dataset_infos.json", "dataset_info.json",
- "state.json", "card_data.json",
- })
- # File size threshold: files smaller than this (bytes) are likely metadata
- META_SIZE_THRESHOLD = 500
- def _is_training_data_file(path: Path) -> bool:
- """判断文件是否可能是训练数据文件(而非配置/元数据)。"""
- if path.name in META_FILENAMES:
- return False
- if path.suffix in (".jsonl", ".parquet", ".csv"):
- # 小文件可能是元数据(如 ModelScope CLI 生成的 data.jsonl 只有几十字节)
- if path.stat().st_size < META_SIZE_THRESHOLD:
- return False
- return True
- if path.suffix == ".json":
- # 小 JSON 文件通常是配置
- if path.stat().st_size < META_SIZE_THRESHOLD:
- return False
- # 尝试读取首行判断格式
- try:
- first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
- obj = json.loads(first_line)
- # 如果有 input/output/conversation/instruction 等字段,则是训练数据
- if isinstance(obj, dict):
- data_keys = {"input", "output", "conversations", "instruction", "prompt",
- "text", "completion", "source", "target", "query", "response"}
- if data_keys & set(obj.keys()):
- return True
- # 如果只有 framework/task/model_type 等字段,则是元数据
- meta_keys = {"framework", "task", "license", "base_model", "model_type",
- "language", "domains", "tags", "authors"}
- if meta_keys & set(obj.keys()):
- return False
- return True # 大 JSON 文件默认是数据
- except Exception:
- return False
- # 无后缀文件:尝试读取判断是否为 JSON/JSONL
- if not path.suffix:
- try:
- first_line = path.read_text(encoding="utf-8", errors="ignore").splitlines()[0].strip()
- json.loads(first_line)
- return True
- except Exception:
- return False
- return False
- def _extract_archives(ds_dir: Path):
- """检测并解压数据集目录中的压缩包(zip/tar.gz/tar.bz2/tgz),
- 图片数据集通常将图片存放在压缩包中,需要解压后才能在预览时显示。"""
- import zipfile
- import tarfile
- extracted_any = False
- for f in list(ds_dir.rglob("*")):
- if not f.is_file():
- continue
- # 判断是否为压缩包
- name_lower = f.name.lower()
- is_zip = name_lower.endswith(".zip")
- is_tar = any(name_lower.endswith(ext) for ext in
- (".tar.gz", ".tgz", ".tar.bz2", ".tbz2", ".tar"))
- if not is_zip and not is_tar:
- continue
- # 用压缩包名(去掉所有后缀)作为解压目标目录
- stem = f.name
- for ext in (".tar.gz", ".tar.bz2", ".tgz", ".tbz2", ".tar", ".zip"):
- if stem.lower().endswith(ext):
- stem = stem[:-len(ext)]
- break
- extract_dir = f.parent / stem
- if extract_dir.exists():
- logger.info(f"Archive already extracted, skipping: {f.name}")
- continue
- logger.info(f"Extracting archive: {f.name} -> {extract_dir}")
- try:
- if is_zip:
- with zipfile.ZipFile(f, "r") as zf:
- zf.extractall(f.parent)
- else:
- with tarfile.open(f, "r:*") as tf:
- tf.extractall(f.parent)
- extracted_any = True
- logger.info(f"Successfully extracted: {f.name}")
- except Exception as e:
- logger.warning(f"Failed to extract {f.name}: {e}")
- if extracted_any:
- logger.info(f"Archive extraction completed for {ds_dir}")
- def _download_modelscope_data_files(dataset_id: str, ds_dir: Path):
- """通过 ModelScope API 下载图片数据集的压缩包。
- 图片数据集有一个 {dataset_name}.json 配置文件,记录了各 split 对应的
- 元数据文件和压缩包名称,例如:
- {"default": {"train": {"meta": "train.csv", "file": "train.zip"},
- "validation": {"meta": "val.csv", "file": "val.zip"}}}
- CLI download 只下载 git 仓库文件(CSV 等元数据),
- 压缩包需要通过 /api/v1/datasets/{ns}/{name}/repo?FilePath=... 单独下载。
- """
- import urllib.request
- import urllib.parse
- api_base = "https://www.modelscope.cn"
- namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
- # Step 1: 找到配置文件 {dataset_name}.json 并读取
- config_files = [p for p in ds_dir.glob("*.json") if p.name not in META_FILENAMES]
- if not config_files:
- # 尝试通过 API 下载配置文件
- config_url = (f"{api_base}/api/v1/datasets/{dataset_id}/repo"
- f"?Source=SDK&Revision=master&FilePath={ds_name}.json&View=false")
- try:
- logger.info(f"尝试下载配置文件: {ds_name}.json")
- req = urllib.request.Request(config_url, headers={"User-Agent": "FineTuning-Backend"})
- with urllib.request.urlopen(req, timeout=30) as resp:
- config_data = json.loads(resp.read().decode())
- config_path = ds_dir / f"{ds_name}.json"
- config_path.write_text(json.dumps(config_data, ensure_ascii=False), encoding="utf-8")
- config_files = [config_path]
- except Exception as e:
- logger.info(f"未找到配置文件 {ds_name}.json,跳过数据文件下载: {e}")
- return
- # 在所有 json 配置文件中找到包含 "file" 字段的那个
- config = None
- for cf in config_files:
- try:
- data = json.loads(cf.read_text(encoding="utf-8"))
- # 检查是否包含 file 字段(数据集配置格式)
- if isinstance(data, dict):
- for subset in data.values():
- if isinstance(subset, dict):
- for split_info in subset.values():
- if isinstance(split_info, dict) and "file" in split_info and split_info["file"]:
- config = data
- break
- if config:
- break
- if config:
- break
- except (json.JSONDecodeError, UnicodeDecodeError):
- continue
- if not config:
- logger.info("未找到包含数据文件引用的配置文件,跳过")
- return
- logger.info(f"找到数据文件配置: {json.dumps(config, ensure_ascii=False)}")
- # Step 2: 收集所有需要下载的压缩包文件名
- archive_files = set()
- for subset in config.values():
- if not isinstance(subset, dict):
- continue
- for split_info in subset.values():
- if isinstance(split_info, dict):
- fname = split_info.get("file", "")
- if fname:
- archive_files.add(fname)
- if not archive_files:
- logger.info("配置中未找到数据文件,跳过")
- return
- # Step 3: 下载压缩包
- existing = {f.name for f in ds_dir.rglob("*") if f.is_file()}
- downloaded = []
- for fname in archive_files:
- if fname in existing:
- logger.info(f"压缩包已存在,跳过: {fname}")
- continue
- params = urllib.parse.urlencode({
- "Source": "SDK", "Revision": "master",
- "FilePath": fname, "View": "false",
- })
- dl_url = f"{api_base}/api/v1/datasets/{dataset_id}/repo?{params}"
- dest = ds_dir / fname
- logger.info(f"下载数据文件: {fname}")
- logger.info(f" URL: {dl_url}")
- try:
- req = urllib.request.Request(dl_url, headers={"User-Agent": "FineTuning-Backend"})
- with urllib.request.urlopen(req, timeout=600) as resp:
- dest.write_bytes(resp.read())
- downloaded.append(fname)
- logger.info(f" 下载完成: {fname} ({dest.stat().st_size / 1024 / 1024:.1f}MB)")
- except Exception as e:
- logger.warning(f" 下载失败 {fname}: {e}")
- if downloaded:
- logger.info(f"共下载 {len(downloaded)} 个数据文件: {downloaded}")
- else:
- logger.info("没有需要下载的数据文件")
- async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
- """启动数据集下载后台任务,立即返回 task_id。"""
- task_id = str(uuid.uuid4())
- # 写 DB
- record = DatasetDownloadTask(
- id=task_id,
- dataset_id=req.dataset_id,
- use_modelscope=1 if req.use_modelscope else 0,
- status="pending",
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- # 注册并启动
- background_task_manager.register_task(task_id, "dataset_download", {"dataset_id": req.dataset_id})
- await background_task_manager.run(
- task_id, "dataset_download", _execute_dataset_download(task_id, req)
- )
- logger.info(f"Dataset download task started: {req.dataset_id} (task_id={task_id})")
- return DatasetDownloadResponse(
- dataset_id=req.dataset_id, status="pending", task_id=task_id, path=task_id
- )
- async def _execute_dataset_download(task_id: str, req: DatasetDownloadRequest) -> dict:
- """后台执行数据集下载。"""
- try:
- if req.use_modelscope:
- ds_dir, jsonl_path, record_count = await asyncio.to_thread(
- _download_modelscope_dataset, req.dataset_id
- )
- else:
- from datasets import load_dataset
- ds = load_dataset(req.dataset_id)
- ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
- ds_dir.mkdir(parents=True, exist_ok=True)
- if "train" in ds:
- split = ds["train"]
- else:
- split = ds[list(ds.keys())[0]]
- output_path = ds_dir / "data.jsonl"
- with open(output_path, "w", encoding="utf-8") as f:
- for item in split:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- jsonl_path = output_path
- record_count = len(split) if hasattr(split, "__len__") else 0
- db_record = DatasetRecord(
- id=str(uuid.uuid4()),
- name=req.dataset_id,
- format="jsonl",
- record_count=record_count,
- file_path=str(jsonl_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(db_record)
- await session.commit()
- await _update_dataset_download_status(task_id, "completed", path=str(jsonl_path), record_count=record_count)
- logger.info(f"Dataset downloaded: {req.dataset_id} ({record_count} records)")
- return {"path": str(jsonl_path), "record_count": record_count}
- except Exception as e:
- logger.error(f"Dataset download failed: {e}")
- await _update_dataset_download_status(task_id, "failed", error=str(e))
- return {"error": str(e)}
- async def _update_dataset_download_status(task_id: str, status: str, path: str = None, error: str = None, record_count: int = 0):
- async with async_session() as session:
- result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- record.status = status
- if path:
- record.path = path
- if error:
- record.error = error
- if record_count:
- record.record_count = record_count
- if status in ("completed", "failed"):
- record.finished_at = datetime.utcnow()
- await session.commit()
- background_task_manager.update_task(
- task_id, status=status, path=path, error=error, record_count=record_count,
- )
- async def get_dataset_download_status(task_id: str) -> dict[str, Any]:
- async with async_session() as session:
- result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "task_id": record.id,
- "dataset_id": record.dataset_id,
- "status": record.status,
- "use_modelscope": bool(record.use_modelscope),
- "path": record.path,
- "error": record.error,
- "record_count": record.record_count,
- "created_at": record.created_at.isoformat() if record.created_at else "",
- }
- mem = background_task_manager.get_task(task_id)
- if mem:
- return {
- "task_id": task_id,
- "dataset_id": mem.get("dataset_id", ""),
- "status": mem["status"],
- "error": mem.get("error"),
- "record_count": mem.get("record_count", 0),
- }
- return {"task_id": task_id, "status": "not_found"}
- async def list_dataset_downloads() -> list[dict[str, Any]]:
- async with async_session() as session:
- result = await session.execute(
- select(DatasetDownloadTask).order_by(DatasetDownloadTask.created_at.desc())
- )
- records = result.scalars().all()
- return [
- {
- "task_id": r.id,
- "dataset_id": r.dataset_id,
- "status": r.status,
- "use_modelscope": bool(r.use_modelscope),
- "path": r.path,
- "error": r.error,
- "record_count": r.record_count,
- "created_at": r.created_at.isoformat() if r.created_at else "",
- }
- for r in records
- ]
- async def cancel_dataset_download(task_id: str) -> dict[str, Any]:
- background_task_manager.cancel_task(task_id)
- async with async_session() as session:
- result = await session.execute(select(DatasetDownloadTask).where(DatasetDownloadTask.id == task_id))
- record = result.scalar_one_or_none()
- if record and record.status in ("pending", "downloading"):
- record.status = "cancelled"
- record.error = "Cancelled by user"
- record.finished_at = datetime.utcnow()
- await session.commit()
- return {"task_id": task_id, "status": "cancelled"}
- async def recover_stale_downloads() -> None:
- async with async_session() as session:
- result = await session.execute(
- select(DatasetDownloadTask).where(
- DatasetDownloadTask.status.in_(["pending", "downloading"])
- )
- )
- records = result.scalars().all()
- for record in records:
- record.status = "failed"
- record.error = "Server restarted, task interrupted"
- record.finished_at = datetime.utcnow()
- if records:
- await session.commit()
- logger.info(f"Recovered {len(records)} stale dataset download tasks")
- def _download_modelscope_dataset(dataset_id: str) -> tuple[Path, Path, int]:
- """用 MsDataset.load() 下载数据集,支持图片数据集(自动从 CDN 下载图片)。
- 如果 MsDataset.load() 失败,fallback 到 CLI 方式。"""
- namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
- ds_dir = settings.processed_dir / f"ms_{dataset_id.replace('/', '_')}"
- ds_dir.mkdir(parents=True, exist_ok=True)
- # 优先用 MsDataset.load(),它能自动下载"数据文件"区的图片
- try:
- records, record_count = _download_via_msdataset(dataset_id, ds_dir)
- if records:
- jsonl_path = ds_dir / "data.jsonl"
- with open(jsonl_path, "w", encoding="utf-8") as f:
- for item in records:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- logger.info(f"MsDataset.load() 成功: {dataset_id} ({record_count} records)")
- return ds_dir, jsonl_path, record_count
- except Exception as e:
- logger.warning(f"MsDataset.load() failed for {dataset_id}: {e}, falling back to CLI")
- # fallback: CLI 方式(只下载 git 仓库文件,不含数据文件区图片)
- return _download_modelscope_dataset_cli(dataset_id, ds_dir)
- def _download_via_msdataset(dataset_id: str, ds_dir: Path) -> tuple[list[dict], int]:
- """用 MsDataset.load() 下载数据集,处理图片列(复制图片文件到数据集目录)。"""
- from modelscope.msdatasets import MsDataset
- from PIL import Image
- import shutil
- import os
- namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
- images_dir = ds_dir / "images"
- # 尝试加载不同 split
- ds = None
- for split in ("train", "validation", "test"):
- try:
- if namespace:
- ds = MsDataset.load(ds_name, namespace=namespace, split=split)
- else:
- ds = MsDataset.load(dataset_id, split=split)
- if ds:
- logger.info(f"MsDataset.load() loaded split '{split}': {len(ds) if hasattr(ds, '__len__') else '?'} records")
- break
- except Exception as e:
- logger.debug(f"split '{split}' failed: {e}")
- if not ds:
- # 不带 split 参数试试
- try:
- if namespace:
- ds = MsDataset.load(ds_name, namespace=namespace)
- else:
- ds = MsDataset.load(dataset_id)
- except Exception as e:
- logger.warning(f"MsDataset.load() without split also failed: {e}")
- return [], 0
- if not ds:
- return [], 0
- # 检查是否 iterable
- if not hasattr(ds, '__iter__'):
- return [], 0
- records = []
- img_counter = 0
- columns = None
- for row in ds:
- if not isinstance(row, dict):
- continue
- if columns is None:
- columns = list(row.keys())
- record = {}
- for k, v in row.items():
- # 检查是否是 PIL.Image 对象
- if isinstance(v, Image.Image):
- # 图片对象:保存到磁盘,记录相对路径
- images_dir.mkdir(parents=True, exist_ok=True)
- img_name = f"{img_counter:06d}.jpg"
- img_path = images_dir / img_name
- if v.mode in ("RGBA", "P", "LA"):
- v = v.convert("RGB")
- v.save(str(img_path), format="JPEG", quality=90)
- record[k] = f"images/{img_name}"
- img_counter += 1
- # 检查是否是图片文件路径
- elif isinstance(v, str) and v.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
- # 如果是绝对路径,复制文件到 images 目录
- if os.path.isabs(v) and os.path.exists(v):
- images_dir.mkdir(parents=True, exist_ok=True)
- ext = os.path.splitext(v)[1]
- img_name = f"{img_counter:06d}{ext}"
- dest_path = images_dir / img_name
- try:
- shutil.copy2(v, dest_path)
- record[k] = f"images/{img_name}"
- img_counter += 1
- except Exception as e:
- logger.warning(f"Failed to copy image {v}: {e}")
- record[k] = v
- else:
- # 相对路径或其他情况,保持原样
- record[k] = v
- else:
- record[k] = v
- records.append(record)
- # 进度日志
- if len(records) % 500 == 0:
- logger.info(f" 处理中... {len(records)} records, {img_counter} images saved")
- if img_counter > 0:
- logger.info(f"共保存 {img_counter} 张图片到 {images_dir}")
- return records, len(records)
- def _download_modelscope_dataset_cli(dataset_id: str, ds_dir: Path) -> tuple[Path, Path, int]:
- """CLI 方式下载数据集(fallback,只下载 git 仓库文件)。"""
- import subprocess
- cmd = ["modelscope", "download", "--dataset", dataset_id, "--local_dir", str(ds_dir)]
- logger.info(f"Fallback CLI: {' '.join(cmd)}")
- proc = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
- if proc.returncode != 0:
- logger.error(f"ModelScope CLI download failed (code={proc.returncode}): {proc.stderr[:500]}")
- raise RuntimeError(f"ModelScope download failed: {proc.stderr[:500]}")
- # CLI 下载完 git 仓库文件后(数据文件区的图片需要通过 MsDataset.load 获取,CLI 只能下载元数据)
- # 扫描下载目录中的所有文件
- all_files = [p for p in ds_dir.rglob("*") if p.is_file()]
- logger.info(f"CLI downloaded {len(all_files)} files to {ds_dir}")
- # 识别训练数据文件
- data_files = [f for f in all_files if _is_training_data_file(f)]
- if not data_files:
- fallback = [f for f in all_files if f.suffix in (".json", ".jsonl") and f.name not in META_FILENAMES and f.name != "README.md"]
- logger.warning(f"No training data files found in {dataset_id}. "
- f"Available JSON files: {[f.name for f in fallback]}")
- if fallback:
- data_files = fallback
- else:
- # 如果还是没有,列出所有文件供排查
- logger.error(f"All downloaded files: {[str(f.relative_to(ds_dir)) for f in all_files]}")
- raise ValueError(f"No JSON/JSONL data files found in dataset {dataset_id}. "
- f"Available files: {[f.name for f in all_files]}")
- # 按文件大小排序,取最大的文件作为训练数据(真正的数据集通常是最大的)
- target = sorted(data_files, key=lambda f: f.stat().st_size, reverse=True)[0]
- logger.info(f"Selected data file: {target} (size={target.stat().st_size})")
- # 读取并统一转为 JSONL
- jsonl_path = ds_dir / "data.jsonl"
- record_count = 0
- content = target.read_text(encoding="utf-8")
- if target.suffix == ".jsonl" or not target.suffix:
- # JSONL 或无后缀文件:逐行解析
- records = []
- for line in content.splitlines():
- line = line.strip()
- if not line:
- continue
- try:
- records.append(json.loads(line))
- except json.JSONDecodeError:
- # 单行解析失败,尝试整体解析
- try:
- data = json.loads(content)
- records = data if isinstance(data, list) else [data]
- except json.JSONDecodeError:
- records = []
- break
- elif target.suffix == ".json":
- # JSON 文件:先尝试 JSON 数组,失败再逐行解析(可能是 JSONL 格式)
- try:
- records = json.loads(content)
- if not isinstance(records, list):
- records = [records]
- except json.JSONDecodeError:
- records = []
- for line in content.splitlines():
- line = line.strip()
- if not line:
- continue
- try:
- records.append(json.loads(line))
- except json.JSONDecodeError:
- continue
- elif target.suffix == ".csv":
- import csv as _csv
- records = []
- reader = _csv.DictReader(content.splitlines())
- for row in reader:
- records.append(dict(row))
- else:
- records = []
- with open(jsonl_path, "w", encoding="utf-8") as f:
- for item in records:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- record_count += 1
- return ds_dir, jsonl_path, record_count
- async def upload_dataset(file: UploadFile) -> dict[str, Any]:
- """保存上传文件并写入数据库。"""
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- safe_name = file.filename or "unknown"
- file_path = upload_dir / safe_name
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
- content = await file.read()
- file_path.write_bytes(content)
- fmt = _detect_format(safe_name)
- record_count = _count_records(file_path, fmt)
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=safe_name,
- format=fmt,
- record_count=record_count,
- file_path=str(file_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
- return {
- "id": record_id,
- "name": safe_name,
- "format": fmt,
- "record_count": record_count,
- "file_path": str(file_path),
- "created_at": record.created_at.isoformat(),
- }
- def _detect_image_column(columns: list[str]) -> str | None:
- """检测哪一列是图片路径列。"""
- candidates = ["image_path", "image", "img_path", "img", "file_path", "filename", "path", "file"]
- for c in candidates:
- if c in columns:
- return c
- # 模糊匹配:列名包含 image 或 path
- for c in columns:
- cl = c.lower()
- if "image" in cl or ("path" in cl and "label" not in cl):
- return c
- return None
- def _resolve_image_path(path_str: str, data_dir: Path) -> Path | None:
- """解析图片路径,返回绝对路径。"""
- if not path_str:
- return None
- p = Path(path_str)
- # 已经是绝对路径
- if p.is_absolute():
- return p if p.exists() else None
- # 相对路径:先尝试相对于数据目录
- candidate = data_dir / p
- if candidate.exists():
- return candidate
- # 也可能直接在 data_dir 下(去掉目录前缀只保留文件名)
- if data_dir.joinpath(p.name).exists():
- return data_dir / p.name
- # 在 data_dir 的子目录中递归查找
- for child in data_dir.rglob(p.name):
- if child.is_file():
- return child
- logger.debug(f"Image not found: '{path_str}' (searched in {data_dir})")
- return None
- def _encode_image_base64(image_path: Path, max_size: int = 200) -> str | None:
- """将图片转为 base64 data URI,用于前端预览。"""
- import base64
- try:
- from PIL import Image
- img = Image.open(image_path)
- # 缩小尺寸用于预览
- img.thumbnail((max_size, max_size))
- if img.mode in ("RGBA", "P", "LA"):
- img = img.convert("RGB")
- import io
- buf = io.BytesIO()
- img.save(buf, format="JPEG", quality=85)
- b64 = base64.b64encode(buf.getvalue()).decode("ascii")
- return f"data:image/jpeg;base64,{b64}"
- except Exception as e:
- logger.warning(f"Failed to encode image {image_path}: {e}")
- return None
- def _format_value(value) -> str:
- """将复杂值格式化为可读字符串。"""
- if isinstance(value, (dict, list)):
- return json.dumps(value, ensure_ascii=False, indent=2)
- return str(value)
- def _is_sharegpt_format(records: list[dict]) -> bool:
- """检测是否为 ShareGPT 格式。"""
- if not records:
- return False
- first = records[0]
- if "conversations" in first and isinstance(first["conversations"], list):
- if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
- conv = first["conversations"][0]
- return "from" in conv and "value" in conv
- return False
- def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
- """将 ShareGPT 格式展平为 input/output 列。"""
- flat_rows = []
- for row in records:
- conversations = row.get("conversations", [])
- for i in range(0, len(conversations) - 1, 2):
- user_turn = conversations[i]
- assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
- if user_turn.get("from") in ("human", "user"):
- input_text = str(user_turn.get("value", ""))
- output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
- else:
- input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
- output_text = str(user_turn.get("value", ""))
- if len(input_text) > 500:
- input_text = input_text[:500] + "..."
- if len(output_text) > 500:
- output_text = output_text[:500] + "..."
- flat_rows.append({"input": input_text, "output": output_text})
- return flat_rows, ["input", "output"]
- async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
- """预览数据集前 N 行。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None}
- file_path = Path(record.file_path)
- if not file_path.exists():
- return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None}
- fmt = record.format
- preview_data = _read_records(file_path, fmt, rows)
- # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
- if _is_sharegpt_format(preview_data):
- preview_data, columns = _flatten_sharegpt(preview_data)
- else:
- columns = list(preview_data[0].keys()) if preview_data else []
- # 检测是否为视觉数据集(有图片路径列),将图片转为 base64 嵌入预览
- image_column = _detect_image_column(columns)
- data_dir = file_path.parent
- preview_rows = []
- for i, row in enumerate(preview_data):
- data = {}
- for k, v in row.items():
- if k == image_column:
- # 解析图片路径,转为 base64 嵌入
- img_path = _resolve_image_path(str(v), data_dir)
- if img_path:
- encoded = _encode_image_base64(img_path)
- data[k] = encoded if encoded else str(v)
- else:
- # 路径解析失败,保留原始路径文本
- data[k] = str(v)
- else:
- data[k] = _format_value(v)
- preview_rows.append({"row_index": i, "data": data})
- return {
- "total_records": record.record_count,
- "preview_rows": preview_rows,
- "columns": columns,
- "image_column": image_column,
- }
- async def validate_dataset(dataset_id: str) -> dict[str, Any]:
- """校验数据集格式和 Schema。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if not record:
- return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
- file_path = Path(record.file_path)
- if not file_path.exists():
- return {"is_valid": False, "errors": ["File not found"], "warnings": []}
- errors = []
- warnings = []
- fmt = record.format
- if fmt not in ("jsonl", "csv", "json", "parquet"):
- errors.append(f"Unsupported format: {fmt}")
- try:
- preview = _read_records(file_path, fmt, 5)
- if not preview:
- warnings.append("Dataset appears to be empty")
- else:
- first = preview[0]
- has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
- if not has_sft_fields:
- warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
- except Exception as e:
- errors.append(f"Failed to read file: {str(e)}")
- return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
- async def list_datasets() -> list[dict[str, Any]]:
- """列出所有已上传数据集。"""
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
- records = result.scalars().all()
- return [
- {
- "id": r.id,
- "name": r.name,
- "format": r.format,
- "record_count": r.record_count,
- "file_path": r.file_path,
- "created_at": r.created_at.isoformat(),
- }
- for r in records
- ]
- async def delete_dataset(dataset_id: str) -> dict[str, Any]:
- """删除数据集,同时清理关联的目录文件。"""
- import shutil
- async with async_session() as session:
- from sqlalchemy import select
- result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
- record = result.scalar_one_or_none()
- if record:
- file_path = Path(record.file_path)
- # 下载的数据集(processed_dir 下的子目录):删除整个目录
- if file_path.exists() and settings.processed_dir in file_path.parents:
- ds_dir = file_path.parent
- shutil.rmtree(ds_dir, ignore_errors=True)
- logger.info(f"Deleted dataset directory: {ds_dir}")
- elif file_path.exists():
- # 上传的数据集:只删文件
- file_path.unlink()
- await session.delete(record)
- await session.commit()
- logger.info(f"Deleted dataset: {record.name}")
- return {"status": "deleted"}
- def _detect_format(filename: str) -> str:
- ext = Path(filename).suffix.lower().lstrip(".")
- if ext in ("jsonl", "csv", "parquet", "json"):
- return ext
- return "unknown"
- def _count_records(file_path: Path, fmt: str) -> int:
- try:
- if fmt == "jsonl":
- return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
- elif fmt == "json":
- with open(file_path, encoding="utf-8") as f:
- data = json.load(f)
- return len(data) if isinstance(data, list) else 1
- elif fmt == "csv":
- import csv
- with open(file_path, encoding="utf-8") as f:
- return sum(1 for _ in csv.reader(f)) - 1
- elif fmt == "parquet":
- import pandas as pd
- return len(pd.read_parquet(file_path))
- except Exception:
- pass
- return 0
- def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
- if fmt == "jsonl":
- records = []
- with open(file_path, encoding="utf-8") as f:
- for i, line in enumerate(f):
- if i >= n:
- break
- line = line.strip()
- if line:
- records.append(json.loads(line))
- return records
- elif fmt == "json":
- with open(file_path, encoding="utf-8") as f:
- data = json.load(f)
- return data[:n] if isinstance(data, list) else [data]
- elif fmt == "csv":
- import csv
- with open(file_path, encoding="utf-8") as f:
- reader = csv.DictReader(f)
- return [dict(row) for i, row in enumerate(reader) if i < n]
- elif fmt == "parquet":
- import pandas as pd
- df = pd.read_parquet(file_path)
- return df.head(n).to_dict(orient="records")
- return []
|