import json import uuid from datetime import datetime from typing import Any from app.config import get_settings from app.core.db import async_session, EvalResultModel from app.core.logging import logger from app.core.remote_executor import ssh_exec from sqlalchemy import select settings = get_settings() async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]: """在已训练的 adapter 上运行评估(perplexity)。""" eval_id = str(uuid.uuid4()) # 远程训练模式:把评估任务也发到远程容器执行 if settings.use_remote_compute: logger.info(f"Running remote evaluation for job {job_id}") return await _run_remote_evaluation(eval_id, job_id) adapter_path = settings.adapters_dir / job_id if not adapter_path.exists(): return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"} try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer # 加载 base model + adapter model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) # 加载评估数据 async with async_session() as session: from app.core.db import TrainingJobModel result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id)) record = result.scalar_one_or_none() if record: dataset_path = record.dataset_id metrics = {} model.eval() # 计算 perplexity(使用 adapter 自身的数据或默认样例) sample_texts = [ "The quick brown fox jumps over the lazy dog.", "Hello, how are you doing today?", ] losses = [] with torch.no_grad(): for text in sample_texts: inputs = tokenizer(text, return_tensors="pt").to(model.device) outputs = model(**inputs, labels=inputs["input_ids"]) losses.append(outputs.loss.item()) avg_loss = sum(losses) / len(losses) if losses else 0 perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0 metrics = { "eval_loss": round(avg_loss, 4), "perplexity": round(perplexity, 2), "num_samples": len(sample_texts), } # 保存结果 eval_record = EvalResultModel( id=eval_id, job_id=job_id, metrics=json.dumps(metrics), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(eval_record) await session.commit() logger.info(f"Evaluation completed for job {job_id}: {metrics}") return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()} except Exception as e: logger.error(f"Evaluation failed for job {job_id}: {e}") return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)} async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]: """通过 SSH 在远程容器里执行评估。""" remote_cmd = ( f"docker exec " f"-e MACA_MPS_MODE=1 " f"-e CUDA_VISIBLE_DEVICES=2,3 " f"-w {settings.compute_node_workdir} " f"{settings.compute_node_docker_container} " f"{settings.compute_node_python} -c \"" "import asyncio, json; " "from app.core.remote_eval import run_remote_eval; " f"result = asyncio.run(run_remote_eval('{job_id}')); " "print(json.dumps(result, ensure_ascii=False))\" 2>&1" ) code, stdout, stderr = ssh_exec(remote_cmd, timeout=300) if code != 0: logger.error(f"Remote evaluation failed: {stderr}") return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": stderr.strip()} # 提取最后一行 JSON for line in reversed(stdout.strip().split("\n")): line = line.strip() if line.startswith("{"): try: result = json.loads(line) # 保存结果到本地数据库 eval_record = EvalResultModel( id=eval_id, job_id=job_id, metrics=json.dumps(result.get("metrics", {})), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(eval_record) await session.commit() return {"id": eval_id, "job_id": job_id, "metrics": result.get("metrics", {}), "created_at": eval_record.created_at.isoformat()} except json.JSONDecodeError: continue return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": f"Invalid response: {stdout[:500]}"} async def get_evaluation_results(eval_id: str) -> dict[str, Any]: """获取已完成评估的结果。""" async with async_session() as session: result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id)) record = result.scalar_one_or_none() if record: return { "id": record.id, "job_id": record.job_id, "metrics": json.loads(record.metrics) if record.metrics else {}, "created_at": record.created_at.isoformat(), } return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}