import json import uuid from datetime import datetime, timezone from typing import Any from app.config import get_settings from app.core.db import async_session, EvalResultModel from app.core.logging import logger from sqlalchemy import select settings = get_settings() async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]: """在已训练的 adapter 上运行评估(perplexity)。""" eval_id = str(uuid.uuid4()) adapter_path = settings.adapters_dir / job_id if not adapter_path.exists(): return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"} try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer # 加载 base model + adapter model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto") tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) # 加载评估数据 async with async_session() as session: from app.core.db import TrainingJobModel result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id)) record = result.scalar_one_or_none() if record: dataset_path = record.dataset_id # 这里简化处理,实际应从文件系统读取 metrics = {} model.eval() # 计算 perplexity(使用 adapter 自身的数据或默认样例) sample_texts = [ "The quick brown fox jumps over the lazy dog.", "Hello, how are you doing today?", ] losses = [] with torch.no_grad(): for text in sample_texts: inputs = tokenizer(text, return_tensors="pt").to(model.device) outputs = model(**inputs, labels=inputs["input_ids"]) losses.append(outputs.loss.item()) avg_loss = sum(losses) / len(losses) if losses else 0 perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0 metrics = { "eval_loss": round(avg_loss, 4), "perplexity": round(perplexity, 2), "num_samples": len(sample_texts), } # 保存结果 eval_record = EvalResultModel( id=eval_id, job_id=job_id, metrics=json.dumps(metrics), created_at=datetime.now(timezone.utc), ) async with async_session() as session: session.add(eval_record) await session.commit() logger.info(f"Evaluation completed for job {job_id}: {metrics}") return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()} except Exception as e: logger.error(f"Evaluation failed for job {job_id}: {e}") return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)} async def get_evaluation_results(eval_id: str) -> dict[str, Any]: """获取已完成评估的结果。""" async with async_session() as session: result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id)) record = result.scalar_one_or_none() if record: return { "id": record.id, "job_id": record.job_id, "metrics": json.loads(record.metrics) if record.metrics else {}, "created_at": record.created_at.isoformat(), } return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}