job_queue.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. import asyncio
  2. import json
  3. from datetime import datetime, timezone
  4. from enum import Enum
  5. from typing import Any, Callable, Coroutine, Optional
  6. from pydantic import BaseModel, Field
  7. from app.core.logging import logger
  8. class JobStatus(str, Enum):
  9. PENDING = "pending"
  10. QUEUED = "queued"
  11. PREPROCESSING = "preprocessing"
  12. TRAINING = "training"
  13. COMPLETED = "completed"
  14. EVALUATING = "evaluating"
  15. EVALUATION_DONE = "evaluation_done"
  16. FAILED = "failed"
  17. CANCELLED = "cancelled"
  18. @property
  19. def is_terminal(self) -> bool:
  20. return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE)
  21. class TrainingJob(BaseModel):
  22. id: str
  23. model_id: str
  24. model_type: str
  25. peft_method: str
  26. dataset_id: str
  27. config: dict = Field(default_factory=dict)
  28. status: JobStatus = JobStatus.PENDING
  29. progress: float = 0.0
  30. current_epoch: int = 0
  31. current_step: int = 0
  32. total_steps: int = 0
  33. loss: float | None = None
  34. adapter_path: str | None = None
  35. error_message: str | None = None
  36. created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
  37. started_at: str | None = None
  38. finished_at: str | None = None
  39. class JobQueue:
  40. """异步任务队列,支持取消和并发控制。"""
  41. def __init__(self, max_concurrent: int = 2):
  42. self._queue: asyncio.Queue[str] = asyncio.Queue()
  43. self._jobs: dict[str, TrainingJob] = {}
  44. self._cancel_events: dict[str, asyncio.Event] = {}
  45. self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = []
  46. self._max_concurrent = max_concurrent
  47. self._workers: list[asyncio.Task] = []
  48. self._running = False
  49. async def start(self):
  50. """启动后台 worker。"""
  51. if self._running:
  52. return
  53. self._running = True
  54. for _ in range(self._max_concurrent):
  55. worker = asyncio.create_task(self._worker_loop())
  56. self._workers.append(worker)
  57. logger.info(f"JobQueue started with {self._max_concurrent} workers")
  58. async def stop(self):
  59. """停止所有 worker。"""
  60. self._running = False
  61. for event in self._cancel_events.values():
  62. event.set()
  63. for worker in self._workers:
  64. worker.cancel()
  65. self._workers.clear()
  66. logger.info("JobQueue stopped")
  67. async def enqueue(self, job_id: str, job: TrainingJob):
  68. """将任务加入队列。"""
  69. self._jobs[job_id] = job
  70. self._cancel_events[job_id] = asyncio.Event()
  71. await self._queue.put(job_id)
  72. logger.info(f"Job {job_id} enqueued")
  73. async def dequeue(self) -> str:
  74. """从队列中取出任务 ID。"""
  75. return await self._queue.get()
  76. def mark_done(self, job_id: str):
  77. """标记任务完成。"""
  78. self._queue.task_done()
  79. self._cancel_events.pop(job_id, None)
  80. def get_job(self, job_id: str) -> Optional[TrainingJob]:
  81. return self._jobs.get(job_id)
  82. def update_job(self, job_id: str, **kwargs):
  83. if job_id in self._jobs:
  84. job = self._jobs[job_id]
  85. for key, val in kwargs.items():
  86. if hasattr(job, key):
  87. setattr(job, key, val)
  88. def is_cancelled(self, job_id: str) -> bool:
  89. event = self._cancel_events.get(job_id)
  90. return event is not None and event.is_set()
  91. async def cancel(self, job_id: str):
  92. """取消任务。"""
  93. if job_id in self._cancel_events:
  94. self._cancel_events[job_id].set()
  95. self.update_job(job_id, status=JobStatus.CANCELLED)
  96. await self._notify_callbacks()
  97. logger.info(f"Job {job_id} cancelled")
  98. def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]):
  99. """注册状态变更回调(用于更新数据库等)。"""
  100. self._callbacks.append(callback)
  101. async def _notify_callbacks(self):
  102. for cb in self._callbacks:
  103. try:
  104. for job in self._jobs.values():
  105. await cb(job)
  106. except Exception as e:
  107. logger.error(f"JobQueue callback error: {e}")
  108. async def _worker_loop(self):
  109. """worker 循环:不断从队列取任务并执行。"""
  110. while self._running:
  111. try:
  112. job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0)
  113. except asyncio.TimeoutError:
  114. continue
  115. try:
  116. await self._run_job(job_id)
  117. except Exception as e:
  118. logger.error(f"Job {job_id} failed: {e}")
  119. self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
  120. finally:
  121. self._queue.task_done()
  122. async def _run_job(self, job_id: str):
  123. """执行单个任务:预处理 → 训练 → 完成。"""
  124. job = self._jobs.get(job_id)
  125. if not job:
  126. return
  127. self.update_job(job_id, status=JobStatus.QUEUED)
  128. await self._notify_callbacks()
  129. if self.is_cancelled(job_id):
  130. return
  131. self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat())
  132. await self._notify_callbacks()
  133. if self.is_cancelled(job_id):
  134. return
  135. try:
  136. config = job.config
  137. model_id = job.model_id
  138. model_type = job.model_type
  139. peft_method = job.peft_method
  140. dataset_id = config.get("dataset_id", job.dataset_id)
  141. from app.config import get_settings
  142. settings = get_settings()
  143. # 查找数据集文件路径
  144. dataset_path = await self._lookup_dataset_db(dataset_id)
  145. if not dataset_path:
  146. dataset_path = self._find_dataset_path(dataset_id)
  147. if not dataset_path:
  148. raise FileNotFoundError(f"Dataset not found: {dataset_id}")
  149. # 选择引擎
  150. engine = self._get_engine(model_type)
  151. # 预处理数据集(始终在本地执行)
  152. processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
  153. task_type = config.get("task_type", "sft")
  154. template = config.get("dataset_template", "alpaca")
  155. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  156. # 判断是否远程执行
  157. if settings.use_remote_compute:
  158. # 远程训练模式 — 数据集路径已由上面的代码查好
  159. if not dataset_path:
  160. dataset_path = self._find_dataset_path(dataset_id)
  161. if not dataset_path:
  162. raise FileNotFoundError(f"Dataset not found: {dataset_id}")
  163. # 启动新训练前清理容器内所有残留的 python 进程(释放 GPU ring buffer)
  164. await self._cleanup_remote_processes()
  165. self.update_job(job_id, status=JobStatus.TRAINING)
  166. await self._notify_callbacks()
  167. from app.core.remote_executor import run_training_remote, is_process_running
  168. pid = run_training_remote(job_id, model_id, model_type, dataset_path, config)
  169. if not pid:
  170. raise RuntimeError("Failed to launch remote training")
  171. # 轮询共享日志文件解析进度
  172. await self._poll_remote_progress(job_id, pid)
  173. logger.info(f"Remote training launched for job {job_id}")
  174. else:
  175. # 本地训练模式
  176. await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
  177. peft_config = engine.get_peft_config(peft_method, config)
  178. self.update_job(job_id, status=JobStatus.TRAINING)
  179. await self._notify_callbacks()
  180. adapter_path = await engine.train(
  181. job_id=job_id,
  182. dataset_path=processed_path,
  183. peft_config=peft_config,
  184. training_args=config,
  185. )
  186. self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path)
  187. await self._notify_callbacks()
  188. logger.info(f"Job {job_id} completed successfully")
  189. except asyncio.CancelledError:
  190. self.update_job(job_id, status=JobStatus.CANCELLED)
  191. await self._notify_callbacks()
  192. except Exception as e:
  193. # 远程训练模式:异常时也要 kill 远程进程
  194. error_msg = str(e)
  195. if settings.use_remote_compute and "pid" in locals():
  196. from app.core.remote_executor import ssh_exec
  197. container = settings.compute_node_docker_container
  198. try:
  199. await asyncio.to_thread(
  200. ssh_exec,
  201. f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
  202. f"pkill -9 -P {pid} 2>/dev/null'",
  203. timeout=5,
  204. )
  205. logger.info(f"Killed remote process {pid} due to exception")
  206. except Exception:
  207. # kill 超时 — 进程可能被 GPU 驱动锁死,由 _poll_remote_progress 兜底处理
  208. logger.warning(f"Failed to kill remote process {pid}, will be handled by progress poller")
  209. logger.error(f"Job {job_id} failed: {error_msg}")
  210. self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
  211. await self._notify_callbacks()
  212. def _find_dataset_path(self, dataset_id: str) -> str | None:
  213. """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。"""
  214. from app.config import get_settings
  215. from pathlib import Path
  216. settings = get_settings()
  217. # 尝试从 uploads 目录查找
  218. upload_path = settings.uploads_dir / dataset_id
  219. if upload_path.exists():
  220. return str(upload_path)
  221. # 如果 dataset_id 本身是路径
  222. if Path(dataset_id).exists():
  223. return dataset_id
  224. return None
  225. async def _cleanup_remote_processes(self):
  226. """通过 SSH 清理容器内所有残留的 python 进程(包括僵尸进程),释放 GPU ring buffer。
  227. 所有操作合并为一条 SSH 命令,避免多次连接导致超时。
  228. """
  229. from app.config import get_settings
  230. from app.core.remote_executor import ssh_exec
  231. settings = get_settings()
  232. container = settings.compute_node_docker_container
  233. # 一条命令完成:检查容器 → 查找 python 进程 → 逐个 kill → 输出清理结果
  234. cmd = (
  235. f"docker inspect -f '{{{{.State.Running}}}}' {container} 2>/dev/null || echo false; "
  236. f"if [ \"$(docker inspect -f '{{{{.State.Running}}}}' {container} 2>/dev/null)\" = 'true' ]; then "
  237. f"pids=$(docker exec {container} bash -c 'ps aux 2>/dev/null | grep \"[p]ython\" | grep -v grep | awk \"{{{{print \\$2}}}}\"'); "
  238. f"if [ -n \"$pids\" ]; then "
  239. f"echo \"$pids\" | while read pid; do "
  240. f"docker exec {container} bash -c 'kill -9 $pid 2>/dev/null; wait $pid 2>/dev/null'; "
  241. f"done; "
  242. f"echo \"cleaned $(echo \"$pids\" | wc -l) processes\"; "
  243. f"else echo 'no python processes'; fi; "
  244. f"else echo 'container not running'; fi"
  245. )
  246. code, stdout, stderr = await asyncio.to_thread(ssh_exec, cmd, timeout=60)
  247. if code != 0:
  248. logger.warning(f"Remote cleanup failed: code={code}, stderr={stderr}")
  249. else:
  250. logger.info(f"Remote cleanup result: {stdout.strip()}")
  251. async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
  252. """从数据库查找数据集路径。"""
  253. from app.core.db import async_session, DatasetRecord
  254. from sqlalchemy import select
  255. async with async_session() as session:
  256. result = await session.execute(select(DatasetRecord).where(
  257. (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
  258. ))
  259. record = result.scalar_one_or_none()
  260. if record:
  261. return record.file_path
  262. return None
  263. def _get_engine(self, model_type: str):
  264. """根据模型类型选择训练引擎。"""
  265. if model_type == "vision":
  266. from app.engines.vision_engine import vision_engine
  267. return vision_engine
  268. elif model_type == "multimodal":
  269. from app.engines.multimodal_engine import multimodal_engine
  270. return multimodal_engine
  271. else:
  272. from app.engines.text_engine import text_engine
  273. return text_engine
  274. async def _poll_remote_progress(self, job_id: str, pid: str):
  275. """通过 SSH 读取远程日志文件,解析训练进度(非阻塞)。
  276. 同时把 253 容器内的 stderr 日志同步输出到 151 后端日志中。
  277. """
  278. from app.config import get_settings
  279. from app.core.websocket import send_progress, send_epoch_done, send_completed, send_error
  280. from app.core.remote_executor import ssh_exec, is_process_running
  281. settings = get_settings()
  282. remote_log = f"{settings.compute_node_remote_data_dir}/logs/{job_id}.jsonl"
  283. container = settings.compute_node_docker_container
  284. last_bytes = 0
  285. stderr_last_bytes = 0 # 跟踪 stderr 日志读取位置
  286. poll_interval = 5
  287. max_polls = 8640
  288. consecutive_empty_polls = 0
  289. max_consecutive_empty = 12 # 60 秒无响应就开始检查 stderr
  290. async def _kill_remote_process(pid: str):
  291. """强制 kill 远程训练进程(多种方式兜底)。"""
  292. # 方式1: docker exec kill -9(常规方式)
  293. try:
  294. await asyncio.to_thread(
  295. ssh_exec,
  296. f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
  297. f"pkill -9 -P {pid} 2>/dev/null'",
  298. timeout=10,
  299. )
  300. logger.info(f"Killed remote process {pid} via docker exec")
  301. return
  302. except Exception as e:
  303. logger.warning(f"Failed to kill process {pid} via docker exec: {e}")
  304. # 方式2: nsenter 从宿主机直接进入进程 namespace 发信号
  305. try:
  306. await asyncio.to_thread(
  307. ssh_exec,
  308. f"docker exec {container} bash -c 'nsenter -t {pid} -p -s -- kill -9 {pid} 2>/dev/null || kill -9 {pid} 2>/dev/null'",
  309. timeout=10,
  310. )
  311. logger.info(f"Killed remote process {pid} via nsenter")
  312. return
  313. except Exception as e:
  314. logger.warning(f"Failed to kill process {pid} via nsenter: {e}")
  315. # 方式3: 终极方案 — 重启整个容器(释放所有 GPU 资源)
  316. try:
  317. await asyncio.to_thread(
  318. ssh_exec,
  319. f"docker restart -t 5 {container}",
  320. timeout=30,
  321. )
  322. logger.warning(f"Force restarted container {container} to release GPU resources")
  323. except Exception as e:
  324. logger.error(f"Failed to restart container {container}: {e}")
  325. async def _mark_failed(error_msg: str):
  326. """统一标记失败:先 kill 远程进程,再更新状态。"""
  327. await _kill_remote_process(pid)
  328. self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
  329. await self._notify_callbacks()
  330. await send_error(job_id, error_msg)
  331. for _ in range(max_polls):
  332. if self.is_cancelled(job_id):
  333. await _kill_remote_process(pid)
  334. self.update_job(job_id, status=JobStatus.CANCELLED)
  335. await self._notify_callbacks()
  336. await send_error(job_id, "Training cancelled")
  337. return
  338. # 检查进程是否还在运行(非阻塞)
  339. process_alive = await asyncio.to_thread(is_process_running, pid)
  340. # === 1. 读取 jsonl 进度日志 ===
  341. cat_cmd = f"docker exec {container} bash -c 'wc -c < {remote_log} 2>/dev/null || echo 0'"
  342. code, size_out, _ = await asyncio.to_thread(ssh_exec, cat_cmd, timeout=30)
  343. try:
  344. file_size = int(size_out.strip()) if code == 0 and size_out.strip() else 0
  345. except ValueError:
  346. file_size = 0
  347. has_new_log = False
  348. if file_size > last_bytes:
  349. read_cmd = f"docker exec {container} bash -c 'tail -c +{last_bytes + 1} {remote_log} 2>/dev/null'"
  350. code, log_content, _ = await asyncio.to_thread(ssh_exec, read_cmd, timeout=30)
  351. if code == 0 and log_content.strip():
  352. has_new_log = True
  353. consecutive_empty_polls = 0
  354. for line in log_content.strip().split("\n"):
  355. line = line.strip()
  356. if not line:
  357. continue
  358. try:
  359. entry = json.loads(line)
  360. except json.JSONDecodeError:
  361. continue
  362. entry_type = entry.get("type")
  363. if entry_type == "progress":
  364. step = entry.get("step", 0)
  365. total_steps = entry.get("total_steps", 0)
  366. # 计算进度:total_steps 为 0 时基于 epoch 估算(每 epoch 按 100/epochs% 递增)
  367. if total_steps > 0:
  368. progress = round(step / total_steps * 100, 1)
  369. else:
  370. # 无 total_steps 时,step 每增加 1 按 0.1% 估算(兜底)
  371. progress = round(step * 0.1, 1)
  372. progress = min(99.9, max(0, progress)) # 限制在 0-99.9%,completed 时才会到 100%
  373. self.update_job(job_id,
  374. current_step=step,
  375. total_steps=total_steps,
  376. loss=entry.get("loss"),
  377. progress=progress)
  378. await self._notify_callbacks()
  379. await send_progress(job_id, **{k: v for k, v in entry.items() if k != "type"})
  380. elif entry_type == "epoch_begin":
  381. self.update_job(job_id, current_epoch=entry.get("epoch", 0))
  382. await self._notify_callbacks()
  383. elif entry_type == "epoch_done":
  384. await self._notify_callbacks()
  385. await send_epoch_done(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
  386. elif entry_type == "completed":
  387. adapter_path = entry.get("adapter_path", str(settings.adapters_dir / job_id))
  388. self.update_job(job_id,
  389. status=JobStatus.COMPLETED,
  390. adapter_path=adapter_path,
  391. progress=100.0)
  392. await self._notify_callbacks()
  393. await send_completed(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
  394. return
  395. elif entry_type == "error":
  396. error_msg = entry.get("message", "Unknown error")
  397. logger.error(f"Remote job {job_id} failed: {error_msg}")
  398. await _mark_failed(error_msg)
  399. return
  400. last_bytes = file_size
  401. # === 2. 同步 253 stderr 日志到 151 后端日志 ===
  402. stderr_cmd = f"docker exec {container} bash -c 'wc -c < /tmp/train_{job_id}.log 2>/dev/null || echo 0'"
  403. code, stderr_size_out, _ = await asyncio.to_thread(ssh_exec, stderr_cmd, timeout=30)
  404. try:
  405. stderr_size = int(stderr_size_out.strip()) if code == 0 and stderr_size_out.strip() else 0
  406. except ValueError:
  407. stderr_size = 0
  408. if stderr_size > stderr_last_bytes:
  409. read_stderr_cmd = f"docker exec {container} bash -c 'tail -c +{stderr_last_bytes + 1} /tmp/train_{job_id}.log 2>/dev/null'"
  410. code, stderr_content, _ = await asyncio.to_thread(ssh_exec, read_stderr_cmd, timeout=30)
  411. if code == 0 and stderr_content.strip():
  412. for line in stderr_content.strip().split("\n"):
  413. line = line.strip()
  414. if not line:
  415. continue
  416. # 识别日志级别
  417. if "[remote_train]" in line:
  418. logger.info(f"[253:{job_id[:8]}] {line}")
  419. elif "[MXKW][E]" in line or "ERROR" in line or "Error" in line:
  420. logger.error(f"[253:{job_id[:8]}] {line}")
  421. elif "[transformers]" in line or "UserWarning" in line or "Warning" in line:
  422. logger.warning(f"[253:{job_id[:8]}] {line}")
  423. else:
  424. logger.info(f"[253:{job_id[:8]}] {line}")
  425. stderr_last_bytes = stderr_size
  426. if not has_new_log:
  427. consecutive_empty_polls += 1
  428. # 进程已退出但日志里没有 completed/error
  429. if not process_alive:
  430. # 多等几秒让日志写完
  431. await asyncio.sleep(2)
  432. if not await asyncio.to_thread(is_process_running, pid):
  433. # 进程退出但没有写 completed/error 日志,读取 stderr 日志兜底
  434. error_msg = f"Remote process exited unexpectedly (pid={pid})"
  435. try:
  436. from app.core.remote_executor import get_remote_stderr
  437. stderr_content = await asyncio.to_thread(get_remote_stderr, job_id)
  438. if stderr_content:
  439. error_msg = stderr_content[-1000:]
  440. except Exception:
  441. pass
  442. logger.error(f"Remote job {job_id} failed: {error_msg}")
  443. await _mark_failed(error_msg)
  444. return
  445. # 长时间无日志且进程异常,也标记为失败
  446. if consecutive_empty_polls >= max_consecutive_empty and not process_alive:
  447. error_msg = f"Remote process exited unexpectedly (pid={pid}), no error log found"
  448. logger.error(f"Remote job {job_id} failed: {error_msg}")
  449. await _mark_failed(error_msg)
  450. return
  451. await asyncio.sleep(poll_interval)
  452. # 超时
  453. error_msg = "Remote training timed out"
  454. logger.error(f"Remote job {job_id} failed: {error_msg}")
  455. await _mark_failed(error_msg)
  456. @property
  457. def jobs(self) -> dict[str, TrainingJob]:
  458. return dict(self._jobs)
  459. # 全局单例
  460. job_queue = JobQueue(max_concurrent=2)