|
|
@@ -232,8 +232,23 @@ class JobQueue:
|
|
|
self.update_job(job_id, status=JobStatus.CANCELLED)
|
|
|
await self._notify_callbacks()
|
|
|
except Exception as e:
|
|
|
- logger.error(f"Job {job_id} failed: {e}")
|
|
|
- self.update_job(job_id, status=JobStatus.FAILED, error_message=str(e))
|
|
|
+ # 远程训练模式:异常时也要 kill 远程进程
|
|
|
+ error_msg = str(e)
|
|
|
+ if settings.use_remote_compute and "pid" in locals():
|
|
|
+ from app.core.remote_executor import ssh_exec
|
|
|
+ container = settings.compute_node_docker_container
|
|
|
+ try:
|
|
|
+ ssh_exec(
|
|
|
+ f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
|
|
|
+ f"pkill -9 -P {pid} 2>/dev/null'",
|
|
|
+ timeout=15,
|
|
|
+ )
|
|
|
+ logger.info(f"Killed remote process {pid} due to exception")
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ logger.error(f"Job {job_id} failed: {error_msg}")
|
|
|
+ self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
|
|
|
await self._notify_callbacks()
|
|
|
|
|
|
def _find_dataset_path(self, dataset_id: str) -> str | None:
|
|
|
@@ -281,23 +296,51 @@ class JobQueue:
|
|
|
return text_engine
|
|
|
|
|
|
async def _poll_remote_progress(self, job_id: str, pid: str):
|
|
|
- """通过 SSH 读取远程日志文件,解析训练进度(非阻塞)。"""
|
|
|
+ """通过 SSH 读取远程日志文件,解析训练进度(非阻塞)。
|
|
|
+
|
|
|
+ 同时把 253 容器内的 stderr 日志同步输出到 151 后端日志中。
|
|
|
+ """
|
|
|
from app.config import get_settings
|
|
|
from app.core.websocket import send_progress, send_epoch_done, send_completed, send_error
|
|
|
from app.core.remote_executor import ssh_exec, is_process_running
|
|
|
|
|
|
settings = get_settings()
|
|
|
remote_log = f"{settings.compute_node_remote_data_dir}/logs/{job_id}.jsonl"
|
|
|
+ container = settings.compute_node_docker_container
|
|
|
+
|
|
|
last_bytes = 0
|
|
|
+ stderr_last_bytes = 0 # 跟踪 stderr 日志读取位置
|
|
|
poll_interval = 5
|
|
|
max_polls = 8640
|
|
|
consecutive_empty_polls = 0
|
|
|
max_consecutive_empty = 12 # 60 秒无响应就开始检查 stderr
|
|
|
|
|
|
+ async def _mark_failed(error_msg: str):
|
|
|
+ """统一标记失败:先 kill 远程进程,再更新状态。"""
|
|
|
+ # 先杀远程进程,防止 GPU 一直被占用
|
|
|
+ try:
|
|
|
+ await asyncio.to_thread(
|
|
|
+ ssh_exec,
|
|
|
+ f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
|
|
|
+ f"pkill -9 -P {pid} 2>/dev/null'",
|
|
|
+ timeout=15,
|
|
|
+ )
|
|
|
+ logger.info(f"Killed remote process {pid} for job {job_id}")
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+
|
|
|
+ self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
|
|
|
+ await self._notify_callbacks()
|
|
|
+ await send_error(job_id, error_msg)
|
|
|
+
|
|
|
for _ in range(max_polls):
|
|
|
if self.is_cancelled(job_id):
|
|
|
- _s = get_settings()
|
|
|
- await asyncio.to_thread(ssh_exec, f"docker exec {_s.compute_node_docker_container} bash -c 'kill {pid} 2>/dev/null'", timeout=10)
|
|
|
+ await asyncio.to_thread(
|
|
|
+ ssh_exec,
|
|
|
+ f"docker exec {container} bash -c 'kill -9 {pid} 2>/dev/null; "
|
|
|
+ f"pkill -9 -P {pid} 2>/dev/null'",
|
|
|
+ timeout=15,
|
|
|
+ )
|
|
|
self.update_job(job_id, status=JobStatus.CANCELLED)
|
|
|
await self._notify_callbacks()
|
|
|
await send_error(job_id, "Training cancelled")
|
|
|
@@ -306,8 +349,8 @@ class JobQueue:
|
|
|
# 检查进程是否还在运行(非阻塞)
|
|
|
process_alive = await asyncio.to_thread(is_process_running, pid)
|
|
|
|
|
|
- # 通过 SSH 远程读取日志文件(非阻塞)
|
|
|
- cat_cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'wc -c < {remote_log} 2>/dev/null || echo 0'"
|
|
|
+ # === 1. 读取 jsonl 进度日志 ===
|
|
|
+ cat_cmd = f"docker exec {container} bash -c 'wc -c < {remote_log} 2>/dev/null || echo 0'"
|
|
|
code, size_out, _ = await asyncio.to_thread(ssh_exec, cat_cmd, timeout=30)
|
|
|
try:
|
|
|
file_size = int(size_out.strip()) if code == 0 and size_out.strip() else 0
|
|
|
@@ -316,7 +359,7 @@ class JobQueue:
|
|
|
|
|
|
has_new_log = False
|
|
|
if file_size > last_bytes:
|
|
|
- read_cmd = f"docker exec {settings.compute_node_docker_container} bash -c 'tail -c +{last_bytes + 1} {remote_log} 2>/dev/null'"
|
|
|
+ read_cmd = f"docker exec {container} bash -c 'tail -c +{last_bytes + 1} {remote_log} 2>/dev/null'"
|
|
|
code, log_content, _ = await asyncio.to_thread(ssh_exec, read_cmd, timeout=30)
|
|
|
|
|
|
if code == 0 and log_content.strip():
|
|
|
@@ -362,15 +405,38 @@ class JobQueue:
|
|
|
elif entry_type == "error":
|
|
|
error_msg = entry.get("message", "Unknown error")
|
|
|
logger.error(f"Remote job {job_id} failed: {error_msg}")
|
|
|
- self.update_job(job_id,
|
|
|
- status=JobStatus.FAILED,
|
|
|
- error_message=error_msg)
|
|
|
- await self._notify_callbacks()
|
|
|
- await send_error(job_id, error_msg)
|
|
|
+ await _mark_failed(error_msg)
|
|
|
return
|
|
|
|
|
|
last_bytes = file_size
|
|
|
|
|
|
+ # === 2. 同步 253 stderr 日志到 151 后端日志 ===
|
|
|
+ stderr_cmd = f"docker exec {container} bash -c 'wc -c < /tmp/train_{job_id}.log 2>/dev/null || echo 0'"
|
|
|
+ code, stderr_size_out, _ = await asyncio.to_thread(ssh_exec, stderr_cmd, timeout=30)
|
|
|
+ try:
|
|
|
+ stderr_size = int(stderr_size_out.strip()) if code == 0 and stderr_size_out.strip() else 0
|
|
|
+ except ValueError:
|
|
|
+ stderr_size = 0
|
|
|
+
|
|
|
+ if stderr_size > stderr_last_bytes:
|
|
|
+ read_stderr_cmd = f"docker exec {container} bash -c 'tail -c +{stderr_last_bytes + 1} /tmp/train_{job_id}.log 2>/dev/null'"
|
|
|
+ code, stderr_content, _ = await asyncio.to_thread(ssh_exec, read_stderr_cmd, timeout=30)
|
|
|
+ if code == 0 and stderr_content.strip():
|
|
|
+ for line in stderr_content.strip().split("\n"):
|
|
|
+ line = line.strip()
|
|
|
+ if not line:
|
|
|
+ continue
|
|
|
+ # 识别日志级别
|
|
|
+ if "[remote_train]" in line:
|
|
|
+ logger.info(f"[253:{job_id[:8]}] {line}")
|
|
|
+ elif "[MXKW][E]" in line or "ERROR" in line or "Error" in line:
|
|
|
+ logger.error(f"[253:{job_id[:8]}] {line}")
|
|
|
+ elif "[transformers]" in line or "UserWarning" in line or "Warning" in line:
|
|
|
+ logger.warning(f"[253:{job_id[:8]}] {line}")
|
|
|
+ else:
|
|
|
+ logger.info(f"[253:{job_id[:8]}] {line}")
|
|
|
+ stderr_last_bytes = stderr_size
|
|
|
+
|
|
|
if not has_new_log:
|
|
|
consecutive_empty_polls += 1
|
|
|
|
|
|
@@ -390,20 +456,14 @@ class JobQueue:
|
|
|
pass
|
|
|
|
|
|
logger.error(f"Remote job {job_id} failed: {error_msg}")
|
|
|
- self.update_job(job_id,
|
|
|
- status=JobStatus.FAILED,
|
|
|
- error_message=error_msg)
|
|
|
- await self._notify_callbacks()
|
|
|
- await send_error(job_id, error_msg)
|
|
|
+ await _mark_failed(error_msg)
|
|
|
return
|
|
|
|
|
|
# 长时间无日志且进程异常,也标记为失败
|
|
|
if consecutive_empty_polls >= max_consecutive_empty and not process_alive:
|
|
|
error_msg = f"Remote process exited unexpectedly (pid={pid}), no error log found"
|
|
|
logger.error(f"Remote job {job_id} failed: {error_msg}")
|
|
|
- self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
|
|
|
- await self._notify_callbacks()
|
|
|
- await send_error(job_id, error_msg)
|
|
|
+ await _mark_failed(error_msg)
|
|
|
return
|
|
|
|
|
|
await asyncio.sleep(poll_interval)
|
|
|
@@ -411,9 +471,7 @@ class JobQueue:
|
|
|
# 超时
|
|
|
error_msg = "Remote training timed out"
|
|
|
logger.error(f"Remote job {job_id} failed: {error_msg}")
|
|
|
- self.update_job(job_id, status=JobStatus.FAILED, error_message=error_msg)
|
|
|
- await self._notify_callbacks()
|
|
|
- await send_error(job_id, error_msg)
|
|
|
+ await _mark_failed(error_msg)
|
|
|
|
|
|
@property
|
|
|
def jobs(self) -> dict[str, TrainingJob]:
|