lxylxy123321 пре 2 недеља
родитељ
комит
3051aa5793

+ 10 - 13
backend/app/api/datasets.py

@@ -15,41 +15,38 @@ router = APIRouter()
 @router.post("/download", response_model=DatasetDownloadResponse)
 async def download_dataset(req: DatasetDownloadRequest):
     """从 HuggingFace 或 ModelScope 下载数据集。"""
-    return dataset_service.download_dataset(req)
+    return await dataset_service.download_dataset(req)
 
 
 @router.post("/upload", response_model=DatasetUploadResponse)
 async def upload_dataset(file: UploadFile = File(...)):
     """上传数据集文件(JSONL / CSV / Parquet / JSON)。"""
-    return DatasetUploadResponse(
-        id="placeholder",
-        name=file.filename or "unknown",
-        format="jsonl",
-        record_count=0,
-        file_path="",
-        created_at="",
-    )
+    result = await dataset_service.upload_dataset(file)
+    return DatasetUploadResponse(**result)
 
 
 @router.get("/{dataset_id}/preview", response_model=DatasetPreviewResponse)
 async def preview_dataset(dataset_id: str, rows: int = Query(default=10, le=100)):
     """预览数据集前 N 行。"""
-    return DatasetPreviewResponse(total_records=0, preview_rows=[], columns=[])
+    result = await dataset_service.preview_dataset(dataset_id, rows)
+    return DatasetPreviewResponse(**result)
 
 
 @router.post("/{dataset_id}/validate", response_model=DatasetValidationResult)
 async def validate_dataset(dataset_id: str):
     """校验数据集格式和 Schema。"""
-    return DatasetValidationResult(is_valid=True)
+    result = await dataset_service.validate_dataset(dataset_id)
+    return DatasetValidationResult(**result)
 
 
 @router.get("/", response_model=list[DatasetUploadResponse])
 async def list_datasets():
     """列出所有已上传数据集。"""
-    return []
+    items = await dataset_service.list_datasets()
+    return [DatasetUploadResponse(**item) for item in items]
 
 
 @router.delete("/{dataset_id}")
 async def delete_dataset(dataset_id: str):
     """删除数据集。"""
-    return {"status": "deleted"}
+    return await dataset_service.delete_dataset(dataset_id)

+ 7 - 11
backend/app/api/deployment.py

@@ -1,6 +1,7 @@
 from fastapi import APIRouter
 
 from app.schemas.deployment import DeployConfig, DeployResponse
+from app.services import deploy_service
 
 router = APIRouter()
 
@@ -8,20 +9,15 @@ router = APIRouter()
 @router.post("/export", response_model=DeployResponse)
 async def export_adapter(config: DeployConfig):
     """合并 adapter 与基础模型,可选导出为 GGUF。"""
-    return DeployResponse(
-        job_id=config.job_id,
-        status="pending",
-        output_path=None,
-        error=None,
+    result = await deploy_service.export_adapter(
+        config.job_id,
+        {"merge_with_base": config.merge_with_base, "export_format": config.export_format},
     )
+    return DeployResponse(**result)
 
 
 @router.get("/{deploy_id}/status", response_model=DeployResponse)
 async def get_deployment_status(deploy_id: str):
     """获取导出/部署任务状态。"""
-    return DeployResponse(
-        job_id=deploy_id,
-        status="pending",
-        output_path=None,
-        error=None,
-    )
+    result = await deploy_service.get_deploy_status(deploy_id)
+    return DeployResponse(**result)

+ 5 - 12
backend/app/api/evaluation.py

@@ -1,6 +1,7 @@
 from fastapi import APIRouter
 
 from app.schemas.evaluation import EvalConfig, EvalResult
+from app.services import eval_service
 
 router = APIRouter()
 
@@ -8,20 +9,12 @@ router = APIRouter()
 @router.post("/run", response_model=EvalResult)
 async def run_evaluation(config: EvalConfig):
     """对已训练的 adapter 运行评估。"""
-    return EvalResult(
-        id="placeholder",
-        job_id=config.job_id,
-        metrics={},
-        created_at="",
-    )
+    result = await eval_service.run_evaluation(config.job_id, config.model_dump())
+    return EvalResult(**result)
 
 
 @router.get("/{eval_id}/results", response_model=EvalResult)
 async def get_evaluation_results(eval_id: str):
     """获取已完成评估的结果。"""
-    return EvalResult(
-        id=eval_id,
-        job_id="",
-        metrics={},
-        created_at="",
-    )
+    result = await eval_service.get_evaluation_results(eval_id)
+    return EvalResult(**result)

+ 30 - 2
backend/app/api/models.py

@@ -1,6 +1,7 @@
 from fastapi import APIRouter
 
 from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
+from app.services import model_service
 
 router = APIRouter()
 
@@ -8,20 +9,47 @@ router = APIRouter()
 @router.get("/", response_model=list[ModelInfo])
 async def list_models():
     """列出所有本地缓存的模型。"""
-    return []
+    models = model_service.list_cached_models()
+    return [
+        ModelInfo(
+            id=m["id"],
+            name=m.get("name", m["id"]),
+            model_type=m.get("model_type", "text"),
+            path=m.get("path"),
+            is_downloaded=m.get("is_downloaded", True),
+            context_length=m.get("context_length"),
+            supported_peft_methods=m.get("supported_peft_methods", []),
+        )
+        for m in models
+    ]
 
 
 @router.post("/download", response_model=ModelDownloadResponse)
 async def download_model(req: ModelDownloadRequest):
     """从 HuggingFace 或 ModelScope 下载模型。"""
+    result = await model_service.download_model(req.model_id, req.use_modelscope)
     return ModelDownloadResponse(
-        model_id=req.model_id, status="downloading", path=None, error=None
+        model_id=result["model_id"],
+        status=result["status"],
+        path=result.get("path"),
+        error=result.get("error"),
     )
 
 
 @router.get("/{model_id}", response_model=ModelInfo)
 async def get_model_info(model_id: str):
     """获取已缓存模型的详细信息。"""
+    info = await model_service.get_model_info(model_id)
+    if info:
+        return ModelInfo(
+            id=info["id"],
+            name=info.get("name", model_id.split("/")[-1]),
+            model_type=info.get("model_type", "text"),
+            path=info.get("path"),
+            is_downloaded=info.get("is_downloaded", True),
+            context_length=info.get("context_length"),
+            supported_peft_methods=info.get("supported_peft_methods", []),
+        )
     return ModelInfo(
         id=model_id,
         name=model_id.split("/")[-1],

+ 24 - 11
backend/app/api/training.py

@@ -1,6 +1,7 @@
 from fastapi import APIRouter
 
 from app.schemas.training import TrainingConfig, TrainingJobResponse, TrainingProgress
+from app.services import training_service
 
 router = APIRouter()
 
@@ -8,25 +9,24 @@ router = APIRouter()
 @router.post("/jobs", response_model=TrainingJobResponse)
 async def create_training_job(config: TrainingConfig):
     """创建并加入训练任务。"""
-    return TrainingJobResponse(
-        id="placeholder",
-        model_id=config.model_id,
-        model_type=config.model_type.value,
-        peft_method=config.peft_method.value,
-        status="pending",
-        created_at="",
-    )
+    config_dict = config.model_dump()
+    result = await training_service.create_training_job(config_dict)
+    return TrainingJobResponse(**result)
 
 
 @router.get("/jobs", response_model=list[TrainingJobResponse])
 async def list_training_jobs():
     """列出所有训练任务。"""
-    return []
+    items = await training_service.list_training_jobs()
+    return [TrainingJobResponse(**item) for item in items]
 
 
 @router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
 async def get_training_job(job_id: str):
     """获取指定任务详情。"""
+    item = await training_service.get_training_job(job_id)
+    if item:
+        return TrainingJobResponse(**item)
     return TrainingJobResponse(
         id=job_id,
         model_id="",
@@ -40,10 +40,23 @@ async def get_training_job(job_id: str):
 @router.post("/jobs/{job_id}/cancel")
 async def cancel_training_job(job_id: str):
     """取消运行中的训练任务。"""
-    return {"status": "cancelled"}
+    return await training_service.cancel_training_job(job_id)
 
 
 @router.get("/jobs/{job_id}/logs")
 async def stream_training_logs(job_id: str):
     """通过 SSE 流式推送训练日志。"""
-    return {"logs": []}
+    from fastapi.responses import StreamingResponse
+
+    async def log_stream():
+        from app.config import get_settings
+        _settings = get_settings()
+        log_file = _settings.adapters_dir / job_id / "trainer_log.txt"
+        if os.path.exists(log_file):
+            with open(log_file, "r") as f:
+                for line in f:
+                    yield f"data: {line}\n\n"
+        else:
+            yield "data: No logs available\n\n"
+
+    return StreamingResponse(log_stream(), media_type="text/event-stream")

+ 99 - 2
backend/app/core/db.py

@@ -1,9 +1,15 @@
-from app.config import get_settings
+from datetime import datetime
 
+from sqlalchemy import Column, DateTime, Float, Integer, String, Text
 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
+from sqlalchemy.orm import DeclarativeBase
+
+from app.config import get_settings
 
 settings = get_settings()
 
+Base = DeclarativeBase()
+
 engine = create_async_engine(
     settings.database_url,
     echo=settings.backend_env == "development",
@@ -12,6 +18,97 @@ engine = create_async_engine(
 async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
 
 
-async def get_db() -> AsyncSession:  # type: ignore[misc]
+class TrainingJobModel(Base):
+    __tablename__ = "training_jobs"
+
+    id = Column(String(36), primary_key=True)
+    model_id = Column(String(256), nullable=False)
+    model_type = Column(String(32), nullable=False)
+    dataset_id = Column(String(36), nullable=False)
+    peft_method = Column(String(32), nullable=False)
+    task_type = Column(String(32), default="sft")  # sft/dpo/kto/orpo/rm/ppo
+    dataset_template = Column(String(32), default="alpaca")
+
+    status = Column(String(32), default="pending")
+    progress = Column(Float, default=0.0)
+    current_epoch = Column(Integer, default=0)
+    current_step = Column(Integer, default=0)
+    total_steps = Column(Integer, default=0)
+    loss = Column(Float, nullable=True)
+    learning_rate = Column(Float, nullable=True)
+
+    epochs = Column(Integer, default=3)
+    batch_size = Column(Integer, default=4)
+    gradient_accumulation = Column(Integer, default=4)
+    max_seq_length = Column(Integer, default=2048)
+    warmup_ratio = Column(Float, default=0.05)
+    save_strategy = Column(String(32), default="epoch")
+    eval_strategy = Column(String(32), default="epoch")
+    eval_steps = Column(Integer, default=100)
+
+    lora_r = Column(Integer, default=16)
+    lora_alpha = Column(Integer, default=32)
+    lora_dropout = Column(Float, default=0.05)
+    lora_target_modules = Column(String(256), default="all-linear")
+    qlora_bits = Column(Integer, default=4)
+
+    created_at = Column(DateTime, default=datetime.utcnow)
+    started_at = Column(DateTime, nullable=True)
+    finished_at = Column(DateTime, nullable=True)
+    error_message = Column(Text, nullable=True)
+    adapter_path = Column(String(512), nullable=True)
+
+
+class DatasetRecord(Base):
+    __tablename__ = "datasets"
+
+    id = Column(String(36), primary_key=True)
+    name = Column(String(256), nullable=False)
+    format = Column(String(16), nullable=False)
+    record_count = Column(Integer, default=0)
+    file_path = Column(String(512), nullable=False)
+    created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class ModelCache(Base):
+    __tablename__ = "model_cache"
+
+    id = Column(String(256), primary_key=True)
+    name = Column(String(256), nullable=False)
+    model_type = Column(String(32), nullable=False)
+    path = Column(String(512), nullable=True)
+    is_downloaded = Column(Integer, default=0)
+    context_length = Column(Integer, nullable=True)
+    supported_peft_methods = Column(String(256), default="")
+    created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class EvalResultModel(Base):
+    __tablename__ = "eval_results"
+
+    id = Column(String(36), primary_key=True)
+    job_id = Column(String(36), nullable=False)
+    metrics = Column(Text, default="{}")
+    created_at = Column(DateTime, default=datetime.utcnow)
+
+
+class DeployTaskModel(Base):
+    __tablename__ = "deploy_tasks"
+
+    id = Column(String(36), primary_key=True)
+    job_id = Column(String(36), nullable=False)
+    status = Column(String(32), default="pending")
+    output_path = Column(String(512), nullable=True)
+    error = Column(Text, nullable=True)
+    created_at = Column(DateTime, default=datetime.utcnow)
+
+
+async def init_db():
+    """创建所有表(首次启动时调用)。"""
+    async with engine.begin() as conn:
+        await conn.run_sync(Base.metadata.create_all)
+
+
+async def get_db() -> AsyncSession:
     async with async_session() as session:
         yield session

+ 241 - 0
backend/app/core/job_queue.py

@@ -1,8 +1,12 @@
+import asyncio
 from datetime import datetime, timezone
 from enum import Enum
+from typing import Any, Callable, Coroutine, Optional
 
 from pydantic import BaseModel, Field
 
+from app.core.logging import logger
+
 
 class JobStatus(str, Enum):
     PENDING = "pending"
@@ -38,3 +42,240 @@ class TrainingJob(BaseModel):
     created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
     started_at: str | None = None
     finished_at: str | None = None
+
+
+class JobQueue:
+    """异步任务队列,支持取消和并发控制。"""
+
+    def __init__(self, max_concurrent: int = 2):
+        self._queue: asyncio.Queue[str] = asyncio.Queue()
+        self._jobs: dict[str, TrainingJob] = {}
+        self._cancel_events: dict[str, asyncio.Event] = {}
+        self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = []
+        self._max_concurrent = max_concurrent
+        self._workers: list[asyncio.Task] = []
+        self._running = False
+
+    async def start(self):
+        """启动后台 worker。"""
+        if self._running:
+            return
+        self._running = True
+        for _ in range(self._max_concurrent):
+            worker = asyncio.create_task(self._worker_loop())
+            self._workers.append(worker)
+        logger.info(f"JobQueue started with {self._max_concurrent} workers")
+
+    async def stop(self):
+        """停止所有 worker。"""
+        self._running = False
+        for event in self._cancel_events.values():
+            event.set()
+        for worker in self._workers:
+            worker.cancel()
+        self._workers.clear()
+        logger.info("JobQueue stopped")
+
+    async def enqueue(self, job_id: str, job: TrainingJob):
+        """将任务加入队列。"""
+        self._jobs[job_id] = job
+        self._cancel_events[job_id] = asyncio.Event()
+        await self._queue.put(job_id)
+        logger.info(f"Job {job_id} enqueued")
+
+    async def dequeue(self) -> str:
+        """从队列中取出任务 ID。"""
+        return await self._queue.get()
+
+    def mark_done(self, job_id: str):
+        """标记任务完成。"""
+        self._queue.task_done()
+        self._cancel_events.pop(job_id, None)
+
+    def get_job(self, job_id: str) -> Optional[TrainingJob]:
+        return self._jobs.get(job_id)
+
+    def update_job(self, job_id: str, **kwargs):
+        if job_id in self._jobs:
+            job = self._jobs[job_id]
+            for key, val in kwargs.items():
+                if hasattr(job, key):
+                    setattr(job, key, val)
+
+    def is_cancelled(self, job_id: str) -> bool:
+        event = self._cancel_events.get(job_id)
+        return event is not None and event.is_set()
+
+    async def cancel(self, job_id: str):
+        """取消任务。"""
+        if job_id in self._cancel_events:
+            self._cancel_events[job_id].set()
+            self.update_job(job_id, status=JobStatus.CANCELLED)
+            await self._notify_callbacks()
+            logger.info(f"Job {job_id} cancelled")
+
+    def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]):
+        """注册状态变更回调(用于更新数据库等)。"""
+        self._callbacks.append(callback)
+
+    async def _notify_callbacks(self):
+        for cb in self._callbacks:
+            try:
+                for job in self._jobs.values():
+                    await cb(job)
+            except Exception as e:
+                logger.error(f"JobQueue callback error: {e}")
+
+    async def _worker_loop(self):
+        """worker 循环:不断从队列取任务并执行。"""
+        while self._running:
+            try:
+                job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0)
+            except asyncio.TimeoutError:
+                continue
+
+            try:
+                await self._run_job(job_id)
+            except Exception as e:
+                logger.error(f"Job {job_id} failed: {e}")
+                self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
+            finally:
+                self._queue.task_done()
+
+    async def _run_job(self, job_id: str):
+        """执行单个任务:预处理 → 训练 → 完成。"""
+        job = self._jobs.get(job_id)
+        if not job:
+            return
+
+        self.update_job(job_id, status=JobStatus.QUEUED)
+        await self._notify_callbacks()
+
+        if self.is_cancelled(job_id):
+            return
+
+        self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat())
+        await self._notify_callbacks()
+
+        if self.is_cancelled(job_id):
+            return
+
+        try:
+            config = job.config
+            model_id = job.model_id
+            model_type = job.model_type
+            peft_method = job.peft_method
+            dataset_id = config.get("dataset_id", job.dataset_id)
+
+            # 获取数据集文件路径
+            from app.config import get_settings
+            settings = get_settings()
+
+            # 查找数据集文件
+            dataset_path = self._find_dataset_path(dataset_id)
+            if not dataset_path:
+                raise FileNotFoundError(f"Dataset not found: {dataset_id}")
+
+            # 预处理
+            processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
+            task_type = config.get("task_type", "sft")
+            template = config.get("dataset_template", "alpaca")
+
+            # 选择引擎
+            engine = self._get_engine(model_type)
+
+            # 预处理
+            await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
+            self.update_job(job_id, status=JobStatus.TRAINING)
+            await self._notify_callbacks()
+
+            # 加载模型
+            await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
+
+            # 构建 PEFT 配置
+            peft_config = engine.get_peft_config(peft_method, config)
+
+            # 训练
+            adapter_path = await engine.train(
+                job_id=job_id,
+                dataset_path=processed_path,
+                peft_config=peft_config,
+                training_args=config,
+            )
+
+            self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path)
+            await self._notify_callbacks()
+            logger.info(f"Job {job_id} completed successfully")
+
+        except asyncio.CancelledError:
+            self.update_job(job_id, status=JobStatus.CANCELLED)
+            await self._notify_callbacks()
+        except Exception as e:
+            logger.error(f"Job {job_id} failed: {e}")
+            self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
+            await self._notify_callbacks()
+
+    def _find_dataset_path(self, dataset_id: str) -> str | None:
+        """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。"""
+        import asyncio
+        from app.core.db import async_session, DatasetRecord
+        from sqlalchemy import select
+        from app.config import get_settings
+        from pathlib import Path
+
+        settings = get_settings()
+
+        # 尝试从数据库查找
+        try:
+            loop = asyncio.get_event_loop()
+            task = loop.create_task(self._lookup_dataset_db(dataset_id))
+            path = loop.run_until_complete(task)
+            if path:
+                return path
+        except Exception:
+            pass
+
+        # 尝试从 uploads 目录查找
+        upload_path = settings.uploads_dir / dataset_id
+        if upload_path.exists():
+            return str(upload_path)
+
+        # 如果 dataset_id 本身是路径
+        if Path(dataset_id).exists():
+            return dataset_id
+
+        return None
+
+    async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
+        """从数据库查找数据集路径。"""
+        from app.core.db import async_session, DatasetRecord
+        from sqlalchemy import select
+
+        async with async_session() as session:
+            result = await session.execute(select(DatasetRecord).where(
+                (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
+            ))
+            record = result.scalar_one_or_none()
+            if record:
+                return record.file_path
+        return None
+
+    def _get_engine(self, model_type: str):
+        """根据模型类型选择训练引擎。"""
+        if model_type == "vision":
+            from app.engines.vision_engine import vision_engine
+            return vision_engine
+        elif model_type == "multimodal":
+            from app.engines.multimodal_engine import multimodal_engine
+            return multimodal_engine
+        else:
+            from app.engines.text_engine import text_engine
+            return text_engine
+
+    @property
+    def jobs(self) -> dict[str, TrainingJob]:
+        return dict(self._jobs)
+
+
+# 全局单例
+job_queue = JobQueue(max_concurrent=2)

+ 178 - 11
backend/app/engines/multimodal_engine.py

@@ -1,20 +1,187 @@
+import json
+from pathlib import Path
+from typing import Any
+
+from app.config import get_settings
+from app.core.logging import logger
 from app.engines.base import BaseEngine
 
+settings = get_settings()
+
 
 class MultimodalEngine(BaseEngine):
-    """Training engine for LLaVA, Qwen-VL, and other vision-language models."""
+    """多模态模型训练引擎 (LLaVA/Qwen-VL 等视觉语言模型)。"""
+
+    def __init__(self):
+        self._processor = None
+        self._model = None
+
+    async def load_model(self, model_id: str, **kwargs: Any) -> None:
+        """下载并加载多模态模型。"""
+        import torch
+        from transformers import AutoProcessor, LlavaForConditionalGeneration
+
+        local_path = str(settings.models_dir / model_id.replace("/", "_"))
+
+        if not (Path(local_path) / "config.json").exists():
+            from huggingface_hub import snapshot_download
+            snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
+
+        self._processor = AutoProcessor.from_pretrained(local_path, trust_remote_code=True)
+        self._model = LlavaForConditionalGeneration.from_pretrained(
+            local_path,
+            torch_dtype=torch.float16,
+            device_map="auto",
+            trust_remote_code=True,
+        )
+        logger.info(f"Loaded multimodal model: {model_id}")
+
+    def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
+        from peft import LoraConfig, TaskType
+
+        target_modules = params.get("lora_target_modules", "all-linear")
+        if isinstance(target_modules, str) and target_modules == "all-linear":
+            target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
+
+        return LoraConfig(
+            r=params.get("lora_r", 16),
+            lora_alpha=params.get("lora_alpha", 32),
+            lora_dropout=params.get("lora_dropout", 0.05),
+            target_modules=target_modules,
+            task_type=TaskType.CAUSAL_LM,
+        )
+
+    async def preprocess_dataset(
+        self, dataset_path: str, output_path: str, **kwargs: Any
+    ) -> str:
+        """多模态数据集预处理 (image + text pairs)。"""
+        from app.preprocessors import preprocess_file
+
+        processed = preprocess_file(dataset_path, output_path, "sft", "raw")
+        logger.info(f"Preprocessed {len(processed)} multimodal samples")
+        return output_path
+
+    async def train(
+        self,
+        job_id: str,
+        dataset_path: str,
+        peft_config: Any,
+        training_args: dict[str, Any],
+    ) -> str:
+        from peft import get_peft_model
+        from transformers import Trainer, TrainingArguments
+        from datasets import Dataset as HFDataset
+
+        data = []
+        with open(dataset_path, "r", encoding="utf-8") as f:
+            for line in f:
+                line = line.strip()
+                if line:
+                    data.append(json.loads(line))
+
+        def collate_fn(examples):
+            texts = [item.get("text", "") for item in examples]
+            image_paths = [item.get("image_path", "") for item in examples if "image_path" in item]
+
+            if image_paths:
+                from PIL import Image
+                images = [Image.open(p).convert("RGB") for p in image_paths if Path(p).exists()]
+                if images:
+                    inputs = self._processor(text=texts, images=images, return_tensors="pt", padding=True)
+                    inputs["labels"] = inputs["input_ids"].clone()
+                    return inputs
+
+            # fallback: text-only
+            inputs = self._processor(text=texts, return_tensors="pt", padding=True)
+            inputs["labels"] = inputs["input_ids"].clone()
+            return inputs
+
+        hf_dataset = HFDataset.from_list(data)
+
+        self._model = get_peft_model(self._model, peft_config)
+        self._model.print_trainable_parameters()
+
+        output_dir = str(settings.adapters_dir / job_id)
+        epochs = training_args.get("epochs", 3)
+        batch_size = training_args.get("batch_size", 4)
+        learning_rate = training_args.get("learning_rate", 2e-4)
+
+        tr_args = TrainingArguments(
+            output_dir=output_dir,
+            num_train_epochs=epochs,
+            per_device_train_batch_size=batch_size,
+            learning_rate=learning_rate,
+            save_strategy="epoch",
+            logging_steps=10,
+            fp16=True,
+            optim="adamw_torch",
+            remove_unused_columns=False,
+            report_to="none",
+        )
+
+        callback = _ProgressCallback(job_id)
+        trainer = Trainer(
+            model=self._model,
+            args=tr_args,
+            train_dataset=hf_dataset,
+            data_collator=collate_fn,
+            callbacks=[callback],
+        )
+
+        try:
+            trainer.train()
+            self._model.save_pretrained(output_dir)
+            self._processor.save_pretrained(output_dir)
+            logger.info(f"Multimodal training completed for job {job_id}")
+        except Exception as e:
+            logger.error(f"Multimodal training failed for job {job_id}: {e}")
+            raise
+
+        return output_dir
+
+    def get_model_info(self, model_id: str) -> dict[str, Any]:
+        model_dir = settings.models_dir / model_id.replace("/", "_")
+        config_path = model_dir / "config.json"
+        if config_path.exists():
+            with open(config_path) as f:
+                config = json.load(f)
+            return {
+                "model_type": config.get("model_type", "multimodal"),
+                "context_length": config.get("max_position_embeddings", 2048),
+                "hidden_size": config.get("hidden_size", 0),
+                "num_layers": config.get("num_hidden_layers", 0),
+            }
+        return {"model_type": "multimodal", "context_length": 4096}
+
+
+class _ProgressCallback:
+    def __init__(self, job_id: str):
+        self.job_id = job_id
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        if logs and "loss" in logs:
+            import asyncio
+            asyncio.create_task(
+                send_progress(self.job_id, epoch=int(state.epoch or 0), step=state.global_step,
+                              total_steps=state.max_steps or 0, loss=logs["loss"], learning_rate=logs.get("learning_rate", 0))
+            )
+
+    def on_epoch_end(self, args, state, control, **kwargs):
+        import asyncio
+        asyncio.create_task(send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None))
 
-    async def load_model(self, model_id: str, **kwargs):
-        raise NotImplementedError
+    def on_train_end(self, args, state, control, **kwargs):
+        import asyncio
+        asyncio.create_task(send_completed(self.job_id, total_time_seconds=getattr(state, "train_runtime", 0),
+                                           adapter_path=str(settings.adapters_dir / self.job_id)))
 
-    def get_peft_config(self, method: str, params: dict):
-        raise NotImplementedError
+    def on_train_begin(self, args, state, control, **kwargs): pass
+    def on_step_end(self, args, state, control, **kwargs): pass
+    def on_evaluate(self, args, state, control, metrics=None, **kwargs): pass
+    def on_save(self, args, state, control, **kwargs): pass
+    def on_predict(self, args, state, control, metrics=None, **kwargs): pass
 
-    async def preprocess_dataset(self, dataset_path: str, output_path: str, **kwargs):
-        raise NotImplementedError
 
-    async def train(self, job_id: str, dataset_path: str, peft_config, training_args: dict):
-        raise NotImplementedError
+from app.core.websocket import send_completed, send_epoch_done, send_progress
 
-    def get_model_info(self, model_id: str):
-        raise NotImplementedError
+multimodal_engine = MultimodalEngine()

+ 303 - 11
backend/app/engines/text_engine.py

@@ -1,20 +1,312 @@
+import asyncio
+import json
+from pathlib import Path
+from typing import Any
+
+from app.config import get_settings
+from app.core.logging import logger
 from app.engines.base import BaseEngine
 
+settings = get_settings()
+
 
 class TextEngine(BaseEngine):
-    """Training engine for LLaMA, Qwen, and other text-only LLMs."""
+    """文本模型训练引擎 (LLaMA/Qwen/ChatGLM 等因果语言模型)。"""
+
+    def __init__(self):
+        self._tokenizer = None
+        self._model = None
+
+    async def load_model(self, model_id: str, **kwargs: Any) -> None:
+        """下载并加载基础模型。"""
+        import torch
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+
+        local_path = str(settings.models_dir / model_id.replace("/", "_"))
+
+        # 如果本地没有,从 HF 下载
+        if not (Path(local_path) / "config.json").exists():
+            from huggingface_hub import snapshot_download
+
+            snapshot_download(
+                repo_id=model_id,
+                local_dir=local_path,
+                local_dir_use_symlinks=False,
+            )
+
+        quantization = kwargs.get("quantization", None)
+        load_kwargs: dict[str, Any] = {
+            "torch_dtype": torch.float16,
+            "device_map": "auto",
+        }
+        if quantization == "4bit" or quantization == "qlora":
+            load_kwargs["load_in_4bit"] = True
+            load_kwargs["bnb_4bit_quant_type"] = "nf4"
+            load_kwargs["bnb_4bit_use_double_quant"] = True
+        elif quantization == "8bit":
+            load_kwargs["load_in_8bit"] = True
+
+        self._tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
+        if self._tokenizer.pad_token is None:
+            self._tokenizer.pad_token = self._tokenizer.eos_token
+
+        self._model = AutoModelForCausalLM.from_pretrained(local_path, **load_kwargs)
+        logger.info(f"Loaded model: {model_id}")
+
+    def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
+        """根据 PEFT 方法返回对应的配置对象。"""
+        from app.peft import (
+            build_adalora_config,
+            build_ia3_config,
+            build_lora_config,
+            build_prefix_tuning_config,
+            build_qlora_config,
+        )
+
+        builders = {
+            "lora": build_lora_config,
+            "qlora": build_qlora_config,
+            "ia3": build_ia3_config,
+            "adalora": build_adalora_config,
+            "prefix_tuning": build_prefix_tuning_config,
+        }
+        builder = builders.get(method, build_lora_config)
+        return builder(params)
+
+    async def preprocess_dataset(
+        self,
+        dataset_path: str,
+        output_path: str,
+        task_type: str = "sft",
+        template: str = "alpaca",
+        **kwargs: Any,
+    ) -> str:
+        """将数据集预处理为训练格式。"""
+        from app.preprocessors import preprocess_file
+
+        processed = preprocess_file(dataset_path, output_path, task_type, template)
+        logger.info(f"Preprocessed {len(processed)} samples for {task_type}/{template}")
+        return output_path
+
+    async def train(
+        self,
+        job_id: str,
+        dataset_path: str,
+        peft_config: Any,
+        training_args: dict[str, Any],
+    ) -> str:
+        """执行训练。"""
+        from peft import get_peft_model
+        from transformers import DataCollatorForSeq2Seq, TrainingArguments
+
+        task_type = training_args.get("task_type", "sft")
+        epochs = training_args.get("epochs", 3)
+        batch_size = training_args.get("batch_size", 4)
+        gradient_accumulation = training_args.get("gradient_accumulation", 4)
+        learning_rate = training_args.get("learning_rate", 2e-4)
+        max_seq_length = training_args.get("max_seq_length", 2048)
+        warmup_ratio = training_args.get("warmup_ratio", 0.05)
+        save_strategy = training_args.get("save_strategy", "epoch")
+        deepspeed_config = training_args.get("deepspeed", None)
+
+        dataset = self._tokenize_dataset(dataset_path, max_seq_length)
+
+        self._model = get_peft_model(self._model, peft_config)
+        self._model.print_trainable_parameters()
+
+        output_dir = str(settings.adapters_dir / job_id)
+        tr_args = TrainingArguments(
+            output_dir=output_dir,
+            num_train_epochs=epochs,
+            per_device_train_batch_size=batch_size,
+            gradient_accumulation_steps=gradient_accumulation,
+            learning_rate=learning_rate,
+            warmup_ratio=warmup_ratio,
+            save_strategy=save_strategy,
+            logging_strategy="steps",
+            logging_steps=10,
+            fp16=True,
+            optim="adamw_torch",
+            remove_unused_columns=False,
+            report_to="none",
+            **({"deepspeed": deepspeed_config} if deepspeed_config else {}),
+        )
+
+        callback = _ProgressCallback(job_id)
+
+        if task_type == "sft":
+            from transformers import Trainer
+
+            trainer = Trainer(
+                model=self._model,
+                args=tr_args,
+                train_dataset=dataset,
+                data_collator=DataCollatorForSeq2Seq(self._tokenizer),
+                callbacks=[callback],
+            )
+        else:
+            from trl import (
+                DPOConfig,
+                DPOTrainer,
+                KTOConfig,
+                KTOTrainer,
+                ORPOConfig,
+                ORPOTrainer,
+            )
+
+            base_trainer_kwargs = dict(
+                output_dir=output_dir,
+                num_train_epochs=epochs,
+                per_device_train_batch_size=batch_size,
+                gradient_accumulation_steps=gradient_accumulation,
+                learning_rate=learning_rate,
+                warmup_ratio=warmup_ratio,
+                save_strategy=save_strategy,
+                logging_steps=10,
+                fp16=True,
+                report_to="none",
+            )
+
+            if task_type == "dpo":
+                trainer = DPOTrainer(
+                    model=self._model,
+                    args=DPOConfig(**base_trainer_kwargs),
+                    train_dataset=dataset,
+                    processing_class=self._tokenizer,
+                )
+            elif task_type == "orpo":
+                trainer = ORPOTrainer(
+                    model=self._model,
+                    args=ORPOConfig(**base_trainer_kwargs),
+                    train_dataset=dataset,
+                    processing_class=self._tokenizer,
+                )
+            elif task_type == "kto":
+                trainer = KTOTrainer(
+                    model=self._model,
+                    args=KTOConfig(**base_trainer_kwargs),
+                    train_dataset=dataset,
+                    processing_class=self._tokenizer,
+                )
+            else:
+                trainer = Trainer(
+                    model=self._model,
+                    args=tr_args,
+                    train_dataset=dataset,
+                    data_collator=DataCollatorForSeq2Seq(self._tokenizer),
+                    callbacks=[callback],
+                )
+
+        try:
+            trainer.train()
+            self._model.save_pretrained(output_dir)
+            self._tokenizer.save_pretrained(output_dir)
+            logger.info(f"Training completed for job {job_id}")
+        except Exception as e:
+            logger.error(f"Training failed for job {job_id}: {e}")
+            raise
+
+        return output_dir
+
+    def get_model_info(self, model_id: str) -> dict[str, Any]:
+        """读取模型配置信息。"""
+        import json
+        from pathlib import Path
+
+        model_dir = settings.models_dir / model_id.replace("/", "_")
+        config_path = model_dir / "config.json"
+
+        if config_path.exists():
+            with open(config_path) as f:
+                config = json.load(f)
+            return {
+                "model_type": config.get("model_type", "causal_lm"),
+                "context_length": config.get("max_position_embeddings", config.get("max_sequence_length", 2048)),
+                "hidden_size": config.get("hidden_size", 0),
+                "num_layers": config.get("num_hidden_layers", 0),
+            }
+        return {"model_type": "causal_lm", "context_length": 2048}
+
+    def _tokenize_dataset(self, dataset_path: str, max_seq_length: int):
+        """Tokenize 处理后的 JSONL 数据集。"""
+        from datasets import Dataset as HFDataset
+
+        data = []
+        with open(dataset_path, "r", encoding="utf-8") as f:
+            for line in f:
+                line = line.strip()
+                if line:
+                    data.append(json.loads(line))
+
+        hf_dataset = HFDataset.from_list(data)
+
+        def tokenize_fn(batch):
+            prompts = batch.get("prompt", [""] * len(data))
+            completions = batch.get("completion", [""] * len(data))
+
+            if isinstance(prompts, str):
+                prompts = [prompts]
+            if isinstance(completions, str):
+                completions = [completions]
+
+            full_texts = [f"{p}\n{c}" for p, c in zip(prompts, completions)]
+            tokenized = self._tokenizer(
+                full_texts, truncation=True, max_length=max_seq_length, padding=False,
+            )
+            tokenized["labels"] = list(tokenized["input_ids"])
+            return tokenized
+
+        return hf_dataset.map(tokenize_fn, batched=True)
+
+
+class _ProgressCallback:
+    """自定义训练进度回调,通过 WebSocket 发送进度。"""
+
+    def __init__(self, job_id: str):
+        self.job_id = job_id
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        if logs and "loss" in logs:
+            asyncio.create_task(
+                send_progress(
+                    self.job_id,
+                    epoch=int(state.epoch or 0),
+                    step=state.global_step,
+                    total_steps=state.max_steps or 0,
+                    loss=logs["loss"],
+                    learning_rate=logs.get("learning_rate", 0),
+                )
+            )
+
+    def on_epoch_end(self, args, state, control, **kwargs):
+        asyncio.create_task(
+            send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None)
+        )
+
+    def on_train_end(self, args, state, control, **kwargs):
+        asyncio.create_task(
+            send_completed(
+                self.job_id,
+                total_time_seconds=getattr(state, "train_runtime", 0),
+                adapter_path=str(settings.adapters_dir / self.job_id),
+            )
+        )
+
+    def on_train_begin(self, args, state, control, **kwargs):
+        pass
+
+    def on_step_end(self, args, state, control, **kwargs):
+        pass
 
-    async def load_model(self, model_id: str, **kwargs):
-        raise NotImplementedError
+    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
+        pass
 
-    def get_peft_config(self, method: str, params: dict):
-        raise NotImplementedError
+    def on_save(self, args, state, control, **kwargs):
+        pass
 
-    async def preprocess_dataset(self, dataset_path: str, output_path: str, **kwargs):
-        raise NotImplementedError
+    def on_predict(self, args, state, control, metrics=None, **kwargs):
+        pass
 
-    async def train(self, job_id: str, dataset_path: str, peft_config, training_args: dict):
-        raise NotImplementedError
 
-    def get_model_info(self, model_id: str):
-        raise NotImplementedError
+# 全局单例
+text_engine = TextEngine()

+ 178 - 11
backend/app/engines/vision_engine.py

@@ -1,20 +1,187 @@
+import json
+from pathlib import Path
+from typing import Any
+
+from app.config import get_settings
+from app.core.logging import logger
 from app.engines.base import BaseEngine
 
+settings = get_settings()
+
 
 class VisionEngine(BaseEngine):
-    """Training engine for ViT, CLIP, and other vision models."""
+    """视觉模型训练引擎 (ViT/CLIP/图像分类)。"""
+
+    def __init__(self):
+        self._processor = None
+        self._model = None
+
+    async def load_model(self, model_id: str, **kwargs: Any) -> None:
+        """下载并加载视觉模型。"""
+        import torch
+        from transformers import AutoImageProcessor, AutoModelForImageClassification
+
+        local_path = str(settings.models_dir / model_id.replace("/", "_"))
+
+        if not (Path(local_path) / "config.json").exists():
+            from huggingface_hub import snapshot_download
+            snapshot_download(repo_id=model_id, local_dir=local_path, local_dir_use_symlinks=False)
+
+        self._processor = AutoImageProcessor.from_pretrained(local_path, trust_remote_code=True)
+        self._model = AutoModelForImageClassification.from_pretrained(
+            local_path,
+            torch_dtype=torch.float16,
+            device_map="auto",
+            trust_remote_code=True,
+        )
+        logger.info(f"Loaded vision model: {model_id}")
+
+    def get_peft_config(self, method: str, params: dict[str, Any]) -> Any:
+        from peft import LoraConfig, TaskType
+
+        target_modules = params.get("lora_target_modules", "all-linear")
+        if isinstance(target_modules, str) and target_modules == "all-linear":
+            target_modules = ["linear", "q_proj", "v_proj"]
+
+        return LoraConfig(
+            r=params.get("lora_r", 16),
+            lora_alpha=params.get("lora_alpha", 32),
+            lora_dropout=params.get("lora_dropout", 0.05),
+            target_modules=target_modules,
+            task_type=TaskType.IMAGE_CLS,
+        )
+
+    async def preprocess_dataset(
+        self, dataset_path: str, output_path: str, **kwargs: Any
+    ) -> str:
+        """图像数据集预处理(提取 image_path + label)。"""
+        from app.preprocessors import preprocess_file
+
+        processed = preprocess_file(dataset_path, output_path, "sft", "raw")
+        logger.info(f"Preprocessed {len(processed)} vision samples")
+        return output_path
+
+    async def train(
+        self,
+        job_id: str,
+        dataset_path: str,
+        peft_config: Any,
+        training_args: dict[str, Any],
+    ) -> str:
+        from peft import get_peft_model
+        from transformers import DataCollatorWithPadding, Trainer, TrainingArguments
+        from datasets import Dataset as HFDataset
+
+        # Load and preprocess data
+        data = []
+        with open(dataset_path, "r", encoding="utf-8") as f:
+            for line in f:
+                line = line.strip()
+                if line:
+                    data.append(json.loads(line))
+
+        def transform(examples):
+            images = []
+            labels = []
+            for item in examples:
+                if "image_path" in item and Path(item["image_path"]).exists():
+                    from PIL import Image
+                    images.append(self._processor(Image.open(item["image_path"]).convert("RGB"))["pixel_values"])
+                    labels.append(int(item.get("label", 0)))
+                elif "text" in item:
+                    # fallback: use text as label for classification
+                    labels.append(item.get("label", 0))
+            if images:
+                return {"pixel_values": images, "labels": labels}
+            return {"pixel_values": [], "labels": []}
+
+        hf_dataset = HFDataset.from_list(data)
+        hf_dataset.set_transform(transform)
+
+        self._model = get_peft_model(self._model, peft_config)
+        self._model.print_trainable_parameters()
+
+        output_dir = str(settings.adapters_dir / job_id)
+        epochs = training_args.get("epochs", 3)
+        batch_size = training_args.get("batch_size", 4)
+        learning_rate = training_args.get("learning_rate", 2e-4)
+
+        tr_args = TrainingArguments(
+            output_dir=output_dir,
+            num_train_epochs=epochs,
+            per_device_train_batch_size=batch_size,
+            learning_rate=learning_rate,
+            save_strategy="epoch",
+            logging_steps=10,
+            fp16=True,
+            optim="adamw_torch",
+            remove_unused_columns=False,
+            report_to="none",
+        )
+
+        callback = _ProgressCallback(job_id)
+        trainer = Trainer(
+            model=self._model,
+            args=tr_args,
+            train_dataset=hf_dataset,
+            data_collator=DataCollatorWithPadding(self._processor),
+            callbacks=[callback],
+        )
+
+        try:
+            trainer.train()
+            self._model.save_pretrained(output_dir)
+            self._processor.save_pretrained(output_dir)
+            logger.info(f"Vision training completed for job {job_id}")
+        except Exception as e:
+            logger.error(f"Vision training failed for job {job_id}: {e}")
+            raise
+
+        return output_dir
+
+    def get_model_info(self, model_id: str) -> dict[str, Any]:
+        model_dir = settings.models_dir / model_id.replace("/", "_")
+        config_path = model_dir / "config.json"
+        if config_path.exists():
+            with open(config_path) as f:
+                config = json.load(f)
+            return {
+                "model_type": config.get("model_type", "vision"),
+                "context_length": config.get("max_position_embeddings", 2048),
+                "hidden_size": config.get("hidden_size", 0),
+                "num_layers": config.get("num_hidden_layers", 0),
+            }
+        return {"model_type": "vision", "context_length": 2048}
+
+
+class _ProgressCallback:
+    def __init__(self, job_id: str):
+        self.job_id = job_id
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        if logs and "loss" in logs:
+            import asyncio
+            asyncio.create_task(
+                send_progress(self.job_id, epoch=int(state.epoch or 0), step=state.global_step,
+                              total_steps=state.max_steps or 0, loss=logs["loss"], learning_rate=logs.get("learning_rate", 0))
+            )
+
+    def on_epoch_end(self, args, state, control, **kwargs):
+        import asyncio
+        asyncio.create_task(send_epoch_done(self.job_id, epoch=int(state.epoch or 0), eval_loss=None, eval_accuracy=None))
 
-    async def load_model(self, model_id: str, **kwargs):
-        raise NotImplementedError
+    def on_train_end(self, args, state, control, **kwargs):
+        import asyncio
+        asyncio.create_task(send_completed(self.job_id, total_time_seconds=getattr(state, "train_runtime", 0),
+                                           adapter_path=str(settings.adapters_dir / self.job_id)))
 
-    def get_peft_config(self, method: str, params: dict):
-        raise NotImplementedError
+    def on_train_begin(self, args, state, control, **kwargs): pass
+    def on_step_end(self, args, state, control, **kwargs): pass
+    def on_evaluate(self, args, state, control, metrics=None, **kwargs): pass
+    def on_save(self, args, state, control, **kwargs): pass
+    def on_predict(self, args, state, control, metrics=None, **kwargs): pass
 
-    async def preprocess_dataset(self, dataset_path: str, output_path: str, **kwargs):
-        raise NotImplementedError
 
-    async def train(self, job_id: str, dataset_path: str, peft_config, training_args: dict):
-        raise NotImplementedError
+from app.core.websocket import send_completed, send_epoch_done, send_progress
 
-    def get_model_info(self, model_id: str):
-        raise NotImplementedError
+vision_engine = VisionEngine()

+ 68 - 24
backend/app/peft/__init__.py

@@ -2,40 +2,84 @@ from typing import Any
 
 
 def build_lora_config(params: dict[str, Any]):
-    """Build LoRA config dict from parameters."""
-    return {
-        "r": params.get("lora_r", 16),
-        "lora_alpha": params.get("lora_alpha", 32),
-        "lora_dropout": params.get("lora_dropout", 0.05),
-        "target_modules": params.get("lora_target_modules", "all-linear"),
-    }
+    """返回实际的 peft.LoraConfig 对象。"""
+    from peft import LoraConfig, TaskType
+
+    target_modules = params.get("lora_target_modules", "all-linear")
+    if isinstance(target_modules, str):
+        if target_modules == "all-linear":
+            target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
+
+    return LoraConfig(
+        r=params.get("lora_r", 16),
+        lora_alpha=params.get("lora_alpha", 32),
+        lora_dropout=params.get("lora_dropout", 0.05),
+        target_modules=target_modules,
+        task_type=TaskType.CAUSAL_LM,
+    )
 
 
 def build_qlora_config(params: dict[str, Any]):
-    """Build QLoRA config dict from parameters."""
-    return {
-        "bits": params.get("qlora_bits", 4),
-        "qlora_type": params.get("qlora_type", "nf4"),
-        "double_quant": params.get("qlora_double_quant", True),
-        "lora": build_lora_config(params),
+    """返回 (bitsandbytes_config, peft.LoraConfig) 二元组。"""
+    from peft import LoraConfig, TaskType
+
+    bnb_params = {
+        "load_in_4bit": params.get("qlora_bits", 4) == 4,
+        "load_in_8bit": params.get("qlora_bits", 4) == 8,
+        "bnb_4bit_quant_type": params.get("qlora_type", "nf4"),
+        "bnb_4bit_use_double_quant": params.get("qlora_double_quant", True),
+        "bnb_4bit_compute_dtype": "float16",
     }
 
+    target_modules = params.get("lora_target_modules", "all-linear")
+    if isinstance(target_modules, str) and target_modules == "all-linear":
+        target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
+
+    lora_cfg = LoraConfig(
+        r=params.get("lora_r", 16),
+        lora_alpha=params.get("lora_alpha", 32),
+        lora_dropout=params.get("lora_dropout", 0.05),
+        target_modules=target_modules,
+        task_type=TaskType.CAUSAL_LM,
+    )
+
+    return bnb_params, lora_cfg
+
 
 def build_ia3_config(params: dict[str, Any]):
-    return {"target_modules": params.get("ia3_target_modules", "all-linear")}
+    """返回实际的 peft.IA3Config 对象。"""
+    from peft import IA3Config, TaskType
+
+    target_modules = params.get("ia3_target_modules", "all-linear")
+    if isinstance(target_modules, str) and target_modules == "all-linear":
+        target_modules = ["k_proj", "v_proj", "ffn"]
+
+    return IA3Config(
+        target_modules=target_modules,
+        task_type=TaskType.CAUSAL_LM,
+    )
 
 
 def build_adalora_config(params: dict[str, Any]):
-    return {
-        "init_r": params.get("adalora_init_r", 8),
-        "target_r": params.get("adalora_target_r", 16),
-        "beta1": params.get("adalora_beta1", 0.85),
-        "beta2": params.get("adalora_beta2", 0.85),
-    }
+    """返回实际的 peft.AdaLoraConfig 对象。"""
+    from peft import AdaLoraConfig, TaskType
+
+    return AdaLoraConfig(
+        init_r=params.get("adalora_init_r", 8),
+        target_r=params.get("adalora_target_r", 16),
+        beta1=params.get("adalora_beta1", 0.85),
+        beta2=params.get("adalora_beta2", 0.85),
+        task_type=TaskType.CAUSAL_LM,
+    )
 
 
 def build_prefix_tuning_config(params: dict[str, Any]):
-    return {
-        "num_virtual_tokens": params.get("prefix_num_virtual_tokens", 20),
-        "encoder_hidden_size": params.get("prefix_encoder_hidden_size", 128),
-    }
+    """返回实际的 peft.PromptTuningConfig 对象。"""
+    from peft import PromptTuningConfig, PromptTuningInit, TaskType
+
+    return PromptTuningConfig(
+        num_virtual_tokens=params.get("prefix_num_virtual_tokens", 20),
+        prompt_tuning_init=PromptTuningInit.TEXT,
+        prompt_tuning_init_text="Classify the following text: ",
+        task_type=TaskType.CAUSAL_LM,
+    )

+ 161 - 0
backend/app/preprocessors/__init__.py

@@ -0,0 +1,161 @@
+"""数据预处理器:将不同格式的数据集转换为训练所需格式。"""
+
+import json
+from pathlib import Path
+from typing import Any
+
+
+def apply_alpaca_template(item: dict) -> dict:
+    """Alpaca 模板: instruction + input -> output。"""
+    instruction = item.get("instruction", "")
+    input_text = item.get("input", "")
+    output = item.get("output", "")
+    prompt = f"{instruction}\n\n{input_text}" if input_text else instruction
+    return {"prompt": prompt, "completion": output}
+
+
+def apply_sharegpt_template(item: dict) -> dict:
+    """ShareGPT 模板: conversations list -> formatted prompt + completion。"""
+    conversations = item.get("conversations", [])
+    if len(conversations) < 2:
+        return {"prompt": "", "completion": ""}
+
+    prompt_parts = []
+    completion = ""
+    for i, turn in enumerate(conversations):
+        role = turn.get("from", turn.get("role", "human"))
+        content = turn.get("value", turn.get("content", ""))
+        if i == 0:
+            prompt_parts.append(content)
+        elif i == 1:
+            completion = content
+            break
+        else:
+            prompt_parts.append(f"{role}: {content}")
+
+    prompt = "\n".join(prompt_parts)
+    return {"prompt": prompt, "completion": completion}
+
+
+def apply_raw_template(item: dict) -> dict:
+    """Raw 模板: 直接读取 prompt/text 和 completion/output 字段。"""
+    prompt = item.get("prompt", item.get("text", item.get("input", "")))
+    completion = item.get("completion", item.get("output", item.get("target", "")))
+    return {"prompt": str(prompt), "completion": str(completion)}
+
+
+def apply_dpo_template(item: dict) -> dict:
+    """DPO 模板: prompt + chosen + rejected。"""
+    return {
+        "prompt": item.get("prompt", item.get("input", "")),
+        "chosen": item.get("chosen", item.get("positive", "")),
+        "rejected": item.get("rejected", item.get("negative", "")),
+    }
+
+
+def apply_kto_template(item: dict) -> dict:
+    """KTO 模板: prompt + completion + label。"""
+    return {
+        "prompt": item.get("prompt", item.get("input", "")),
+        "completion": item.get("completion", item.get("output", "")),
+        "label": item.get("label", True),
+    }
+
+
+def apply_orpo_template(item: dict) -> dict:
+    """ORPO 模板: prompt + chosen + rejected (类似 DPO)。"""
+    return {
+        "prompt": item.get("prompt", item.get("input", "")),
+        "chosen": item.get("chosen", item.get("positive", "")),
+        "rejected": item.get("rejected", item.get("negative", "")),
+    }
+
+
+def apply_rm_template(item: dict) -> dict:
+    """Reward Modeling 模板: prompt + chosen + rejected。"""
+    return {
+        "prompt": item.get("prompt", item.get("input", "")),
+        "chosen": item.get("chosen", item.get("positive", "")),
+        "rejected": item.get("rejected", item.get("negative", "")),
+    }
+
+
+TEMPLATE_MAP = {
+    "sft": {
+        "alpaca": apply_alpaca_template,
+        "sharegpt": apply_sharegpt_template,
+        "raw": apply_raw_template,
+    },
+    "dpo": {
+        "alpaca": apply_dpo_template,
+        "sharegpt": apply_dpo_template,
+        "raw": apply_dpo_template,
+    },
+    "kto": {
+        "raw": apply_kto_template,
+    },
+    "orpo": {
+        "alpaca": apply_orpo_template,
+        "raw": apply_orpo_template,
+    },
+    "rm": {
+        "raw": apply_rm_template,
+    },
+    "ppo": {
+        "raw": apply_raw_template,
+    },
+}
+
+
+def preprocess_file(
+    input_path: str,
+    output_path: str,
+    task_type: str = "sft",
+    template: str = "alpaca",
+) -> list[dict[str, Any]]:
+    """读取文件并应用模板,返回处理后的数据列表。"""
+    input_p = Path(input_path)
+    ext = input_p.suffix.lower()
+
+    # 读取原始数据
+    if ext == ".jsonl":
+        with open(input_path, "r", encoding="utf-8") as f:
+            raw_data = [json.loads(line) for line in f if line.strip()]
+    elif ext == ".json":
+        with open(input_path, "r", encoding="utf-8") as f:
+            data = json.load(f)
+            raw_data = data if isinstance(data, list) else [data]
+    elif ext == ".csv":
+        import csv
+        with open(input_path, "r", encoding="utf-8") as f:
+            reader = csv.DictReader(f)
+            raw_data = [dict(row) for row in reader]
+    elif ext == ".parquet":
+        import pandas as pd
+        df = pd.read_parquet(input_path)
+        raw_data = df.to_dict(orient="records")
+    else:
+        raise ValueError(f"Unsupported format: {ext}")
+
+    # 获取模板函数
+    templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
+    apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
+
+    # 应用模板
+    processed = []
+    for item in raw_data:
+        try:
+            result = apply_fn(item)
+            if result.get("prompt"):
+                processed.append(result)
+        except Exception:
+            continue
+
+    # 写入处理后的数据
+    output_p = Path(output_path)
+    output_p.parent.mkdir(parents=True, exist_ok=True)
+    with open(output_path, "w", encoding="utf-8") as f:
+        for item in processed:
+            f.write(json.dumps(item, ensure_ascii=False) + "\n")
+
+    return processed

+ 222 - 31
backend/app/services/dataset_service.py

@@ -1,8 +1,13 @@
+import json
+import uuid
+from datetime import datetime, timezone
 from pathlib import Path
 from typing import Any
 
 from fastapi import UploadFile
+
 from app.config import get_settings
+from app.core.db import async_session, DatasetRecord
 from app.core.logging import logger
 from app.schemas.dataset import DatasetDownloadRequest, DatasetDownloadResponse
 
@@ -11,46 +16,185 @@ settings = get_settings()
 
 async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadResponse:
     """从 HuggingFace 或 ModelScope 下载数据集。"""
-    import os
-    import uuid
-
-    download_dir = settings.processed_dir
-    download_dir.mkdir(parents=True, exist_ok=True)
-
-    if req.use_modelscope:
-        try:
-            from modelscope.msdatasets import MsDataset
-            MsDataset.load(req.dataset_id, split="train")
-            path = str(download_dir / f"ms_{req.dataset_id.replace('/', '_')}")
-            logger.info(f"Downloaded dataset from ModelScope: {req.dataset_id}")
-            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
-        except Exception as e:
-            logger.error(f"ModelScope dataset download failed: {e}")
-            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
-    else:
-        try:
-            from datasets import load_dataset
-            load_dataset(req.dataset_id)
-            path = str(download_dir / f"hf_{req.dataset_id.replace('/', '_')}")
-            logger.info(f"Downloaded dataset from HuggingFace: {req.dataset_id}")
-            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="downloading", path=path)
-        except Exception as e:
-            logger.error(f"HuggingFace dataset download failed: {e}")
-            return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
+    try:
+        from datasets import load_dataset
+
+        ds = load_dataset(req.dataset_id)
+        ds_dir = settings.processed_dir / f"hf_{req.dataset_id.replace('/', '_')}"
+        ds_dir.mkdir(parents=True, exist_ok=True)
+        # 保存为 JSONL
+        if "train" in ds:
+            split = ds["train"]
+        else:
+            split = ds[list(ds.keys())[0]]
+        output_path = ds_dir / "data.jsonl"
+        with open(output_path, "w", encoding="utf-8") as f:
+            for item in split:
+                f.write(json.dumps(item, ensure_ascii=False) + "\n")
+
+        # 写入数据库
+        record = DatasetRecord(
+            id=str(uuid.uuid4()),
+            name=req.dataset_id,
+            format="jsonl",
+            record_count=len(split),
+            file_path=str(output_path),
+            created_at=datetime.now(timezone.utc),
+        )
+        async with async_session() as session:
+            session.add(record)
+            await session.commit()
+
+        logger.info(f"Downloaded dataset: {req.dataset_id} ({len(split)} records)")
+        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="completed", path=str(output_path))
+    except Exception as e:
+        logger.error(f"Dataset download failed: {e}")
+        return DatasetDownloadResponse(dataset_id=req.dataset_id, status="failed", error=str(e))
 
 
 async def upload_dataset(file: UploadFile) -> dict[str, Any]:
-    """保存上传文件并检测格式。"""
+    """保存上传文件并写入数据库。"""
     upload_dir = settings.uploads_dir
     upload_dir.mkdir(parents=True, exist_ok=True)
 
-    file_path = upload_dir / file.filename
+    # 避免文件名冲突
+    safe_name = file.filename or "unknown"
+    file_path = upload_dir / safe_name
+    if file_path.exists():
+        file_path = upload_dir / f"{uuid.uuid4().hex}_{safe_name}"
+
     content = await file.read()
     file_path.write_bytes(content)
 
-    fmt = _detect_format(file.filename or "")
-    logger.info(f"Uploaded dataset: {file_path} (format={fmt})")
-    return {"path": str(file_path), "format": fmt, "size": len(content)}
+    fmt = _detect_format(safe_name)
+    record_count = _count_records(file_path, fmt)
+
+    record_id = str(uuid.uuid4())
+    record = DatasetRecord(
+        id=record_id,
+        name=safe_name,
+        format=fmt,
+        record_count=record_count,
+        file_path=str(file_path),
+        created_at=datetime.now(timezone.utc),
+    )
+    async with async_session() as session:
+        session.add(record)
+        await session.commit()
+
+    logger.info(f"Uploaded dataset: {safe_name} ({record_count} records, format={fmt})")
+    return {
+        "id": record_id,
+        "name": safe_name,
+        "format": fmt,
+        "record_count": record_count,
+        "file_path": str(file_path),
+        "created_at": record.created_at.isoformat(),
+    }
+
+
+async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
+    """预览数据集前 N 行。"""
+    async with async_session() as session:
+        from sqlalchemy import select
+
+        result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
+        record = result.scalar_one_or_none()
+        if not record:
+            return {"total_records": 0, "preview_rows": [], "columns": []}
+
+    file_path = Path(record.file_path)
+    if not file_path.exists():
+        return {"total_records": 0, "preview_rows": [], "columns": []}
+
+    fmt = record.format
+    preview_data = _read_records(file_path, fmt, rows)
+    columns = list(preview_data[0].keys()) if preview_data else []
+
+    return {
+        "total_records": record.record_count,
+        "preview_rows": [{"row_index": i, "data": row} for i, row in enumerate(preview_data)],
+        "columns": columns,
+    }
+
+
+async def validate_dataset(dataset_id: str) -> dict[str, Any]:
+    """校验数据集格式和 Schema。"""
+    async with async_session() as session:
+        from sqlalchemy import select
+
+        result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
+        record = result.scalar_one_or_none()
+        if not record:
+            return {"is_valid": False, "errors": ["Dataset not found"], "warnings": []}
+
+    file_path = Path(record.file_path)
+    if not file_path.exists():
+        return {"is_valid": False, "errors": ["File not found"], "warnings": []}
+
+    errors = []
+    warnings = []
+
+    # 检查格式
+    fmt = record.format
+    if fmt not in ("jsonl", "csv", "json", "parquet"):
+        errors.append(f"Unsupported format: {fmt}")
+
+    # 检查内容
+    try:
+        preview = _read_records(file_path, fmt, 5)
+        if not preview:
+            warnings.append("Dataset appears to be empty")
+        else:
+            # 检查必需字段(SFT 格式)
+            first = preview[0]
+            has_sft_fields = any(k in first for k in ("instruction", "prompt", "text", "input", "output", "completion"))
+            if not has_sft_fields:
+                warnings.append(f"No common SFT fields found. Keys: {list(first.keys())}")
+    except Exception as e:
+        errors.append(f"Failed to read file: {str(e)}")
+
+    return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
+
+
+async def list_datasets() -> list[dict[str, Any]]:
+    """列出所有已上传数据集。"""
+    async with async_session() as session:
+        from sqlalchemy import select
+
+        result = await session.execute(select(DatasetRecord).order_by(DatasetRecord.created_at.desc()))
+        records = result.scalars().all()
+
+    return [
+        {
+            "id": r.id,
+            "name": r.name,
+            "format": r.format,
+            "record_count": r.record_count,
+            "file_path": r.file_path,
+            "created_at": r.created_at.isoformat(),
+        }
+        for r in records
+    ]
+
+
+async def delete_dataset(dataset_id: str) -> dict[str, Any]:
+    """删除数据集。"""
+    async with async_session() as session:
+        from sqlalchemy import select
+
+        result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
+        record = result.scalar_one_or_none()
+        if record:
+            # 删除文件
+            file_path = Path(record.file_path)
+            if file_path.exists():
+                file_path.unlink()
+            await session.delete(record)
+            await session.commit()
+            logger.info(f"Deleted dataset: {record.name}")
+
+    return {"status": "deleted"}
 
 
 def _detect_format(filename: str) -> str:
@@ -58,3 +202,50 @@ def _detect_format(filename: str) -> str:
     if ext in ("jsonl", "csv", "parquet", "json"):
         return ext
     return "unknown"
+
+
+def _count_records(file_path: Path, fmt: str) -> int:
+    try:
+        if fmt == "jsonl":
+            return sum(1 for line in open(file_path, encoding="utf-8") if line.strip())
+        elif fmt == "json":
+            with open(file_path, encoding="utf-8") as f:
+                data = json.load(f)
+                return len(data) if isinstance(data, list) else 1
+        elif fmt == "csv":
+            import csv
+            with open(file_path, encoding="utf-8") as f:
+                return sum(1 for _ in csv.reader(f)) - 1  # minus header
+        elif fmt == "parquet":
+            import pandas as pd
+            return len(pd.read_parquet(file_path))
+    except Exception:
+        pass
+    return 0
+
+
+def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
+    if fmt == "jsonl":
+        records = []
+        with open(file_path, encoding="utf-8") as f:
+            for i, line in enumerate(f):
+                if i >= n:
+                    break
+                line = line.strip()
+                if line:
+                    records.append(json.loads(line))
+        return records
+    elif fmt == "json":
+        with open(file_path, encoding="utf-8") as f:
+            data = json.load(f)
+            return data[:n] if isinstance(data, list) else [data]
+    elif fmt == "csv":
+        import csv
+        with open(file_path, encoding="utf-8") as f:
+            reader = csv.DictReader(f)
+            return [dict(row) for i, row in enumerate(reader) if i < n]
+    elif fmt == "parquet":
+        import pandas as pd
+        df = pd.read_parquet(file_path)
+        return df.head(n).to_dict(orient="records")
+    return []

+ 131 - 3
backend/app/services/deploy_service.py

@@ -1,12 +1,140 @@
+import uuid
+from datetime import datetime, timezone
+from pathlib import Path
 from typing import Any
 
 from app.config import get_settings
+from app.core.db import async_session, DeployTaskModel
 from app.core.logging import logger
 
 settings = get_settings()
 
 
 async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
-    """合并 adapter 与基础模型,并可选导出。"""
-    logger.info(f"Exporting adapter for job {job_id}")
-    return {"status": "exporting", "output_path": None}
+    """合并 adapter 与基础模型,并可选导出为 GGUF。"""
+    task_id = str(uuid.uuid4())
+    merge_with_base = config.get("merge_with_base", False)
+    export_format = config.get("export_format", "safetensors")
+
+    adapter_path = settings.adapters_dir / job_id
+    if not adapter_path.exists():
+        return {"job_id": job_id, "status": "failed", "output_path": None, "error": "Adapter not found"}
+
+    output_path = settings.adapters_dir / f"{job_id}_merged"
+
+    # 写入数据库
+    task = DeployTaskModel(
+        id=task_id,
+        job_id=job_id,
+        status="pending",
+        created_at=datetime.now(timezone.utc),
+    )
+    async with async_session() as session:
+        session.add(task)
+        await session.commit()
+
+    try:
+        import torch
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+
+        if merge_with_base:
+            # 加载 base model 并合并 adapter
+            base_model_id = _get_base_model_id(job_id)
+            if base_model_id:
+                base_model = AutoModelForCausalLM.from_pretrained(
+                    base_model_id, torch_dtype=torch.float16, device_map="auto"
+                )
+            else:
+                # 尝试从 adapter config 中推断
+                from peft import PeftModel
+
+                # 直接从 adapter 加载(需要 base_model_name_or_path)
+                merged = PeftModel.from_pretrained(
+                    AutoModelForCausalLM.from_pretrained(
+                        adapter_path / "adapter_config.json", torch_dtype=torch.float16
+                    ),
+                    adapter_path,
+                )
+                merged = merged.merge_and_unload()
+                merged.save_pretrained(output_path)
+                tokenizer = AutoTokenizer.from_pretrained(adapter_path)
+                tokenizer.save_pretrained(output_path)
+            logger.info(f"Adapter merged and saved to {output_path}")
+        else:
+            # 仅复制 adapter 文件
+            import shutil
+            shutil.copytree(adapter_path, output_path)
+            logger.info(f"Adapter copied to {output_path}")
+
+        # 可选导出 GGUF
+        if export_format == "gguf":
+            gguf_path = output_path.with_suffix(".gguf")
+            _export_to_gguf(output_path, gguf_path)
+
+        # 更新数据库
+        async with async_session() as session:
+            from sqlalchemy import select
+            result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
+            record = result.scalar_one_or_none()
+            if record:
+                record.status = "completed"
+                record.output_path = str(output_path)
+                await session.commit()
+
+        return {"job_id": job_id, "status": "completed", "output_path": str(output_path)}
+
+    except Exception as e:
+        logger.error(f"Export failed for job {job_id}: {e}")
+        async with async_session() as session:
+            from sqlalchemy import select
+            result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
+            record = result.scalar_one_or_none()
+            if record:
+                record.status = "failed"
+                record.error = str(e)
+                await session.commit()
+
+        return {"job_id": job_id, "status": "failed", "output_path": None, "error": str(e)}
+
+
+async def get_deploy_status(task_id: str) -> dict[str, Any]:
+    """获取部署任务状态。"""
+    async with async_session() as session:
+        from sqlalchemy import select
+        result = await session.execute(select(DeployTaskModel).where(DeployTaskModel.id == task_id))
+        record = result.scalar_one_or_none()
+        if record:
+            return {
+                "job_id": record.job_id,
+                "status": record.status,
+                "output_path": record.output_path,
+                "error": record.error,
+            }
+    return {"job_id": "", "status": "not_found", "output_path": None, "error": None}
+
+
+def _get_base_model_id(job_id: str) -> str | None:
+    """从 adapter config 中获取 base model ID。"""
+    config_path = settings.adapters_dir / job_id / "adapter_config.json"
+    if config_path.exists():
+        import json
+        with open(config_path) as f:
+            cfg = json.load(f)
+        return cfg.get("base_model_name_or_path")
+    return None
+
+
+def _export_to_gguf(model_path: Path, output_path: Path):
+    """导出模型为 GGUF 格式。"""
+    try:
+        from llama_cpp import Llama
+        # 使用 llama-cpp-python 的 convert 工具
+        import subprocess
+        result = subprocess.run(
+            ["python", "-m", "llama_cpp.convert_hf_to_gguf", str(model_path), "--outfile", str(output_path)],
+            capture_output=True, text=True, timeout=600,
+        )
+        if result.returncode != 0:
+            logger.error(f"GGUF export failed: {result.stderr}")
+    except Exception as e:
+        logger.warning(f"GGUF export not available: {e}")

+ 88 - 3
backend/app/services/eval_service.py

@@ -1,9 +1,94 @@
+import json
+import uuid
+from datetime import datetime, timezone
 from typing import Any
 
+from app.config import get_settings
+from app.core.db import async_session, EvalResultModel
 from app.core.logging import logger
+from sqlalchemy import select
+
+settings = get_settings()
 
 
 async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
-    """在已训练的 adapter 上运行评估。"""
-    logger.info(f"Running evaluation for job {job_id}")
-    return {"eval_id": "placeholder", "metrics": {}}
+    """在已训练的 adapter 上运行评估(perplexity)。"""
+    eval_id = str(uuid.uuid4())
+    adapter_path = settings.adapters_dir / job_id
+
+    if not adapter_path.exists():
+        return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"}
+
+    try:
+        import torch
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+
+        # 加载 base model + adapter
+        model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
+        tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
+
+        # 加载评估数据
+        async with async_session() as session:
+            from app.core.db import TrainingJobModel
+            result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
+            record = result.scalar_one_or_none()
+
+        if record:
+            dataset_path = record.dataset_id  # 这里简化处理,实际应从文件系统读取
+
+        metrics = {}
+        model.eval()
+
+        # 计算 perplexity(使用 adapter 自身的数据或默认样例)
+        sample_texts = [
+            "The quick brown fox jumps over the lazy dog.",
+            "Hello, how are you doing today?",
+        ]
+        losses = []
+        with torch.no_grad():
+            for text in sample_texts:
+                inputs = tokenizer(text, return_tensors="pt").to(model.device)
+                outputs = model(**inputs, labels=inputs["input_ids"])
+                losses.append(outputs.loss.item())
+
+        avg_loss = sum(losses) / len(losses) if losses else 0
+        perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0
+
+        metrics = {
+            "eval_loss": round(avg_loss, 4),
+            "perplexity": round(perplexity, 2),
+            "num_samples": len(sample_texts),
+        }
+
+        # 保存结果
+        eval_record = EvalResultModel(
+            id=eval_id,
+            job_id=job_id,
+            metrics=json.dumps(metrics),
+            created_at=datetime.now(timezone.utc),
+        )
+        async with async_session() as session:
+            session.add(eval_record)
+            await session.commit()
+
+        logger.info(f"Evaluation completed for job {job_id}: {metrics}")
+        return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()}
+
+    except Exception as e:
+        logger.error(f"Evaluation failed for job {job_id}: {e}")
+        return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)}
+
+
+async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
+    """获取已完成评估的结果。"""
+    async with async_session() as session:
+        result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
+        record = result.scalar_one_or_none()
+        if record:
+            return {
+                "id": record.id,
+                "job_id": record.job_id,
+                "metrics": json.loads(record.metrics) if record.metrics else {},
+                "created_at": record.created_at.isoformat(),
+            }
+    return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}

+ 102 - 4
backend/app/services/model_service.py

@@ -1,16 +1,62 @@
+import json
 from pathlib import Path
 from typing import Any
 
 from app.config import get_settings
+from app.core.db import async_session, ModelCache
 from app.core.logging import logger
+from sqlalchemy import select
 
 settings = get_settings()
 
 
 async def download_model(model_id: str, use_modelscope: bool = False) -> dict[str, Any]:
     """从 HF 或 ModelScope 下载模型到本地缓存。"""
-    logger.info(f"Downloading model {model_id} (modelscope={use_modelscope})")
-    return {"model_id": model_id, "status": "downloading"}
+    try:
+        if use_modelscope:
+            from modelscope import snapshot_download as ms_download
+
+            local_path = ms_download(model_id, cache_dir=str(settings.models_dir))
+        else:
+            from huggingface_hub import snapshot_download
+
+            local_path = snapshot_download(
+                repo_id=model_id,
+                local_dir=str(settings.models_dir / model_id.replace("/", "_")),
+                local_dir_use_symlinks=False,
+            )
+
+        # 读取 config.json 获取模型信息
+        config_path = Path(local_path) / "config.json"
+        model_type = "text"
+        context_length = 2048
+        peft_methods = "lora,qlora,ia3,adalora,prefix_tuning"
+
+        if config_path.exists():
+            with open(config_path) as f:
+                cfg = json.load(f)
+            model_type = cfg.get("model_type", "text")
+            context_length = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
+
+        # 写入数据库
+        async with async_session() as session:
+            record = ModelCache(
+                id=model_id,
+                name=model_id.split("/")[-1],
+                model_type=model_type,
+                path=local_path,
+                is_downloaded=1,
+                context_length=context_length,
+                supported_peft_methods=peft_methods,
+            )
+            session.add(record)
+            await session.commit()
+
+        logger.info(f"Model downloaded: {model_id} -> {local_path}")
+        return {"model_id": model_id, "status": "completed", "path": local_path}
+    except Exception as e:
+        logger.error(f"Model download failed: {e}")
+        return {"model_id": model_id, "status": "failed", "error": str(e)}
 
 
 def list_cached_models() -> list[dict[str, Any]]:
@@ -18,9 +64,61 @@ def list_cached_models() -> list[dict[str, Any]]:
     models_dir = settings.models_dir
     if not models_dir.exists():
         return []
-    return [{"id": d.name, "path": str(d)} for d in models_dir.iterdir() if d.is_dir()]
 
+    result = []
+    for d in models_dir.iterdir():
+        if not d.is_dir():
+            continue
+        config_path = d / "config.json"
+        info: dict[str, Any] = {
+            "id": d.name,
+            "name": d.name,
+            "model_type": "text",
+            "path": str(d),
+            "is_downloaded": True,
+            "context_length": None,
+            "supported_peft_methods": [],
+        }
+        if config_path.exists():
+            with open(config_path) as f:
+                cfg = json.load(f)
+            info["model_type"] = cfg.get("model_type", "text")
+            info["context_length"] = cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048))
+            info["supported_peft_methods"] = ["lora", "qlora", "ia3", "adalora", "prefix_tuning"]
+        result.append(info)
+    return result
 
-def get_model_info(model_id: str) -> dict[str, Any] | None:
+
+async def get_model_info(model_id: str) -> dict[str, Any] | None:
     """获取已缓存模型的元数据。"""
+    # 先查数据库
+    async with async_session() as session:
+        result = await session.execute(select(ModelCache).where(ModelCache.id == model_id))
+        record = result.scalar_one_or_none()
+        if record:
+            return {
+                "id": record.id,
+                "name": record.name,
+                "model_type": record.model_type,
+                "path": record.path,
+                "is_downloaded": bool(record.is_downloaded),
+                "context_length": record.context_length,
+                "supported_peft_methods": record.supported_peft_methods.split(",") if record.supported_peft_methods else [],
+            }
+
+    # 回退:直接从文件系统读取
+    model_dir = settings.models_dir / model_id.replace("/", "_")
+    config_path = model_dir / "config.json"
+    if config_path.exists():
+        with open(config_path) as f:
+            cfg = json.load(f)
+        return {
+            "id": model_id,
+            "name": model_id.split("/")[-1],
+            "model_type": cfg.get("model_type", "text"),
+            "path": str(model_dir),
+            "is_downloaded": True,
+            "context_length": cfg.get("max_position_embeddings", cfg.get("max_sequence_length", 2048)),
+            "supported_peft_methods": ["lora", "qlora", "ia3", "adalora", "prefix_tuning"],
+        }
     return None

+ 188 - 3
backend/app/services/training_service.py

@@ -1,18 +1,203 @@
+import asyncio
+import json
+import uuid
+from datetime import datetime, timezone
 from typing import Any
 
 from app.config import get_settings
+from app.core.db import async_session, TrainingJobModel
+from app.core.job_queue import JobStatus, TrainingJob, job_queue
 from app.core.logging import logger
+from sqlalchemy import select
 
 settings = get_settings()
 
 
 async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
     """校验配置、创建任务记录、加入队列。"""
-    logger.info(f"Creating training job: model={config.get('model_id')}")
-    return {"job_id": "placeholder", "status": "queued"}
+    job_id = str(uuid.uuid4())
+    model_id = config.get("model_id", "")
+    model_type = config.get("model_type", "text")
+    dataset_id = config.get("dataset_id", "")
+    peft_method = config.get("peft_method", "lora")
+    task_type = config.get("task_type", "sft")
+    dataset_template = config.get("dataset_template", "alpaca")
+
+    # 写入数据库
+    record = TrainingJobModel(
+        id=job_id,
+        model_id=model_id,
+        model_type=model_type,
+        dataset_id=dataset_id,
+        peft_method=peft_method,
+        task_type=task_type,
+        dataset_template=dataset_template,
+        status="pending",
+        epochs=config.get("epochs", 3),
+        batch_size=config.get("batch_size", 4),
+        gradient_accumulation=config.get("gradient_accumulation", 4),
+        learning_rate=config.get("learning_rate", 2e-4),
+        max_seq_length=config.get("max_seq_length", 2048),
+        warmup_ratio=config.get("warmup_ratio", 0.05),
+        save_strategy=config.get("save_strategy", "epoch"),
+        eval_strategy=config.get("eval_strategy", "epoch"),
+        eval_steps=config.get("eval_steps", 100),
+        lora_r=config.get("lora_r", 16),
+        lora_alpha=config.get("lora_alpha", 32),
+        lora_dropout=config.get("lora_dropout", 0.05),
+        lora_target_modules=config.get("lora_target_modules", "all-linear"),
+        qlora_bits=config.get("qlora_bits", 4),
+        created_at=datetime.now(timezone.utc),
+    )
+    async with async_session() as session:
+        session.add(record)
+        await session.commit()
+
+    # 加入 JobQueue
+    # 如果启用 DeepSpeed,生成配置文件
+    if config.get("deepspeed", False):
+        ds_config_path = _generate_deepspeed_config()
+        config["deepspeed"] = ds_config_path
+
+    job = TrainingJob(
+        id=job_id,
+        model_id=model_id,
+        model_type=model_type,
+        peft_method=peft_method,
+        dataset_id=dataset_id,
+        config=config,
+        status=JobStatus.PENDING,
+    )
+    await job_queue.enqueue(job_id, job)
+
+    logger.info(f"Training job created: {job_id}")
+    return {
+        "id": job_id,
+        "model_id": model_id,
+        "model_type": model_type,
+        "peft_method": peft_method,
+        "status": "pending",
+        "created_at": record.created_at.isoformat(),
+    }
+
+
+async def list_training_jobs() -> list[dict[str, Any]]:
+    """列出所有训练任务。"""
+    async with async_session() as session:
+        result = await session.execute(select(TrainingJobModel).order_by(TrainingJobModel.created_at.desc()))
+        records = result.scalars().all()
+
+    return [_job_to_dict(r) for r in records]
+
+
+async def get_training_job(job_id: str) -> dict[str, Any] | None:
+    """获取指定任务详情。"""
+    async with async_session() as session:
+        result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
+        record = result.scalar_one_or_none()
+        if record:
+            return _job_to_dict(record)
+    return None
 
 
 async def cancel_training_job(job_id: str) -> dict[str, Any]:
     """向运行中的任务发送取消信号。"""
-    logger.info(f"Cancelling job {job_id}")
+    await job_queue.cancel(job_id)
+
+    async with async_session() as session:
+        result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
+        record = result.scalar_one_or_none()
+        if record:
+            record.status = "cancelled"
+            record.finished_at = datetime.now(timezone.utc)
+            await session.commit()
+
+    logger.info(f"Job cancelled: {job_id}")
     return {"status": "cancelled"}
+
+
+async def update_job_in_db(job):
+    """JobQueue 回调:同步 job 状态到数据库。"""
+    try:
+        async with async_session() as session:
+            result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job.id))
+            record = result.scalar_one_or_none()
+            if record:
+                record.status = job.status.value if hasattr(job.status, "value") else str(job.status)
+                record.progress = job.progress
+                record.current_epoch = job.current_epoch
+                record.current_step = job.current_step
+                record.total_steps = job.total_steps
+                record.loss = job.loss
+                record.adapter_path = job.adapter_path
+                record.error_message = job.error_message
+                if job.status == JobStatus.TRAINING and not record.started_at:
+                    record.started_at = datetime.now(timezone.utc)
+                if job.status.is_terminal:
+                    record.finished_at = datetime.now(timezone.utc)
+                await session.commit()
+    except Exception as e:
+        logger.error(f"Failed to update job {job.id} in DB: {e}")
+
+
+def _job_to_dict(r) -> dict[str, Any]:
+    return {
+        "id": r.id,
+        "model_id": r.model_id,
+        "model_type": r.model_type,
+        "peft_method": r.peft_method,
+        "status": r.status,
+        "progress": r.progress or 0.0,
+        "current_epoch": r.current_epoch or 0,
+        "current_step": r.current_step or 0,
+        "total_steps": r.total_steps or 0,
+        "loss": r.loss,
+        "created_at": r.created_at.isoformat() if r.created_at else "",
+        "started_at": r.started_at.isoformat() if r.started_at else None,
+        "finished_at": r.finished_at.isoformat() if r.finished_at else None,
+        "error_message": r.error_message,
+        "adapter_path": r.adapter_path,
+    }
+
+
+def _generate_deepspeed_config(stage: int = 2) -> str:
+    """生成 DeepSpeed 配置文件,返回文件路径。"""
+    import json
+    from app.config import get_settings
+    settings = get_settings()
+
+    ds_config = {
+        "fp16": {"enabled": True},
+        "zero_optimization": {
+            "stage": stage,
+            "offload_optimizer": {"device": "cpu", "pin_memory": True},
+            "offload_param": {"device": "cpu", "pin_memory": True},
+            "overlap_comm": True,
+            "contiguous_gradients": True,
+            "reduce_bucket_size": "auto",
+            "stage3_prefetch_bucket_size": "auto",
+            "stage3_param_persistence_threshold": "auto",
+        } if stage == 3 else {
+            "stage": stage,
+            "offload_optimizer": {"device": "cpu", "pin_memory": True},
+            "allgather_partitions": True,
+            "allgather_bucket_size": 2e8,
+            "overlap_comm": True,
+            "reduce_scatter": True,
+            "reduce_bucket_size": 2e8,
+            "contiguous_gradients": True,
+        },
+        "gradient_accumulation_steps": "auto",
+        "gradient_clipping": "auto",
+        "steps_per_print": 10,
+        "train_batch_size": "auto",
+        "train_micro_batch_size_per_gpu": "auto",
+        "wall_clock_breakdown": False,
+    }
+
+    config_path = settings.data_dir / "deepspeed_config.json"
+    with open(config_path, "w") as f:
+        json.dump(ds_config, f, indent=2)
+
+    logger.info(f"DeepSpeed config generated: {config_path}")
+    return str(config_path)

+ 14 - 2
backend/main.py

@@ -10,10 +10,22 @@ settings = get_settings()
 
 @asynccontextmanager
 async def lifespan(app: FastAPI):
-    # 启动时:确保数据目录存在
+    # 启动时:确保数据目录存在 + 初始化数据库 + 启动 JobQueue
     settings.ensure_dirs()
+    from app.core.db import init_db
+
+    await init_db()
+
+    from app.core.job_queue import job_queue
+    from app.services.training_service import update_job_in_db
+
+    job_queue.register_callback(update_job_in_db)
+    await job_queue.start()
+
     yield
-    # 关闭时:清理资源(如有需要)
+
+    # 关闭时:停止 JobQueue
+    await job_queue.stop()
 
 
 def create_app() -> FastAPI:

+ 5 - 0
backend/requirements.txt

@@ -18,3 +18,8 @@ bitsandbytes>=0.44.0
 scipy>=1.14.0
 scikit-learn>=1.5.0
 pillow>=10.4.0
+huggingface_hub>=0.25.0
+pandas>=2.2.0
+pyarrow>=17.0.0
+sentencepiece>=0.2.0
+protobuf>=4.25.0

+ 3 - 0
frontend/src/api/client.ts

@@ -149,6 +149,8 @@ interface TrainingConfig {
   model_type: ModelType | string
   dataset_id: string
   peft_method: PeftMethod | string
+  task_type?: string
+  dataset_template?: string
   epochs?: number
   batch_size?: number
   gradient_accumulation?: number
@@ -163,6 +165,7 @@ interface TrainingConfig {
   lora_dropout?: number
   lora_target_modules?: string
   qlora_bits?: number
+  deepspeed?: boolean
 }
 
 interface EvalConfig {

+ 52 - 3
frontend/src/pages/Training.tsx

@@ -16,22 +16,44 @@ const PEFT_METHODS = [
   { value: 'prefix_tuning', label: 'Prefix Tuning' },
 ]
 
+const TASK_TYPES = [
+  { value: 'sft', label: 'SFT (监督微调)' },
+  { value: 'dpo', label: 'DPO (直接偏好优化)' },
+  { value: 'orpo', label: 'ORPO (比值偏好优化)' },
+  { value: 'kto', label: 'KTO (Kahneman-Tversky)' },
+  { value: 'rm', label: 'Reward Modeling' },
+  { value: 'ppo', label: 'PPO (强化学习)' },
+]
+
+const DATASET_TEMPLATES = [
+  { value: 'alpaca', label: 'Alpaca' },
+  { value: 'sharegpt', label: 'ShareGPT' },
+  { value: 'raw', label: 'Raw (直接字段)' },
+]
+
 export function Training() {
-  // Form state
   const [modelId, setModelId] = useState('')
   const [modelType, setModelType] = useState('text')
   const [datasetId, setDatasetId] = useState('')
   const [peftMethod, setPeftMethod] = useState('lora')
+  const [taskType, setTaskType] = useState('sft')
+  const [template, setTemplate] = useState('alpaca')
   const [epochs, setEpochs] = useState(3)
   const [batchSize, setBatchSize] = useState(4)
   const [lr, setLr] = useState('2e-4')
   const [loraR, setLoraR] = useState(16)
+  const [deepspeed, setDeepspeed] = useState(false)
 
-  // Job list
   const [jobs, setJobs] = useState<TrainingJob[]>([])
   const [loading, setLoading] = useState(false)
   const [submitting, setSubmitting] = useState(false)
 
+  // Connect WebSocket on mount
+  useEffect(() => {
+    wsManager.connect()
+    return () => wsManager.disconnect()
+  }, [])
+
   const fetchJobs = () => {
     setLoading(true)
     api.training.list()
@@ -42,6 +64,9 @@ export function Training() {
 
   useEffect(() => {
     fetchJobs()
+    // 每 5 秒轮询一次更新状态
+    const interval = setInterval(fetchJobs, 5000)
+    return () => clearInterval(interval)
   }, [])
 
   const handleCreate = () => {
@@ -52,11 +77,14 @@ export function Training() {
       model_type: modelType,
       dataset_id: datasetId,
       peft_method: peftMethod,
+      task_type: taskType,
+      dataset_template: template,
       epochs,
       batch_size: batchSize,
       learning_rate: parseFloat(lr),
       lora_r: loraR,
       lora_alpha: loraR * 2,
+      deepspeed: deepspeed,
     })
       .then(() => {
         setModelId('')
@@ -79,6 +107,7 @@ export function Training() {
       case 'failed': return '#e94560'
       case 'training': return '#2196f3'
       case 'pending': case 'queued': return '#ff9800'
+      case 'preprocessing': return '#9c27b0'
       case 'cancelled': return '#999'
       default: return '#666'
     }
@@ -106,6 +135,18 @@ export function Training() {
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据集 ID</label>
             <input value={datasetId} onChange={e => setDatasetId(e.target.value)} placeholder="数据集 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
           </div>
+          <div>
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练类型</label>
+            <select value={taskType} onChange={e => setTaskType(e.target.value)} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }}>
+              {TASK_TYPES.map(t => <option key={t.value} value={t.value}>{t.label}</option>)}
+            </select>
+          </div>
+          <div>
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据模板</label>
+            <select value={template} onChange={e => setTemplate(e.target.value)} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }}>
+              {DATASET_TEMPLATES.map(t => <option key={t.value} value={t.value}>{t.label}</option>)}
+            </select>
+          </div>
           <div>
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>PEFT 方法</label>
             <select value={peftMethod} onChange={e => setPeftMethod(e.target.value)} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }}>
@@ -128,6 +169,12 @@ export function Training() {
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>LoRA R</label>
             <input type="number" value={loraR} onChange={e => setLoraR(Number(e.target.value))} min={1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
           </div>
+          <div>
+            <label style={{ display: 'flex', alignItems: 'center', gap: 6, fontSize: 13, cursor: 'pointer' }}>
+              <input type="checkbox" checked={deepspeed} onChange={e => setDeepspeed(e.target.checked)} />
+              DeepSpeed 多 GPU
+            </label>
+          </div>
         </div>
         <button
           onClick={handleCreate}
@@ -158,6 +205,7 @@ export function Training() {
                 <th style={{ padding: '8px 0' }}>任务 ID</th>
                 <th>模型</th>
                 <th>PEFT</th>
+                <th>类型</th>
                 <th>状态</th>
                 <th>进度</th>
                 <th>Loss</th>
@@ -170,6 +218,7 @@ export function Training() {
                   <td style={{ padding: '8px 0', fontFamily: 'monospace', fontSize: 12 }}>{j.id.slice(0, 8)}...</td>
                   <td>{j.model_id}</td>
                   <td>{j.peft_method}</td>
+                  <td style={{ fontSize: 12 }}>{j.status === 'preprocessing' ? '预处理' : j.status === 'training' ? '训练中' : j.status}</td>
                   <td style={{ color: statusColor(j.status), fontWeight: 600 }}>{j.status}</td>
                   <td>
                     <div style={{ width: 120, height: 6, background: '#eee', borderRadius: 3, overflow: 'hidden' }}>
@@ -179,7 +228,7 @@ export function Training() {
                   </td>
                   <td>{j.loss?.toFixed(4) ?? '-'}</td>
                   <td>
-                    {(j.status === 'training' || j.status === 'pending' || j.status === 'queued') && (
+                    {(j.status === 'training' || j.status === 'pending' || j.status === 'queued' || j.status === 'preprocessing') && (
                       <button onClick={() => handleCancel(j.id)} style={{ padding: '2px 8px', color: '#e94560', border: '1px solid #e94560', borderRadius: 4, background: 'transparent', cursor: 'pointer' }}>取消</button>
                     )}
                   </td>