job_queue.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import asyncio
  2. import json
  3. from datetime import datetime, timezone
  4. from enum import Enum
  5. from typing import Any, Callable, Coroutine, Optional
  6. from pydantic import BaseModel, Field
  7. from app.core.logging import logger
  8. class JobStatus(str, Enum):
  9. PENDING = "pending"
  10. QUEUED = "queued"
  11. PREPROCESSING = "preprocessing"
  12. TRAINING = "training"
  13. COMPLETED = "completed"
  14. EVALUATING = "evaluating"
  15. EVALUATION_DONE = "evaluation_done"
  16. FAILED = "failed"
  17. CANCELLED = "cancelled"
  18. @property
  19. def is_terminal(self) -> bool:
  20. return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE)
  21. class TrainingJob(BaseModel):
  22. id: str
  23. model_id: str
  24. model_type: str
  25. peft_method: str
  26. dataset_id: str
  27. config: dict = Field(default_factory=dict)
  28. status: JobStatus = JobStatus.PENDING
  29. progress: float = 0.0
  30. current_epoch: int = 0
  31. current_step: int = 0
  32. total_steps: int = 0
  33. loss: float | None = None
  34. adapter_path: str | None = None
  35. error_message: str | None = None
  36. created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
  37. started_at: str | None = None
  38. finished_at: str | None = None
  39. class JobQueue:
  40. """异步任务队列,支持取消和并发控制。"""
  41. def __init__(self, max_concurrent: int = 2):
  42. self._queue: asyncio.Queue[str] = asyncio.Queue()
  43. self._jobs: dict[str, TrainingJob] = {}
  44. self._cancel_events: dict[str, asyncio.Event] = {}
  45. self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = []
  46. self._max_concurrent = max_concurrent
  47. self._workers: list[asyncio.Task] = []
  48. self._running = False
  49. async def start(self):
  50. """启动后台 worker。"""
  51. if self._running:
  52. return
  53. self._running = True
  54. for _ in range(self._max_concurrent):
  55. worker = asyncio.create_task(self._worker_loop())
  56. self._workers.append(worker)
  57. logger.info(f"JobQueue started with {self._max_concurrent} workers")
  58. async def stop(self):
  59. """停止所有 worker。"""
  60. self._running = False
  61. for event in self._cancel_events.values():
  62. event.set()
  63. for worker in self._workers:
  64. worker.cancel()
  65. self._workers.clear()
  66. logger.info("JobQueue stopped")
  67. async def enqueue(self, job_id: str, job: TrainingJob):
  68. """将任务加入队列。"""
  69. self._jobs[job_id] = job
  70. self._cancel_events[job_id] = asyncio.Event()
  71. await self._queue.put(job_id)
  72. logger.info(f"Job {job_id} enqueued")
  73. async def dequeue(self) -> str:
  74. """从队列中取出任务 ID。"""
  75. return await self._queue.get()
  76. def mark_done(self, job_id: str):
  77. """标记任务完成。"""
  78. self._queue.task_done()
  79. self._cancel_events.pop(job_id, None)
  80. def get_job(self, job_id: str) -> Optional[TrainingJob]:
  81. return self._jobs.get(job_id)
  82. def update_job(self, job_id: str, **kwargs):
  83. if job_id in self._jobs:
  84. job = self._jobs[job_id]
  85. for key, val in kwargs.items():
  86. if hasattr(job, key):
  87. setattr(job, key, val)
  88. def is_cancelled(self, job_id: str) -> bool:
  89. event = self._cancel_events.get(job_id)
  90. return event is not None and event.is_set()
  91. async def cancel(self, job_id: str):
  92. """取消任务。"""
  93. if job_id in self._cancel_events:
  94. self._cancel_events[job_id].set()
  95. self.update_job(job_id, status=JobStatus.CANCELLED)
  96. await self._notify_callbacks()
  97. logger.info(f"Job {job_id} cancelled")
  98. def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]):
  99. """注册状态变更回调(用于更新数据库等)。"""
  100. self._callbacks.append(callback)
  101. async def _notify_callbacks(self):
  102. for cb in self._callbacks:
  103. try:
  104. for job in self._jobs.values():
  105. await cb(job)
  106. except Exception as e:
  107. logger.error(f"JobQueue callback error: {e}")
  108. async def _worker_loop(self):
  109. """worker 循环:不断从队列取任务并执行。"""
  110. while self._running:
  111. try:
  112. job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0)
  113. except asyncio.TimeoutError:
  114. continue
  115. try:
  116. await self._run_job(job_id)
  117. except Exception as e:
  118. logger.error(f"Job {job_id} failed: {e}")
  119. self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
  120. finally:
  121. self._queue.task_done()
  122. async def _run_job(self, job_id: str):
  123. """执行单个任务:预处理 → 训练 → 完成。"""
  124. job = self._jobs.get(job_id)
  125. if not job:
  126. return
  127. self.update_job(job_id, status=JobStatus.QUEUED)
  128. await self._notify_callbacks()
  129. if self.is_cancelled(job_id):
  130. return
  131. self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat())
  132. await self._notify_callbacks()
  133. if self.is_cancelled(job_id):
  134. return
  135. try:
  136. config = job.config
  137. model_id = job.model_id
  138. model_type = job.model_type
  139. peft_method = job.peft_method
  140. dataset_id = config.get("dataset_id", job.dataset_id)
  141. from app.config import get_settings
  142. settings = get_settings()
  143. # 查找数据集文件路径
  144. dataset_path = await self._lookup_dataset_db(dataset_id)
  145. if not dataset_path:
  146. dataset_path = self._find_dataset_path(dataset_id)
  147. if not dataset_path:
  148. raise FileNotFoundError(f"Dataset not found: {dataset_id}")
  149. # 选择引擎
  150. engine = self._get_engine(model_type)
  151. # 预处理数据集(始终在本地执行)
  152. processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
  153. task_type = config.get("task_type", "sft")
  154. template = config.get("dataset_template", "alpaca")
  155. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  156. # 判断是否远程执行
  157. if settings.use_remote_compute:
  158. # 远程训练模式
  159. self.update_job(job_id, status=JobStatus.TRAINING)
  160. await self._notify_callbacks()
  161. from app.core.remote_executor import run_training_remote, is_process_running
  162. pid = run_training_remote(job_id, model_id, model_type, dataset_id, config)
  163. if not pid:
  164. raise RuntimeError("Failed to launch remote training")
  165. # 轮询共享日志文件解析进度
  166. await self._poll_remote_progress(job_id, pid)
  167. logger.info(f"Remote training launched for job {job_id}")
  168. else:
  169. # 本地训练模式
  170. await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
  171. peft_config = engine.get_peft_config(peft_method, config)
  172. self.update_job(job_id, status=JobStatus.TRAINING)
  173. await self._notify_callbacks()
  174. adapter_path = await engine.train(
  175. job_id=job_id,
  176. dataset_path=processed_path,
  177. peft_config=peft_config,
  178. training_args=config,
  179. )
  180. self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path)
  181. await self._notify_callbacks()
  182. logger.info(f"Job {job_id} completed successfully")
  183. except asyncio.CancelledError:
  184. self.update_job(job_id, status=JobStatus.CANCELLED)
  185. await self._notify_callbacks()
  186. except Exception as e:
  187. logger.error(f"Job {job_id} failed: {e}")
  188. self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
  189. await self._notify_callbacks()
  190. def _find_dataset_path(self, dataset_id: str) -> str | None:
  191. """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。"""
  192. from app.config import get_settings
  193. from pathlib import Path
  194. settings = get_settings()
  195. # 尝试从 uploads 目录查找
  196. upload_path = settings.uploads_dir / dataset_id
  197. if upload_path.exists():
  198. return str(upload_path)
  199. # 如果 dataset_id 本身是路径
  200. if Path(dataset_id).exists():
  201. return dataset_id
  202. return None
  203. async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
  204. """从数据库查找数据集路径。"""
  205. from app.core.db import async_session, DatasetRecord
  206. from sqlalchemy import select
  207. async with async_session() as session:
  208. result = await session.execute(select(DatasetRecord).where(
  209. (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
  210. ))
  211. record = result.scalar_one_or_none()
  212. if record:
  213. return record.file_path
  214. return None
  215. def _get_engine(self, model_type: str):
  216. """根据模型类型选择训练引擎。"""
  217. if model_type == "vision":
  218. from app.engines.vision_engine import vision_engine
  219. return vision_engine
  220. elif model_type == "multimodal":
  221. from app.engines.multimodal_engine import multimodal_engine
  222. return multimodal_engine
  223. else:
  224. from app.engines.text_engine import text_engine
  225. return text_engine
  226. async def _poll_remote_progress(self, job_id: str, pid: str):
  227. """轮询共享日志文件,解析远程训练进度并通过 WebSocket 推送。"""
  228. from app.config import get_settings
  229. from app.core.websocket import send_progress, send_epoch_done, send_completed, send_error
  230. settings = get_settings()
  231. log_file = settings.data_dir / "logs" / f"{job_id}.jsonl"
  232. last_offset = 0
  233. poll_interval = 5 # 每 5 秒轮询一次
  234. max_polls = 8640 # 最多轮询 12 小时 (8640 * 5s)
  235. for _ in range(max_polls):
  236. if self.is_cancelled(job_id):
  237. # 取消容器内的远程进程
  238. from app.core.remote_executor import ssh_exec
  239. from app.config import get_settings
  240. _s = get_settings()
  241. ssh_exec(f"docker exec {_s.compute_node_docker_container} bash -c 'kill {pid} 2>/dev/null'", timeout=10)
  242. self.update_job(job_id, status=JobStatus.CANCELLED)
  243. await self._notify_callbacks()
  244. await send_error(job_id, "Training cancelled")
  245. return
  246. # 检查进程是否还在运行
  247. from app.core.remote_executor import is_process_running
  248. process_alive = is_process_running(pid)
  249. # 读取新的日志行
  250. if log_file.exists():
  251. try:
  252. with open(log_file, "r", encoding="utf-8") as f:
  253. f.seek(last_offset)
  254. new_lines = f.readlines()
  255. last_offset = f.tell()
  256. for line in new_lines:
  257. line = line.strip()
  258. if not line:
  259. continue
  260. try:
  261. entry = json.loads(line)
  262. except json.JSONDecodeError:
  263. continue
  264. entry_type = entry.get("type")
  265. if entry_type == "progress":
  266. self.update_job(job_id,
  267. epoch=entry.get("epoch", 0),
  268. current_step=entry.get("step", 0),
  269. total_steps=entry.get("total_steps", 0),
  270. loss=entry.get("loss"),
  271. progress=round(entry.get("step", 0) / max(entry.get("total_steps", 1), 1) * 100, 1))
  272. await self._notify_callbacks()
  273. await send_progress(job_id, **{k: v for k, v in entry.items() if k != "type"})
  274. elif entry_type == "epoch_begin":
  275. self.update_job(job_id, current_epoch=entry.get("epoch", 0))
  276. await self._notify_callbacks()
  277. elif entry_type == "epoch_done":
  278. await self._notify_callbacks()
  279. await send_epoch_done(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
  280. elif entry_type == "completed":
  281. adapter_path = entry.get("adapter_path", str(settings.adapters_dir / job_id))
  282. self.update_job(job_id,
  283. status=JobStatus.COMPLETED,
  284. adapter_path=adapter_path,
  285. progress=100.0)
  286. await self._notify_callbacks()
  287. await send_completed(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
  288. return
  289. elif entry_type == "error":
  290. self.update_job(job_id,
  291. status=JobStatus.FAILED,
  292. error_message=entry.get("message", "Unknown error"))
  293. await self._notify_callbacks()
  294. await send_error(job_id, entry.get("message", "Unknown error"))
  295. return
  296. except Exception as e:
  297. logger.warning(f"Error reading remote log file: {e}")
  298. # 进程已退出但日志里没有 completed/error,可能异常退出
  299. if not process_alive:
  300. # 再等一轮确认
  301. await asyncio.sleep(2)
  302. if not is_process_running(pid):
  303. # 检查日志里是否有最终状态
  304. if log_file.exists():
  305. try:
  306. with open(log_file, "r", encoding="utf-8") as f:
  307. content = f.read()
  308. if "completed" in content or "error" in content:
  309. # 上面已经处理过了
  310. continue
  311. except Exception:
  312. pass
  313. # 进程退出但没有最终状态,视为失败
  314. self.update_job(job_id,
  315. status=JobStatus.FAILED,
  316. error_message=f"Remote process exited unexpectedly (pid={pid})")
  317. await self._notify_callbacks()
  318. await send_error(job_id, f"Remote process exited unexpectedly (pid={pid})")
  319. return
  320. await asyncio.sleep(poll_interval)
  321. # 超时
  322. self.update_job(job_id, status=JobStatus.FAILED, error_message="Remote training timed out")
  323. await self._notify_callbacks()
  324. await send_error(job_id, "Remote training timed out")
  325. @property
  326. def jobs(self) -> dict[str, TrainingJob]:
  327. return dict(self._jobs)
  328. # 全局单例
  329. job_queue = JobQueue(max_concurrent=2)