| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- import asyncio
- import json
- from datetime import datetime, timezone
- from enum import Enum
- from typing import Any, Callable, Coroutine, Optional
- from pydantic import BaseModel, Field
- from app.core.logging import logger
- class JobStatus(str, Enum):
- PENDING = "pending"
- QUEUED = "queued"
- PREPROCESSING = "preprocessing"
- TRAINING = "training"
- COMPLETED = "completed"
- EVALUATING = "evaluating"
- EVALUATION_DONE = "evaluation_done"
- FAILED = "failed"
- CANCELLED = "cancelled"
- @property
- def is_terminal(self) -> bool:
- return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE)
- class TrainingJob(BaseModel):
- id: str
- model_id: str
- model_type: str
- peft_method: str
- dataset_id: str
- config: dict = Field(default_factory=dict)
- status: JobStatus = JobStatus.PENDING
- progress: float = 0.0
- current_epoch: int = 0
- current_step: int = 0
- total_steps: int = 0
- loss: float | None = None
- adapter_path: str | None = None
- error_message: str | None = None
- created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
- started_at: str | None = None
- finished_at: str | None = None
- class JobQueue:
- """异步任务队列,支持取消和并发控制。"""
- def __init__(self, max_concurrent: int = 2):
- self._queue: asyncio.Queue[str] = asyncio.Queue()
- self._jobs: dict[str, TrainingJob] = {}
- self._cancel_events: dict[str, asyncio.Event] = {}
- self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = []
- self._max_concurrent = max_concurrent
- self._workers: list[asyncio.Task] = []
- self._running = False
- async def start(self):
- """启动后台 worker。"""
- if self._running:
- return
- self._running = True
- for _ in range(self._max_concurrent):
- worker = asyncio.create_task(self._worker_loop())
- self._workers.append(worker)
- logger.info(f"JobQueue started with {self._max_concurrent} workers")
- async def stop(self):
- """停止所有 worker。"""
- self._running = False
- for event in self._cancel_events.values():
- event.set()
- for worker in self._workers:
- worker.cancel()
- self._workers.clear()
- logger.info("JobQueue stopped")
- async def enqueue(self, job_id: str, job: TrainingJob):
- """将任务加入队列。"""
- self._jobs[job_id] = job
- self._cancel_events[job_id] = asyncio.Event()
- await self._queue.put(job_id)
- logger.info(f"Job {job_id} enqueued")
- async def dequeue(self) -> str:
- """从队列中取出任务 ID。"""
- return await self._queue.get()
- def mark_done(self, job_id: str):
- """标记任务完成。"""
- self._queue.task_done()
- self._cancel_events.pop(job_id, None)
- def get_job(self, job_id: str) -> Optional[TrainingJob]:
- return self._jobs.get(job_id)
- def update_job(self, job_id: str, **kwargs):
- if job_id in self._jobs:
- job = self._jobs[job_id]
- for key, val in kwargs.items():
- if hasattr(job, key):
- setattr(job, key, val)
- def is_cancelled(self, job_id: str) -> bool:
- event = self._cancel_events.get(job_id)
- return event is not None and event.is_set()
- async def cancel(self, job_id: str):
- """取消任务。"""
- if job_id in self._cancel_events:
- self._cancel_events[job_id].set()
- self.update_job(job_id, status=JobStatus.CANCELLED)
- await self._notify_callbacks()
- logger.info(f"Job {job_id} cancelled")
- def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]):
- """注册状态变更回调(用于更新数据库等)。"""
- self._callbacks.append(callback)
- async def _notify_callbacks(self):
- for cb in self._callbacks:
- try:
- for job in self._jobs.values():
- await cb(job)
- except Exception as e:
- logger.error(f"JobQueue callback error: {e}")
- async def _worker_loop(self):
- """worker 循环:不断从队列取任务并执行。"""
- while self._running:
- try:
- job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0)
- except asyncio.TimeoutError:
- continue
- try:
- await self._run_job(job_id)
- except Exception as e:
- logger.error(f"Job {job_id} failed: {e}")
- self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
- finally:
- self._queue.task_done()
- async def _run_job(self, job_id: str):
- """执行单个任务:预处理 → 训练 → 完成。"""
- job = self._jobs.get(job_id)
- if not job:
- logger.warning(f"Job {job_id} not found in queue, skipping")
- return
- logger.info(f"Job {job_id}: Starting execution (status=QUEUED)")
- self.update_job(job_id, status=JobStatus.QUEUED)
- await self._notify_callbacks()
- if self.is_cancelled(job_id):
- return
- self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat())
- await self._notify_callbacks()
- if self.is_cancelled(job_id):
- return
- try:
- config = job.config
- model_id = job.model_id
- model_type = job.model_type
- peft_method = job.peft_method
- dataset_id = config.get("dataset_id", job.dataset_id)
- from app.config import get_settings
- settings = get_settings()
- # 查找数据集文件路径
- dataset_path = await self._lookup_dataset_db(dataset_id)
- if not dataset_path:
- dataset_path = self._find_dataset_path(dataset_id)
- if not dataset_path:
- raise FileNotFoundError(f"Dataset not found: {dataset_id}")
- # 选择引擎
- engine = self._get_engine(model_type)
- # 预处理数据集(始终在本地执行)
- processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
- task_type = config.get("task_type", "sft")
- template = config.get("dataset_template", "alpaca")
- logger.info(f"Job {job_id}: Preprocessing dataset (task_type={task_type}, template={template})")
- await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
- logger.info(f"Job {job_id}: Preprocessing completed, output: {processed_path}")
- logger.info(f"Job {job_id}: Loading model {model_id} (peft={peft_method})")
- await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
- logger.info(f"Job {job_id}: Model loaded, building PEFT config")
- peft_config = engine.get_peft_config(peft_method, config)
- logger.info(f"Job {job_id}: PEFT config built, starting training...")
- self.update_job(job_id, status=JobStatus.TRAINING)
- await self._notify_callbacks()
- logger.info(f"Job {job_id}: Calling engine.train()...")
- adapter_path = await engine.train(
- job_id=job_id,
- dataset_path=processed_path,
- peft_config=peft_config,
- training_args=config,
- )
- self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path)
- await self._notify_callbacks()
- logger.info(f"Job {job_id} completed successfully")
- except asyncio.CancelledError:
- self.update_job(job_id, status=JobStatus.CANCELLED)
- await self._notify_callbacks()
- except Exception as e:
- logger.error(f"Job {job_id} failed: {e}")
- self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
- await self._notify_callbacks()
- def _find_dataset_path(self, dataset_id: str) -> str | None:
- """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。"""
- from app.config import get_settings
- from pathlib import Path
- settings = get_settings()
- # 尝试从 uploads 目录查找
- upload_path = settings.uploads_dir / dataset_id
- if upload_path.exists():
- return str(upload_path)
- # 如果 dataset_id 本身是路径
- if Path(dataset_id).exists():
- return dataset_id
- return None
- async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
- """从数据库查找数据集路径。"""
- from app.core.db import async_session, DatasetRecord
- from sqlalchemy import select
- async with async_session() as session:
- result = await session.execute(select(DatasetRecord).where(
- (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
- ))
- record = result.scalar_one_or_none()
- if record:
- return record.file_path
- return None
- def _get_engine(self, model_type: str):
- """根据模型类型选择训练引擎。"""
- if model_type == "vision":
- from app.engines.vision_engine import vision_engine
- return vision_engine
- elif model_type == "multimodal":
- from app.engines.multimodal_engine import multimodal_engine
- return multimodal_engine
- else:
- from app.engines.text_engine import text_engine
- return text_engine
- @property
- def jobs(self) -> dict[str, TrainingJob]:
- return dict(self._jobs)
- # 全局单例
- job_queue = JobQueue(max_concurrent=2)
|