import asyncio import json import uuid from datetime import datetime, timezone from typing import Any from app.config import get_settings from app.core.background_tasks import background_task_manager from app.core.db import async_session, EvalResultModel from app.core.logging import logger from app.core.remote_executor import ssh_exec from sqlalchemy import select settings = get_settings() async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]: """启动评估后台任务,立即返回 eval_id。""" eval_id = str(uuid.uuid4()) # 写 DB record = EvalResultModel( id=eval_id, job_id=job_id, status="pending", metrics="{}", ) async with async_session() as session: session.add(record) await session.commit() # 注册并启动 background_task_manager.register_task(eval_id, "evaluation", {"job_id": job_id}) await background_task_manager.run( eval_id, "evaluation", _execute_evaluation(eval_id, job_id, config) ) logger.info(f"Evaluation task started: job={job_id} (eval_id={eval_id})") return {"id": eval_id, "job_id": job_id, "status": "pending"} async def _execute_evaluation(eval_id: str, job_id: str, config: dict[str, Any]) -> dict: """后台执行评估。""" try: # 远程训练模式:把评估任务也发到远程容器执行 if settings.use_remote_compute: logger.info(f"Running remote evaluation for job {job_id}") result = await _run_remote_evaluation(eval_id, job_id) return {"metrics": result.get("metrics", {})} adapter_path = settings.adapters_dir / job_id if not adapter_path.exists(): raise ValueError("Adapter not found") import torch from transformers import AutoModelForCausalLM, AutoTokenizer # 加载 base model + adapter model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) # 计算 perplexity sample_texts = [ "The quick brown fox jumps over the lazy dog.", "Hello, how are you doing today?", ] losses = [] model.eval() with torch.no_grad(): for text in sample_texts: inputs = tokenizer(text, return_tensors="pt").to(model.device) outputs = model(**inputs, labels=inputs["input_ids"]) losses.append(outputs.loss.item()) avg_loss = sum(losses) / len(losses) if losses else 0 perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0 metrics = { "eval_loss": round(avg_loss, 4), "perplexity": round(perplexity, 2), "num_samples": len(sample_texts), } # 更新 DB async with async_session() as session: result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id)) eval_record = result.scalar_one_or_none() if eval_record: eval_record.metrics = json.dumps(metrics) eval_record.status = "completed" eval_record.progress = 100.0 await session.commit() logger.info(f"Evaluation completed for job {job_id}: {metrics}") return {"metrics": metrics} except Exception as e: logger.error(f"Evaluation failed for job {job_id}: {e}") async with async_session() as session: result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id)) eval_record = result.scalar_one_or_none() if eval_record: eval_record.status = "failed" eval_record.error = str(e) await session.commit() return {"error": str(e)} async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]: """通过 SSH 在远程容器里执行评估。""" remote_cmd = ( f"docker exec " f"-e MACA_MPS_MODE=1 " f"-e CUDA_VISIBLE_DEVICES=3 " f"-w {settings.compute_node_workdir} " f"{settings.compute_node_docker_container} " f"{settings.compute_node_python} -c \"" "import asyncio, json; " "from app.core.remote_eval import run_remote_eval; " f"result = asyncio.run(run_remote_eval('{job_id}')); " "print(json.dumps(result, ensure_ascii=False))\" 2>&1" ) code, stdout, stderr = await asyncio.to_thread(ssh_exec, remote_cmd, timeout=300) if code != 0: raise RuntimeError(f"Remote evaluation failed: {stderr}") # 提取最后一行 JSON for line in reversed(stdout.strip().split("\n")): line = line.strip() if line.startswith("{"): try: result = json.loads(line) # 保存结果到本地数据库(更新已有记录) metrics = result.get("metrics", {}) async with async_session() as session: res = await session.execute( select(EvalResultModel).where(EvalResultModel.id == eval_id) ) eval_record = res.scalar_one_or_none() if eval_record: eval_record.metrics = json.dumps(metrics) eval_record.status = "completed" eval_record.progress = 100.0 await session.commit() return {"id": eval_id, "job_id": job_id, "metrics": metrics} except json.JSONDecodeError: continue raise RuntimeError(f"Invalid response: {stdout[:500]}") async def get_evaluation_results(eval_id: str) -> dict[str, Any]: """获取已完成评估的结果。""" async with async_session() as session: result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id)) record = result.scalar_one_or_none() if record: return { "id": record.id, "job_id": record.job_id, "status": record.status, "progress": record.progress, "metrics": json.loads(record.metrics) if record.metrics else {}, "error": record.error, "created_at": record.created_at.isoformat(), } return {"id": eval_id, "job_id": "", "status": "not_found", "metrics": {}} async def recover_stale_evaluations() -> None: async with async_session() as session: result = await session.execute( select(EvalResultModel).where( EvalResultModel.status.in_(["pending", "running"]) ) ) records = result.scalars().all() for record in records: record.status = "failed" record.error = "Server restarted, task interrupted" if records: await session.commit() logger.info(f"Recovered {len(records)} stale evaluation tasks")