import asyncio import json from datetime import datetime, timezone from enum import Enum from typing import Any, Callable, Coroutine, Optional from pydantic import BaseModel, Field from app.core.logging import logger class JobStatus(str, Enum): PENDING = "pending" QUEUED = "queued" PREPROCESSING = "preprocessing" TRAINING = "training" COMPLETED = "completed" EVALUATING = "evaluating" EVALUATION_DONE = "evaluation_done" FAILED = "failed" CANCELLED = "cancelled" @property def is_terminal(self) -> bool: return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE) class TrainingJob(BaseModel): id: str model_id: str model_type: str peft_method: str dataset_id: str config: dict = Field(default_factory=dict) status: JobStatus = JobStatus.PENDING progress: float = 0.0 current_epoch: int = 0 current_step: int = 0 total_steps: int = 0 loss: float | None = None adapter_path: str | None = None error_message: str | None = None created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) started_at: str | None = None finished_at: str | None = None class JobQueue: """异步任务队列,支持取消和并发控制。""" def __init__(self, max_concurrent: int = 2): self._queue: asyncio.Queue[str] = asyncio.Queue() self._jobs: dict[str, TrainingJob] = {} self._cancel_events: dict[str, asyncio.Event] = {} self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = [] self._max_concurrent = max_concurrent self._workers: list[asyncio.Task] = [] self._running = False async def start(self): """启动后台 worker。""" if self._running: return self._running = True for _ in range(self._max_concurrent): worker = asyncio.create_task(self._worker_loop()) self._workers.append(worker) logger.info(f"JobQueue started with {self._max_concurrent} workers") async def stop(self): """停止所有 worker。""" self._running = False for event in self._cancel_events.values(): event.set() for worker in self._workers: worker.cancel() self._workers.clear() logger.info("JobQueue stopped") async def enqueue(self, job_id: str, job: TrainingJob): """将任务加入队列。""" self._jobs[job_id] = job self._cancel_events[job_id] = asyncio.Event() await self._queue.put(job_id) logger.info(f"Job {job_id} enqueued") async def dequeue(self) -> str: """从队列中取出任务 ID。""" return await self._queue.get() def mark_done(self, job_id: str): """标记任务完成。""" self._queue.task_done() self._cancel_events.pop(job_id, None) def get_job(self, job_id: str) -> Optional[TrainingJob]: return self._jobs.get(job_id) def update_job(self, job_id: str, **kwargs): if job_id in self._jobs: job = self._jobs[job_id] for key, val in kwargs.items(): if hasattr(job, key): setattr(job, key, val) def is_cancelled(self, job_id: str) -> bool: event = self._cancel_events.get(job_id) return event is not None and event.is_set() async def cancel(self, job_id: str): """取消任务。""" if job_id in self._cancel_events: self._cancel_events[job_id].set() self.update_job(job_id, status=JobStatus.CANCELLED) await self._notify_callbacks() logger.info(f"Job {job_id} cancelled") def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]): """注册状态变更回调(用于更新数据库等)。""" self._callbacks.append(callback) async def _notify_callbacks(self): for cb in self._callbacks: try: for job in self._jobs.values(): await cb(job) except Exception as e: logger.error(f"JobQueue callback error: {e}") async def _worker_loop(self): """worker 循环:不断从队列取任务并执行。""" while self._running: try: job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0) except asyncio.TimeoutError: continue try: await self._run_job(job_id) except Exception as e: logger.error(f"Job {job_id} failed: {e}") self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e)) finally: self._queue.task_done() async def _run_job(self, job_id: str): """执行单个任务:预处理 → 训练 → 完成。""" job = self._jobs.get(job_id) if not job: return self.update_job(job_id, status=JobStatus.QUEUED) await self._notify_callbacks() if self.is_cancelled(job_id): return self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat()) await self._notify_callbacks() if self.is_cancelled(job_id): return try: config = job.config model_id = job.model_id model_type = job.model_type peft_method = job.peft_method dataset_id = config.get("dataset_id", job.dataset_id) from app.config import get_settings settings = get_settings() # 查找数据集文件路径 dataset_path = await self._lookup_dataset_db(dataset_id) if not dataset_path: dataset_path = self._find_dataset_path(dataset_id) if not dataset_path: raise FileNotFoundError(f"Dataset not found: {dataset_id}") # 选择引擎 engine = self._get_engine(model_type) # 预处理数据集(始终在本地执行) processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "alpaca") await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) # 判断是否远程执行 if settings.use_remote_compute: # 远程训练模式 — 数据集路径已由上面的代码查好 if not dataset_path: dataset_path = self._find_dataset_path(dataset_id) if not dataset_path: raise FileNotFoundError(f"Dataset not found: {dataset_id}") self.update_job(job_id, status=JobStatus.TRAINING) await self._notify_callbacks() from app.core.remote_executor import run_training_remote, is_process_running pid = run_training_remote(job_id, model_id, model_type, dataset_path, config) if not pid: raise RuntimeError("Failed to launch remote training") # 轮询共享日志文件解析进度 await self._poll_remote_progress(job_id, pid) logger.info(f"Remote training launched for job {job_id}") else: # 本地训练模式 await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None) peft_config = engine.get_peft_config(peft_method, config) self.update_job(job_id, status=JobStatus.TRAINING) await self._notify_callbacks() adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, ) self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path) await self._notify_callbacks() logger.info(f"Job {job_id} completed successfully") except asyncio.CancelledError: self.update_job(job_id, status=JobStatus.CANCELLED) await self._notify_callbacks() except Exception as e: logger.error(f"Job {job_id} failed: {e}") self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e)) await self._notify_callbacks() def _find_dataset_path(self, dataset_id: str) -> str | None: """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。""" from app.config import get_settings from pathlib import Path settings = get_settings() # 尝试从 uploads 目录查找 upload_path = settings.uploads_dir / dataset_id if upload_path.exists(): return str(upload_path) # 如果 dataset_id 本身是路径 if Path(dataset_id).exists(): return dataset_id return None async def _lookup_dataset_db(self, dataset_id: str) -> str | None: """从数据库查找数据集路径。""" from app.core.db import async_session, DatasetRecord from sqlalchemy import select async with async_session() as session: result = await session.execute(select(DatasetRecord).where( (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id) )) record = result.scalar_one_or_none() if record: return record.file_path return None def _get_engine(self, model_type: str): """根据模型类型选择训练引擎。""" if model_type == "vision": from app.engines.vision_engine import vision_engine return vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine return multimodal_engine else: from app.engines.text_engine import text_engine return text_engine async def _poll_remote_progress(self, job_id: str, pid: str): """通过 SSH 读取远程日志文件,解析训练进度(非阻塞)。""" from app.config import get_settings from app.core.websocket import send_progress, send_epoch_done, send_completed, send_error from app.core.remote_executor import ssh_exec, is_process_running settings = get_settings() remote_log = f"{settings.compute_node_remote_data_dir}/logs/{job_id}.jsonl" last_bytes = 0 poll_interval = 5 max_polls = 8640 for _ in range(max_polls): if self.is_cancelled(job_id): _s = get_settings() await asyncio.to_thread(ssh_exec, f"docker exec {_s.compute_node_docker_container} bash -c 'kill {pid} 2>/dev/null'", timeout=10) self.update_job(job_id, status=JobStatus.CANCELLED) await self._notify_callbacks() await send_error(job_id, "Training cancelled") return # 检查进程是否还在运行(非阻塞) process_alive = await asyncio.to_thread(is_process_running, pid) # 通过 SSH 远程读取日志文件(非阻塞) cat_cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'wc -c < {remote_log} 2>/dev/null || echo 0'" code, size_out, _ = await asyncio.to_thread(ssh_exec, cat_cmd, timeout=30) try: file_size = int(size_out.strip()) if code == 0 and size_out.strip() else 0 except ValueError: file_size = 0 if file_size > last_bytes: read_cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'tail -c +{last_bytes + 1} {remote_log} 2>/dev/null'" code, log_content, _ = await asyncio.to_thread(ssh_exec, read_cmd, timeout=30) if code == 0 and log_content.strip(): for line in log_content.strip().split("\n"): line = line.strip() if not line: continue try: entry = json.loads(line) except json.JSONDecodeError: continue entry_type = entry.get("type") if entry_type == "progress": self.update_job(job_id, current_step=entry.get("step", 0), total_steps=entry.get("total_steps", 0), loss=entry.get("loss"), progress=round(entry.get("step", 0) / max(entry.get("total_steps", 1), 1) * 100, 1)) await self._notify_callbacks() await send_progress(job_id, **{k: v for k, v in entry.items() if k != "type"}) elif entry_type == "epoch_begin": self.update_job(job_id, current_epoch=entry.get("epoch", 0)) await self._notify_callbacks() elif entry_type == "epoch_done": await self._notify_callbacks() await send_epoch_done(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")}) elif entry_type == "completed": adapter_path = entry.get("adapter_path", str(settings.adapters_dir / job_id)) self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path, progress=100.0) await self._notify_callbacks() await send_completed(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")}) return elif entry_type == "error": error_msg = entry.get("message", "Unknown error") logger.error(f"Remote job {job_id} failed: {error_msg}") self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg) await self._notify_callbacks() await send_error(job_id, error_msg) return last_bytes = file_size # 进程已退出但日志里没有 completed/error if not process_alive: await asyncio.sleep(2) if not await asyncio.to_thread(is_process_running, pid): error_msg = f"Remote process exited unexpectedly (pid={pid})" logger.error(f"Remote job {job_id} failed: {error_msg}") self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg) await self._notify_callbacks() await send_error(job_id, error_msg) return await asyncio.sleep(poll_interval) # 超时 error_msg = "Remote training timed out" logger.error(f"Remote job {job_id} failed: {error_msg}") self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg) await self._notify_callbacks() await send_error(job_id, error_msg) @property def jobs(self) -> dict[str, TrainingJob]: return dict(self._jobs) # 全局单例 job_queue = JobQueue(max_concurrent=2)