|
@@ -1,9 +1,10 @@
|
|
|
import json
|
|
import json
|
|
|
import uuid
|
|
import uuid
|
|
|
-from datetime import datetime
|
|
|
|
|
|
|
+from datetime import datetime, timezone
|
|
|
from typing import Any
|
|
from typing import Any
|
|
|
|
|
|
|
|
from app.config import get_settings
|
|
from app.config import get_settings
|
|
|
|
|
+from app.core.background_tasks import background_task_manager
|
|
|
from app.core.db import async_session, EvalResultModel
|
|
from app.core.db import async_session, EvalResultModel
|
|
|
from app.core.logging import logger
|
|
from app.core.logging import logger
|
|
|
from app.core.remote_executor import ssh_exec
|
|
from app.core.remote_executor import ssh_exec
|
|
@@ -13,20 +14,43 @@ settings = get_settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
|
- """在已训练的 adapter 上运行评估(perplexity)。"""
|
|
|
|
|
|
|
+ """启动评估后台任务,立即返回 eval_id。"""
|
|
|
eval_id = str(uuid.uuid4())
|
|
eval_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
- # 远程训练模式:把评估任务也发到远程容器执行
|
|
|
|
|
- if settings.use_remote_compute:
|
|
|
|
|
- logger.info(f"Running remote evaluation for job {job_id}")
|
|
|
|
|
- return await _run_remote_evaluation(eval_id, job_id)
|
|
|
|
|
|
|
+ # 写 DB
|
|
|
|
|
+ record = EvalResultModel(
|
|
|
|
|
+ id=eval_id,
|
|
|
|
|
+ job_id=job_id,
|
|
|
|
|
+ status="pending",
|
|
|
|
|
+ metrics="{}",
|
|
|
|
|
+ )
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ session.add(record)
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+
|
|
|
|
|
+ # 注册并启动
|
|
|
|
|
+ background_task_manager.register_task(eval_id, "evaluation", {"job_id": job_id})
|
|
|
|
|
+ background_task_manager.run(
|
|
|
|
|
+ eval_id, "evaluation", _execute_evaluation(eval_id, job_id, config)
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- adapter_path = settings.adapters_dir / job_id
|
|
|
|
|
|
|
+ logger.info(f"Evaluation task started: job={job_id} (eval_id={eval_id})")
|
|
|
|
|
+ return {"id": eval_id, "job_id": job_id, "status": "pending"}
|
|
|
|
|
|
|
|
- if not adapter_path.exists():
|
|
|
|
|
- return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"}
|
|
|
|
|
|
|
|
|
|
|
|
+async def _execute_evaluation(eval_id: str, job_id: str, config: dict[str, Any]) -> dict:
|
|
|
|
|
+ """后台执行评估。"""
|
|
|
try:
|
|
try:
|
|
|
|
|
+ # 远程训练模式:把评估任务也发到远程容器执行
|
|
|
|
|
+ if settings.use_remote_compute:
|
|
|
|
|
+ logger.info(f"Running remote evaluation for job {job_id}")
|
|
|
|
|
+ result = await _run_remote_evaluation(eval_id, job_id)
|
|
|
|
|
+ return {"metrics": result.get("metrics", {})}
|
|
|
|
|
+
|
|
|
|
|
+ adapter_path = settings.adapters_dir / job_id
|
|
|
|
|
+ if not adapter_path.exists():
|
|
|
|
|
+ raise ValueError("Adapter not found")
|
|
|
|
|
+
|
|
|
import torch
|
|
import torch
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
@@ -34,24 +58,13 @@ async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
|
model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
|
|
model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
|
|
|
tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
|
|
tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
|
|
|
|
|
|
|
|
- # 加载评估数据
|
|
|
|
|
- async with async_session() as session:
|
|
|
|
|
- from app.core.db import TrainingJobModel
|
|
|
|
|
- result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
|
|
|
|
|
- record = result.scalar_one_or_none()
|
|
|
|
|
-
|
|
|
|
|
- if record:
|
|
|
|
|
- dataset_path = record.dataset_id
|
|
|
|
|
-
|
|
|
|
|
- metrics = {}
|
|
|
|
|
- model.eval()
|
|
|
|
|
-
|
|
|
|
|
- # 计算 perplexity(使用 adapter 自身的数据或默认样例)
|
|
|
|
|
|
|
+ # 计算 perplexity
|
|
|
sample_texts = [
|
|
sample_texts = [
|
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
|
"Hello, how are you doing today?",
|
|
"Hello, how are you doing today?",
|
|
|
]
|
|
]
|
|
|
losses = []
|
|
losses = []
|
|
|
|
|
+ model.eval()
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
for text in sample_texts:
|
|
for text in sample_texts:
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
|
@@ -67,23 +80,29 @@ async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
|
|
|
"num_samples": len(sample_texts),
|
|
"num_samples": len(sample_texts),
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- # 保存结果
|
|
|
|
|
- eval_record = EvalResultModel(
|
|
|
|
|
- id=eval_id,
|
|
|
|
|
- job_id=job_id,
|
|
|
|
|
- metrics=json.dumps(metrics),
|
|
|
|
|
- created_at=datetime.utcnow(),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ # 更新 DB
|
|
|
async with async_session() as session:
|
|
async with async_session() as session:
|
|
|
- session.add(eval_record)
|
|
|
|
|
- await session.commit()
|
|
|
|
|
|
|
+ result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
|
|
|
|
|
+ eval_record = result.scalar_one_or_none()
|
|
|
|
|
+ if eval_record:
|
|
|
|
|
+ eval_record.metrics = json.dumps(metrics)
|
|
|
|
|
+ eval_record.status = "completed"
|
|
|
|
|
+ eval_record.progress = 100.0
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
|
|
|
logger.info(f"Evaluation completed for job {job_id}: {metrics}")
|
|
logger.info(f"Evaluation completed for job {job_id}: {metrics}")
|
|
|
- return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()}
|
|
|
|
|
|
|
+ return {"metrics": metrics}
|
|
|
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"Evaluation failed for job {job_id}: {e}")
|
|
logger.error(f"Evaluation failed for job {job_id}: {e}")
|
|
|
- return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)}
|
|
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
|
|
|
|
|
+ eval_record = result.scalar_one_or_none()
|
|
|
|
|
+ if eval_record:
|
|
|
|
|
+ eval_record.status = "failed"
|
|
|
|
|
+ eval_record.error = str(e)
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+ return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
|
|
async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
|
|
@@ -91,7 +110,7 @@ async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
|
|
|
remote_cmd = (
|
|
remote_cmd = (
|
|
|
f"docker exec "
|
|
f"docker exec "
|
|
|
f"-e MACA_MPS_MODE=1 "
|
|
f"-e MACA_MPS_MODE=1 "
|
|
|
- f"-e CUDA_VISIBLE_DEVICES=2,3 "
|
|
|
|
|
|
|
+ f"-e CUDA_VISIBLE_DEVICES=3 "
|
|
|
f"-w {settings.compute_node_workdir} "
|
|
f"-w {settings.compute_node_workdir} "
|
|
|
f"{settings.compute_node_docker_container} "
|
|
f"{settings.compute_node_docker_container} "
|
|
|
f"{settings.compute_node_python} -c \""
|
|
f"{settings.compute_node_python} -c \""
|
|
@@ -104,8 +123,7 @@ async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
|
|
|
code, stdout, stderr = ssh_exec(remote_cmd, timeout=300)
|
|
code, stdout, stderr = ssh_exec(remote_cmd, timeout=300)
|
|
|
|
|
|
|
|
if code != 0:
|
|
if code != 0:
|
|
|
- logger.error(f"Remote evaluation failed: {stderr}")
|
|
|
|
|
- return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": stderr.strip()}
|
|
|
|
|
|
|
+ raise RuntimeError(f"Remote evaluation failed: {stderr}")
|
|
|
|
|
|
|
|
# 提取最后一行 JSON
|
|
# 提取最后一行 JSON
|
|
|
for line in reversed(stdout.strip().split("\n")):
|
|
for line in reversed(stdout.strip().split("\n")):
|
|
@@ -114,21 +132,22 @@ async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
|
|
|
try:
|
|
try:
|
|
|
result = json.loads(line)
|
|
result = json.loads(line)
|
|
|
# 保存结果到本地数据库
|
|
# 保存结果到本地数据库
|
|
|
- eval_record = EvalResultModel(
|
|
|
|
|
- id=eval_id,
|
|
|
|
|
- job_id=job_id,
|
|
|
|
|
- metrics=json.dumps(result.get("metrics", {})),
|
|
|
|
|
- created_at=datetime.utcnow(),
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ metrics = result.get("metrics", {})
|
|
|
async with async_session() as session:
|
|
async with async_session() as session:
|
|
|
|
|
+ eval_record = EvalResultModel(
|
|
|
|
|
+ id=eval_id,
|
|
|
|
|
+ job_id=job_id,
|
|
|
|
|
+ metrics=json.dumps(metrics),
|
|
|
|
|
+ status="completed",
|
|
|
|
|
+ created_at=datetime.utcnow(),
|
|
|
|
|
+ )
|
|
|
session.add(eval_record)
|
|
session.add(eval_record)
|
|
|
await session.commit()
|
|
await session.commit()
|
|
|
- return {"id": eval_id, "job_id": job_id, "metrics": result.get("metrics", {}),
|
|
|
|
|
- "created_at": eval_record.created_at.isoformat()}
|
|
|
|
|
|
|
+ return {"id": eval_id, "job_id": job_id, "metrics": metrics}
|
|
|
except json.JSONDecodeError:
|
|
except json.JSONDecodeError:
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": f"Invalid response: {stdout[:500]}"}
|
|
|
|
|
|
|
+ raise RuntimeError(f"Invalid response: {stdout[:500]}")
|
|
|
|
|
|
|
|
|
|
|
|
|
async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
|
|
async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
|
|
@@ -140,7 +159,26 @@ async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
|
|
|
return {
|
|
return {
|
|
|
"id": record.id,
|
|
"id": record.id,
|
|
|
"job_id": record.job_id,
|
|
"job_id": record.job_id,
|
|
|
|
|
+ "status": record.status,
|
|
|
|
|
+ "progress": record.progress,
|
|
|
"metrics": json.loads(record.metrics) if record.metrics else {},
|
|
"metrics": json.loads(record.metrics) if record.metrics else {},
|
|
|
|
|
+ "error": record.error,
|
|
|
"created_at": record.created_at.isoformat(),
|
|
"created_at": record.created_at.isoformat(),
|
|
|
}
|
|
}
|
|
|
- return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}
|
|
|
|
|
|
|
+ return {"id": eval_id, "job_id": "", "status": "not_found", "metrics": {}}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def recover_stale_evaluations() -> None:
|
|
|
|
|
+ async with async_session() as session:
|
|
|
|
|
+ result = await session.execute(
|
|
|
|
|
+ select(EvalResultModel).where(
|
|
|
|
|
+ EvalResultModel.status.in_(["pending", "running"])
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ records = result.scalars().all()
|
|
|
|
|
+ for record in records:
|
|
|
|
|
+ record.status = "failed"
|
|
|
|
|
+ record.error = "Server restarted, task interrupted"
|
|
|
|
|
+ if records:
|
|
|
|
|
+ await session.commit()
|
|
|
|
|
+ logger.info(f"Recovered {len(records)} stale evaluation tasks")
|