job_queue.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import asyncio
  2. from datetime import datetime, timezone
  3. from enum import Enum
  4. from typing import Any, Callable, Coroutine, Optional
  5. from pydantic import BaseModel, Field
  6. from app.core.logging import logger
  7. class JobStatus(str, Enum):
  8. PENDING = "pending"
  9. QUEUED = "queued"
  10. PREPROCESSING = "preprocessing"
  11. TRAINING = "training"
  12. COMPLETED = "completed"
  13. EVALUATING = "evaluating"
  14. EVALUATION_DONE = "evaluation_done"
  15. FAILED = "failed"
  16. CANCELLED = "cancelled"
  17. @property
  18. def is_terminal(self) -> bool:
  19. return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE)
  20. class TrainingJob(BaseModel):
  21. id: str
  22. model_id: str
  23. model_type: str
  24. peft_method: str
  25. dataset_id: str
  26. config: dict = Field(default_factory=dict)
  27. status: JobStatus = JobStatus.PENDING
  28. progress: float = 0.0
  29. current_epoch: int = 0
  30. current_step: int = 0
  31. total_steps: int = 0
  32. loss: float | None = None
  33. adapter_path: str | None = None
  34. error_message: str | None = None
  35. created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
  36. started_at: str | None = None
  37. finished_at: str | None = None
  38. class JobQueue:
  39. """异步任务队列,支持取消和并发控制。"""
  40. def __init__(self, max_concurrent: int = 2):
  41. self._queue: asyncio.Queue[str] = asyncio.Queue()
  42. self._jobs: dict[str, TrainingJob] = {}
  43. self._cancel_events: dict[str, asyncio.Event] = {}
  44. self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = []
  45. self._max_concurrent = max_concurrent
  46. self._workers: list[asyncio.Task] = []
  47. self._running = False
  48. async def start(self):
  49. """启动后台 worker。"""
  50. if self._running:
  51. return
  52. self._running = True
  53. for _ in range(self._max_concurrent):
  54. worker = asyncio.create_task(self._worker_loop())
  55. self._workers.append(worker)
  56. logger.info(f"JobQueue started with {self._max_concurrent} workers")
  57. async def stop(self):
  58. """停止所有 worker。"""
  59. self._running = False
  60. for event in self._cancel_events.values():
  61. event.set()
  62. for worker in self._workers:
  63. worker.cancel()
  64. self._workers.clear()
  65. logger.info("JobQueue stopped")
  66. async def enqueue(self, job_id: str, job: TrainingJob):
  67. """将任务加入队列。"""
  68. self._jobs[job_id] = job
  69. self._cancel_events[job_id] = asyncio.Event()
  70. await self._queue.put(job_id)
  71. logger.info(f"Job {job_id} enqueued")
  72. async def dequeue(self) -> str:
  73. """从队列中取出任务 ID。"""
  74. return await self._queue.get()
  75. def mark_done(self, job_id: str):
  76. """标记任务完成。"""
  77. self._queue.task_done()
  78. self._cancel_events.pop(job_id, None)
  79. def get_job(self, job_id: str) -> Optional[TrainingJob]:
  80. return self._jobs.get(job_id)
  81. def update_job(self, job_id: str, **kwargs):
  82. if job_id in self._jobs:
  83. job = self._jobs[job_id]
  84. for key, val in kwargs.items():
  85. if hasattr(job, key):
  86. setattr(job, key, val)
  87. def is_cancelled(self, job_id: str) -> bool:
  88. event = self._cancel_events.get(job_id)
  89. return event is not None and event.is_set()
  90. async def cancel(self, job_id: str):
  91. """取消任务。"""
  92. if job_id in self._cancel_events:
  93. self._cancel_events[job_id].set()
  94. self.update_job(job_id, status=JobStatus.CANCELLED)
  95. await self._notify_callbacks()
  96. logger.info(f"Job {job_id} cancelled")
  97. def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]):
  98. """注册状态变更回调(用于更新数据库等)。"""
  99. self._callbacks.append(callback)
  100. async def _notify_callbacks(self):
  101. for cb in self._callbacks:
  102. try:
  103. for job in self._jobs.values():
  104. await cb(job)
  105. except Exception as e:
  106. logger.error(f"JobQueue callback error: {e}")
  107. async def _worker_loop(self):
  108. """worker 循环:不断从队列取任务并执行。"""
  109. while self._running:
  110. try:
  111. job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0)
  112. except asyncio.TimeoutError:
  113. continue
  114. try:
  115. await self._run_job(job_id)
  116. except Exception as e:
  117. logger.error(f"Job {job_id} failed: {e}")
  118. self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
  119. finally:
  120. self._queue.task_done()
  121. async def _run_job(self, job_id: str):
  122. """执行单个任务:预处理 → 训练 → 完成。"""
  123. job = self._jobs.get(job_id)
  124. if not job:
  125. return
  126. self.update_job(job_id, status=JobStatus.QUEUED)
  127. await self._notify_callbacks()
  128. if self.is_cancelled(job_id):
  129. return
  130. self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat())
  131. await self._notify_callbacks()
  132. if self.is_cancelled(job_id):
  133. return
  134. try:
  135. config = job.config
  136. model_id = job.model_id
  137. model_type = job.model_type
  138. peft_method = job.peft_method
  139. dataset_id = config.get("dataset_id", job.dataset_id)
  140. # 获取数据集文件路径
  141. from app.config import get_settings
  142. settings = get_settings()
  143. # 查找数据集文件(优先查数据库)
  144. dataset_path = await self._lookup_dataset_db(dataset_id)
  145. if not dataset_path:
  146. dataset_path = self._find_dataset_path(dataset_id)
  147. if not dataset_path:
  148. raise FileNotFoundError(f"Dataset not found: {dataset_id}")
  149. # 预处理
  150. processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
  151. task_type = config.get("task_type", "sft")
  152. template = config.get("dataset_template", "alpaca")
  153. # 选择引擎
  154. engine = self._get_engine(model_type)
  155. # 预处理
  156. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  157. self.update_job(job_id, status=JobStatus.TRAINING)
  158. await self._notify_callbacks()
  159. # 加载模型
  160. await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
  161. # 构建 PEFT 配置
  162. peft_config = engine.get_peft_config(peft_method, config)
  163. # 训练
  164. adapter_path = await engine.train(
  165. job_id=job_id,
  166. dataset_path=processed_path,
  167. peft_config=peft_config,
  168. training_args=config,
  169. )
  170. self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path)
  171. await self._notify_callbacks()
  172. logger.info(f"Job {job_id} completed successfully")
  173. except asyncio.CancelledError:
  174. self.update_job(job_id, status=JobStatus.CANCELLED)
  175. await self._notify_callbacks()
  176. except Exception as e:
  177. logger.error(f"Job {job_id} failed: {e}")
  178. self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
  179. await self._notify_callbacks()
  180. def _find_dataset_path(self, dataset_id: str) -> str | None:
  181. """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。"""
  182. from app.config import get_settings
  183. from pathlib import Path
  184. settings = get_settings()
  185. # 尝试从 uploads 目录查找
  186. upload_path = settings.uploads_dir / dataset_id
  187. if upload_path.exists():
  188. return str(upload_path)
  189. # 如果 dataset_id 本身是路径
  190. if Path(dataset_id).exists():
  191. return dataset_id
  192. return None
  193. async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
  194. """从数据库查找数据集路径。"""
  195. from app.core.db import async_session, DatasetRecord
  196. from sqlalchemy import select
  197. async with async_session() as session:
  198. result = await session.execute(select(DatasetRecord).where(
  199. (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
  200. ))
  201. record = result.scalar_one_or_none()
  202. if record:
  203. return record.file_path
  204. return None
  205. def _get_engine(self, model_type: str):
  206. """根据模型类型选择训练引擎。"""
  207. if model_type == "vision":
  208. from app.engines.vision_engine import vision_engine
  209. return vision_engine
  210. elif model_type == "multimodal":
  211. from app.engines.multimodal_engine import multimodal_engine
  212. return multimodal_engine
  213. else:
  214. from app.engines.text_engine import text_engine
  215. return text_engine
  216. @property
  217. def jobs(self) -> dict[str, TrainingJob]:
  218. return dict(self._jobs)
  219. # 全局单例
  220. job_queue = JobQueue(max_concurrent=2)