| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548 |
- import asyncio
- import json
- from datetime import datetime, timezone
- from enum import Enum
- from typing import Any, Callable, Coroutine, Optional
- from pydantic import BaseModel, Field
- from app.core.logging import logger
- class JobStatus(str, Enum):
- PENDING = "pending"
- QUEUED = "queued"
- PREPROCESSING = "preprocessing"
- TRAINING = "training"
- COMPLETED = "completed"
- EVALUATING = "evaluating"
- EVALUATION_DONE = "evaluation_done"
- FAILED = "failed"
- CANCELLED = "cancelled"
- @property
- def is_terminal(self) -> bool:
- return self in (self.COMPLETED, self.FAILED, self.CANCELLED, self.EVALUATION_DONE)
- class TrainingJob(BaseModel):
- id: str
- model_id: str
- model_type: str
- peft_method: str
- dataset_id: str
- config: dict = Field(default_factory=dict)
- status: JobStatus = JobStatus.PENDING
- progress: float = 0.0
- current_epoch: int = 0
- current_step: int = 0
- total_steps: int = 0
- loss: float | None = None
- adapter_path: str | None = None
- error_message: str | None = None
- created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
- started_at: str | None = None
- finished_at: str | None = None
- class JobQueue:
- """异步任务队列,支持取消和并发控制。"""
- def __init__(self, max_concurrent: int = 2):
- self._queue: asyncio.Queue[str] = asyncio.Queue()
- self._jobs: dict[str, TrainingJob] = {}
- self._cancel_events: dict[str, asyncio.Event] = {}
- self._callbacks: list[Callable[[TrainingJob], Coroutine[Any, Any, None]]] = []
- self._max_concurrent = max_concurrent
- self._workers: list[asyncio.Task] = []
- self._running = False
- async def start(self):
- """启动后台 worker。"""
- if self._running:
- return
- self._running = True
- for _ in range(self._max_concurrent):
- worker = asyncio.create_task(self._worker_loop())
- self._workers.append(worker)
- logger.info(f"JobQueue started with {self._max_concurrent} workers")
- async def stop(self):
- """停止所有 worker。"""
- self._running = False
- for event in self._cancel_events.values():
- event.set()
- for worker in self._workers:
- worker.cancel()
- self._workers.clear()
- logger.info("JobQueue stopped")
- async def enqueue(self, job_id: str, job: TrainingJob):
- """将任务加入队列。"""
- self._jobs[job_id] = job
- self._cancel_events[job_id] = asyncio.Event()
- await self._queue.put(job_id)
- logger.info(f"Job {job_id} enqueued")
- async def dequeue(self) -> str:
- """从队列中取出任务 ID。"""
- return await self._queue.get()
- def mark_done(self, job_id: str):
- """标记任务完成。"""
- self._queue.task_done()
- self._cancel_events.pop(job_id, None)
- def get_job(self, job_id: str) -> Optional[TrainingJob]:
- return self._jobs.get(job_id)
- def update_job(self, job_id: str, **kwargs):
- if job_id in self._jobs:
- job = self._jobs[job_id]
- for key, val in kwargs.items():
- if hasattr(job, key):
- setattr(job, key, val)
- def is_cancelled(self, job_id: str) -> bool:
- event = self._cancel_events.get(job_id)
- return event is not None and event.is_set()
- async def cancel(self, job_id: str):
- """取消任务。"""
- if job_id in self._cancel_events:
- self._cancel_events[job_id].set()
- self.update_job(job_id, status=JobStatus.CANCELLED)
- await self._notify_callbacks()
- logger.info(f"Job {job_id} cancelled")
- def register_callback(self, callback: Callable[[TrainingJob], Coroutine[Any, Any, None]]):
- """注册状态变更回调(用于更新数据库等)。"""
- self._callbacks.append(callback)
- async def _notify_callbacks(self):
- for cb in self._callbacks:
- try:
- for job in self._jobs.values():
- await cb(job)
- except Exception as e:
- logger.error(f"JobQueue callback error: {e}")
- async def _worker_loop(self):
- """worker 循环:不断从队列取任务并执行。"""
- while self._running:
- try:
- job_id = await asyncio.wait_for(self._queue.get(), timeout=1.0)
- except asyncio.TimeoutError:
- continue
- try:
- await self._run_job(job_id)
- except Exception as e:
- logger.error(f"Job {job_id} failed: {e}")
- self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
- finally:
- self._queue.task_done()
- async def _run_job(self, job_id: str):
- """执行单个任务:预处理 → 训练 → 完成。"""
- job = self._jobs.get(job_id)
- if not job:
- return
- self.update_job(job_id, status=JobStatus.QUEUED)
- await self._notify_callbacks()
- if self.is_cancelled(job_id):
- return
- self.update_job(job_id, status=JobStatus.PREPROCESSING, started_at=datetime.now(timezone.utc).isoformat())
- await self._notify_callbacks()
- if self.is_cancelled(job_id):
- return
- try:
- config = job.config
- model_id = job.model_id
- model_type = job.model_type
- peft_method = job.peft_method
- dataset_id = config.get("dataset_id", job.dataset_id)
- from app.config import get_settings
- settings = get_settings()
- # 查找数据集文件路径
- dataset_path = await self._lookup_dataset_db(dataset_id)
- if not dataset_path:
- dataset_path = self._find_dataset_path(dataset_id)
- if not dataset_path:
- raise FileNotFoundError(f"Dataset not found: {dataset_id}")
- # 选择引擎
- engine = self._get_engine(model_type)
- # 预处理数据集(始终在本地执行)
- processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
- task_type = config.get("task_type", "sft")
- template = config.get("dataset_template", "alpaca")
- await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
- # 判断是否远程执行
- if settings.use_remote_compute:
- # 远程训练模式 — 数据集路径已由上面的代码查好
- if not dataset_path:
- dataset_path = self._find_dataset_path(dataset_id)
- if not dataset_path:
- raise FileNotFoundError(f"Dataset not found: {dataset_id}")
- # 启动新训练前清理容器内所有残留的 python 进程(释放 GPU ring buffer)
- await self._cleanup_remote_processes()
- self.update_job(job_id, status=JobStatus.TRAINING)
- await self._notify_callbacks()
- from app.core.remote_executor import run_training_remote, is_process_running
- pid = run_training_remote(job_id, model_id, model_type, dataset_path, config)
- if not pid:
- raise RuntimeError("Failed to launch remote training")
- # 轮询共享日志文件解析进度
- await self._poll_remote_progress(job_id, pid)
- logger.info(f"Remote training launched for job {job_id}")
- else:
- # 本地训练模式
- await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
- peft_config = engine.get_peft_config(peft_method, config)
- self.update_job(job_id, status=JobStatus.TRAINING)
- await self._notify_callbacks()
- adapter_path = await engine.train(
- job_id=job_id,
- dataset_path=processed_path,
- peft_config=peft_config,
- training_args=config,
- )
- self.update_job(job_id, status=JobStatus.COMPLETED, adapter_path=adapter_path)
- await self._notify_callbacks()
- logger.info(f"Job {job_id} completed successfully")
- except asyncio.CancelledError:
- self.update_job(job_id, status=JobStatus.CANCELLED)
- await self._notify_callbacks()
- except Exception as e:
- # 远程训练模式:异常时也要 kill 远程进程
- error_msg = str(e)
- if settings.use_remote_compute and "pid" in locals():
- from app.core.remote_executor import ssh_exec
- container = settings.compute_node_docker_container
- try:
- await asyncio.to_thread(
- ssh_exec,
- f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
- f"pkill -9 -P {pid} 2>/dev/null'",
- timeout=5,
- )
- logger.info(f"Killed remote process {pid} due to exception")
- except Exception:
- # kill 超时 — 进程可能被 GPU 驱动锁死,由 _poll_remote_progress 兜底处理
- logger.warning(f"Failed to kill remote process {pid}, will be handled by progress poller")
- logger.error(f"Job {job_id} failed: {error_msg}")
- self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
- await self._notify_callbacks()
- def _find_dataset_path(self, dataset_id: str) -> str | None:
- """根据 dataset_id 查找文件路径(数据库或 uploads 目录)。"""
- from app.config import get_settings
- from pathlib import Path
- settings = get_settings()
- # 尝试从 uploads 目录查找
- upload_path = settings.uploads_dir / dataset_id
- if upload_path.exists():
- return str(upload_path)
- # 如果 dataset_id 本身是路径
- if Path(dataset_id).exists():
- return dataset_id
- return None
- async def _cleanup_remote_processes(self):
- """通过 SSH 清理容器内所有残留的 python 进程(包括僵尸进程),释放 GPU ring buffer。
- 所有操作合并为一条 SSH 命令,避免多次连接导致超时。
- """
- from app.config import get_settings
- from app.core.remote_executor import ssh_exec
- settings = get_settings()
- container = settings.compute_node_docker_container
- # 一条命令完成:检查容器 → 查找 python 进程 → 逐个 kill → 输出清理结果
- cmd = (
- f"docker inspect -f '{{{{.State.Running}}}}' {container} 2>/dev/null || echo false; "
- f"if [ \"$(docker inspect -f '{{{{.State.Running}}}}' {container} 2>/dev/null)\" = 'true' ]; then "
- f"pids=$(docker exec {container} bash -c 'ps aux 2>/dev/null | grep \"[p]ython\" | grep -v grep | awk \"{{{{print \\$2}}}}\"'); "
- f"if [ -n \"$pids\" ]; then "
- f"echo \"$pids\" | while read pid; do "
- f"docker exec {container} bash -c 'kill -9 $pid 2>/dev/null; wait $pid 2>/dev/null'; "
- f"done; "
- f"echo \"cleaned $(echo \"$pids\" | wc -l) processes\"; "
- f"else echo 'no python processes'; fi; "
- f"else echo 'container not running'; fi"
- )
- code, stdout, stderr = await asyncio.to_thread(ssh_exec, cmd, timeout=60)
- if code != 0:
- logger.warning(f"Remote cleanup failed: code={code}, stderr={stderr}")
- else:
- logger.info(f"Remote cleanup result: {stdout.strip()}")
- async def _lookup_dataset_db(self, dataset_id: str) -> str | None:
- """从数据库查找数据集路径。"""
- from app.core.db import async_session, DatasetRecord
- from sqlalchemy import select
- async with async_session() as session:
- result = await session.execute(select(DatasetRecord).where(
- (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
- ))
- record = result.scalar_one_or_none()
- if record:
- return record.file_path
- return None
- def _get_engine(self, model_type: str):
- """根据模型类型选择训练引擎。"""
- if model_type == "vision":
- from app.engines.vision_engine import vision_engine
- return vision_engine
- elif model_type == "multimodal":
- from app.engines.multimodal_engine import multimodal_engine
- return multimodal_engine
- else:
- from app.engines.text_engine import text_engine
- return text_engine
- async def _poll_remote_progress(self, job_id: str, pid: str):
- """通过 SSH 读取远程日志文件,解析训练进度(非阻塞)。
- 同时把 253 容器内的 stderr 日志同步输出到 151 后端日志中。
- """
- from app.config import get_settings
- from app.core.websocket import send_progress, send_epoch_done, send_completed, send_error
- from app.core.remote_executor import ssh_exec, is_process_running
- settings = get_settings()
- remote_log = f"{settings.compute_node_remote_data_dir}/logs/{job_id}.jsonl"
- container = settings.compute_node_docker_container
- last_bytes = 0
- stderr_last_bytes = 0 # 跟踪 stderr 日志读取位置
- poll_interval = 5
- max_polls = 8640
- consecutive_empty_polls = 0
- max_consecutive_empty = 12 # 60 秒无响应就开始检查 stderr
- async def _kill_remote_process(pid: str):
- """强制 kill 远程训练进程(多种方式兜底)。"""
- # 方式1: docker exec kill -9(常规方式)
- try:
- await asyncio.to_thread(
- ssh_exec,
- f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
- f"pkill -9 -P {pid} 2>/dev/null'",
- timeout=10,
- )
- logger.info(f"Killed remote process {pid} via docker exec")
- return
- except Exception as e:
- logger.warning(f"Failed to kill process {pid} via docker exec: {e}")
- # 方式2: nsenter 从宿主机直接进入进程 namespace 发信号
- try:
- await asyncio.to_thread(
- ssh_exec,
- f"docker exec {container} bash -c 'nsenter -t {pid} -p -s -- kill -9 {pid} 2>/dev/null || kill -9 {pid} 2>/dev/null'",
- timeout=10,
- )
- logger.info(f"Killed remote process {pid} via nsenter")
- return
- except Exception as e:
- logger.warning(f"Failed to kill process {pid} via nsenter: {e}")
- # 方式3: 终极方案 — 重启整个容器(释放所有 GPU 资源)
- try:
- await asyncio.to_thread(
- ssh_exec,
- f"docker restart -t 5 {container}",
- timeout=30,
- )
- logger.warning(f"Force restarted container {container} to release GPU resources")
- except Exception as e:
- logger.error(f"Failed to restart container {container}: {e}")
- async def _mark_failed(error_msg: str):
- """统一标记失败:先 kill 远程进程,再更新状态。"""
- await _kill_remote_process(pid)
- self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
- await self._notify_callbacks()
- await send_error(job_id, error_msg)
- for _ in range(max_polls):
- if self.is_cancelled(job_id):
- await _kill_remote_process(pid)
- self.update_job(job_id, status=JobStatus.CANCELLED)
- await self._notify_callbacks()
- await send_error(job_id, "Training cancelled")
- return
- # 检查进程是否还在运行(非阻塞)
- process_alive = await asyncio.to_thread(is_process_running, pid)
- # === 1. 读取 jsonl 进度日志 ===
- cat_cmd = f"docker exec {container} bash -c 'wc -c < {remote_log} 2>/dev/null || echo 0'"
- code, size_out, _ = await asyncio.to_thread(ssh_exec, cat_cmd, timeout=30)
- try:
- file_size = int(size_out.strip()) if code == 0 and size_out.strip() else 0
- except ValueError:
- file_size = 0
- has_new_log = False
- if file_size > last_bytes:
- read_cmd = f"docker exec {container} bash -c 'tail -c +{last_bytes + 1} {remote_log} 2>/dev/null'"
- code, log_content, _ = await asyncio.to_thread(ssh_exec, read_cmd, timeout=30)
- if code == 0 and log_content.strip():
- has_new_log = True
- consecutive_empty_polls = 0
- for line in log_content.strip().split("\n"):
- line = line.strip()
- if not line:
- continue
- try:
- entry = json.loads(line)
- except json.JSONDecodeError:
- continue
- entry_type = entry.get("type")
- if entry_type == "progress":
- step = entry.get("step", 0)
- total_steps = entry.get("total_steps", 0)
- # 计算进度:total_steps 为 0 时基于 epoch 估算(每 epoch 按 100/epochs% 递增)
- if total_steps > 0:
- progress = round(step / total_steps * 100, 1)
- else:
- # 无 total_steps 时,step 每增加 1 按 0.1% 估算(兜底)
- progress = round(step * 0.1, 1)
- progress = min(99.9, max(0, progress)) # 限制在 0-99.9%,completed 时才会到 100%
- self.update_job(job_id,
- current_step=step,
- total_steps=total_steps,
- loss=entry.get("loss"),
- progress=progress)
- await self._notify_callbacks()
- await send_progress(job_id, **{k: v for k, v in entry.items() if k != "type"})
- elif entry_type == "epoch_begin":
- self.update_job(job_id, current_epoch=entry.get("epoch", 0))
- await self._notify_callbacks()
- elif entry_type == "epoch_done":
- await self._notify_callbacks()
- await send_epoch_done(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
- elif entry_type == "completed":
- adapter_path = entry.get("adapter_path", str(settings.adapters_dir / job_id))
- self.update_job(job_id,
- status=JobStatus.COMPLETED,
- adapter_path=adapter_path,
- progress=100.0)
- await self._notify_callbacks()
- await send_completed(job_id, **{k: v for k, v in entry.items() if k not in ("type", "ts")})
- return
- elif entry_type == "error":
- error_msg = entry.get("message", "Unknown error")
- logger.error(f"Remote job {job_id} failed: {error_msg}")
- await _mark_failed(error_msg)
- return
- last_bytes = file_size
- # === 2. 同步 253 stderr 日志到 151 后端日志 ===
- stderr_cmd = f"docker exec {container} bash -c 'wc -c < /tmp/train_{job_id}.log 2>/dev/null || echo 0'"
- code, stderr_size_out, _ = await asyncio.to_thread(ssh_exec, stderr_cmd, timeout=30)
- try:
- stderr_size = int(stderr_size_out.strip()) if code == 0 and stderr_size_out.strip() else 0
- except ValueError:
- stderr_size = 0
- if stderr_size > stderr_last_bytes:
- read_stderr_cmd = f"docker exec {container} bash -c 'tail -c +{stderr_last_bytes + 1} /tmp/train_{job_id}.log 2>/dev/null'"
- code, stderr_content, _ = await asyncio.to_thread(ssh_exec, read_stderr_cmd, timeout=30)
- if code == 0 and stderr_content.strip():
- for line in stderr_content.strip().split("\n"):
- line = line.strip()
- if not line:
- continue
- # 识别日志级别
- if "[remote_train]" in line:
- logger.info(f"[253:{job_id[:8]}] {line}")
- elif "[MXKW][E]" in line or "ERROR" in line or "Error" in line:
- logger.error(f"[253:{job_id[:8]}] {line}")
- elif "[transformers]" in line or "UserWarning" in line or "Warning" in line:
- logger.warning(f"[253:{job_id[:8]}] {line}")
- else:
- logger.info(f"[253:{job_id[:8]}] {line}")
- stderr_last_bytes = stderr_size
- if not has_new_log:
- consecutive_empty_polls += 1
- # 进程已退出但日志里没有 completed/error
- if not process_alive:
- # 多等几秒让日志写完
- await asyncio.sleep(2)
- if not await asyncio.to_thread(is_process_running, pid):
- # 进程退出但没有写 completed/error 日志,读取 stderr 日志兜底
- error_msg = f"Remote process exited unexpectedly (pid={pid})"
- try:
- from app.core.remote_executor import get_remote_stderr
- stderr_content = await asyncio.to_thread(get_remote_stderr, job_id)
- if stderr_content:
- error_msg = stderr_content[-1000:]
- except Exception:
- pass
- logger.error(f"Remote job {job_id} failed: {error_msg}")
- await _mark_failed(error_msg)
- return
- # 长时间无日志且进程异常,也标记为失败
- if consecutive_empty_polls >= max_consecutive_empty and not process_alive:
- error_msg = f"Remote process exited unexpectedly (pid={pid}), no error log found"
- logger.error(f"Remote job {job_id} failed: {error_msg}")
- await _mark_failed(error_msg)
- return
- await asyncio.sleep(poll_interval)
- # 超时
- error_msg = "Remote training timed out"
- logger.error(f"Remote job {job_id} failed: {error_msg}")
- await _mark_failed(error_msg)
- @property
- def jobs(self) -> dict[str, TrainingJob]:
- return dict(self._jobs)
- # 全局单例
- job_queue = JobQueue(max_concurrent=2)
|