import asyncio from datetime import datetime, timezone from enum import Enum from typing import Any, Callable, Coroutine, Optional from pydantic import BaseModel, Field from app.core.logging import logger class JobStatus(str, Enum): PENDING = "pending" QUEUED = "queued" PREPROCESSING = "preprocessing" TRAINING = "training" COMPLETED = "completed" EVALUATING = "evaluating" EVALUATION_DONE = "evaluation_done" FAILED = "failed" CANCELLED = "cancelled" @property def is_terminal(self) -> bool: return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE) class TrainingJob(BaseModel): id: str model_id: str model_type: str peft_method: str dataset_id: str config: dict = Field(default_factory=dict) status: JobStatus = JobStatus.PENDING progress: float = 0.0 current_epoch: int = 0 current_step: int = 0 total_steps: int = 0 loss: float | None = None adapter_path: str | None = None error_message: str | None = None created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) started_at: str | None = None finished_at: str | None = None class JobQueue: """异步任务队列,支持取消和并发控制。""" def __init__(self, max_concurrent: int = 2): self._queue: asyncio.Queue[str] = asyncio.Queue() self._jobs: dict[str, TrainingJob] = {} self._cancel_events: dict[str, asyncio.Event] = {} self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = [] self._max_concurrent = max_concurrent self._workers: list[asyncio.Task] = [] self._running = False async def start(self): """启动后台 worker。""" if self._running: return self._running = True for _ in range(self._max_concurrent): worker = asyncio.create_task(self._worker_loop()) self._workers.append(worker) logger.info(f"JobQueue started with {self._max_concurrent} workers") async def stop(self): """停止所有 worker。""" self._running = False for event in self._cancel_events.values(): event.set() for worker in self._workers: worker.cancel() self._workers.clear() logger.info("JobQueue stopped") async def enqueue(self, job_id: str, job: TrainingJob): """将任务加入队列。""" self._jobs[job_id] = job self._cancel_events[job_id] = asyncio.Event() await self._queue.put(job_id) logger.info(f"Job {job_id} enqueued") async def dequeue(self) -> str: """从队列中取出任务 ID。""" return await self._queue.get() def mark_done(self, job_id: str): """标记任务完成。""" self._queue.task_done() self._cancel_events.pop(job_id, None) def get_job(self, job_id: str) -> Optional[TrainingJob]: return self._jobs.get(job_id) def update_job(self, job_id: str, **kwargs): if job_id in self._jobs: job = self._jobs[job_id] for key, val in kwargs.items(): if hasattr(job, key): setattr(job, key, val) def is_cancelled(self, job_id: str) -> bool: event = self._cancel_events.get(job_id) return event is not None and event.is_set() async def cancel(self, job_id: str): """取消任务。""" if job_id in self._cancel_events: self._cancel_events[job_id].set() self.update_job(job_id, status=JobStatus.CANCELLED) await self._notify_callbacks() logger.info(f"Job {job_id} cancelled") def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]): """注册状态变更回调(用于更新数据库等)。""" self._callbacks.append(callback) async def _notify_callbacks(self): for cb in self._callbacks: try: for job in self._jobs.values(): await cb(job) except Exception as e: logger.error(f"JobQueue callback error: {e}") async def _worker_loop(self): """worker 循环:不断从队列取任务并执行。""" while self._running: try: job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0) except asyncio.TimeoutError: continue try: await self._run_job(job_id) except Exception as e: logger.error(f"Job {job_id} failed: {e}") self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e)) finally: self._queue.task_done() async def _run_job(self, job_id: str): """执行单个任务:预处理 → 训练 → 完成。""" job = self._jobs.get(job_id) if not job: return self.update_job(job_id, status=JobStatus.QUEUED) await self._notify_callbacks() if self.is_cancelled(job_id): return self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat()) await self._notify_callbacks() if self.is_cancelled(job_id): return try: config = job.config model_id = job.model_id model_type = job.model_type peft_method = job.peft_method dataset_id = config.get("dataset_id", job.dataset_id) # 获取数据集文件路径 from app.config import get_settings settings = get_settings() # 查找数据集文件(优先查数据库) dataset_path = await self._lookup_dataset_db(dataset_id) if not dataset_path: dataset_path = self._find_dataset_path(dataset_id) if not dataset_path: raise FileNotFoundError(f"Dataset not found: {dataset_id}") # 预处理 processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "alpaca") # 选择引擎 engine = self._get_engine(model_type) # 预处理 await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) self.update_job(job_id, status=JobStatus.TRAINING) await self._notify_callbacks() # 加载模型 await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None) # 构建 PEFT 配置 peft_config = engine.get_peft_config(peft_method, config) # 训练 adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, ) self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path) await self._notify_callbacks() logger.info(f"Job {job_id} completed successfully") except asyncio.CancelledError: self.update_job(job_id, status=JobStatus.CANCELLED) await self._notify_callbacks() except Exception as e: logger.error(f"Job {job_id} failed: {e}") self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e)) await self._notify_callbacks() def _find_dataset_path(self, dataset_id: str) -> str | None: """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。""" from app.config import get_settings from pathlib import Path settings = get_settings() # 尝试从 uploads 目录查找 upload_path = settings.uploads_dir / dataset_id if upload_path.exists(): return str(upload_path) # 如果 dataset_id 本身是路径 if Path(dataset_id).exists(): return dataset_id return None async def _lookup_dataset_db(self, dataset_id: str) -> str | None: """从数据库查找数据集路径。""" from app.core.db import async_session, DatasetRecord from sqlalchemy import select async with async_session() as session: result = await session.execute(select(DatasetRecord).where( (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id) )) record = result.scalar_one_or_none() if record: return record.file_path return None def _get_engine(self, model_type: str): """根据模型类型选择训练引擎。""" if model_type == "vision": from app.engines.vision_engine import vision_engine return vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine return multimodal_engine else: from app.engines.text_engine import text_engine return text_engine @property def jobs(self) -> dict[str, TrainingJob]: return dict(self._jobs) # 全局单例 job_queue = JobQueue(max_concurrent=2)