| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import json
- import uuid
- from datetime import datetime, timezone
- from typing import Any
- from app.config import get_settings
- from app.core.db import async_session, EvalResultModel
- from app.core.logging import logger
- from sqlalchemy import select
- settings = get_settings()
- async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
- """在已训练的 adapter 上运行评估(perplexity)。"""
- eval_id = str(uuid.uuid4())
- adapter_path = settings.adapters_dir / job_id
- if not adapter_path.exists():
- return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"}
- try:
- import torch
- from transformers import AutoModelForCausalLM, AutoTokenizer
- # 加载 base model + adapter
- model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
- tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
- # 加载评估数据
- async with async_session() as session:
- from app.core.db import TrainingJobModel
- result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
- record = result.scalar_one_or_none()
- if record:
- dataset_path = record.dataset_id # 这里简化处理,实际应从文件系统读取
- metrics = {}
- model.eval()
- # 计算 perplexity(使用 adapter 自身的数据或默认样例)
- sample_texts = [
- "The quick brown fox jumps over the lazy dog.",
- "Hello, how are you doing today?",
- ]
- losses = []
- with torch.no_grad():
- for text in sample_texts:
- inputs = tokenizer(text, return_tensors="pt").to(model.device)
- outputs = model(**inputs, labels=inputs["input_ids"])
- losses.append(outputs.loss.item())
- avg_loss = sum(losses) / len(losses) if losses else 0
- perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0
- metrics = {
- "eval_loss": round(avg_loss, 4),
- "perplexity": round(perplexity, 2),
- "num_samples": len(sample_texts),
- }
- # 保存结果
- eval_record = EvalResultModel(
- id=eval_id,
- job_id=job_id,
- metrics=json.dumps(metrics),
- created_at=datetime.now(timezone.utc),
- )
- async with async_session() as session:
- session.add(eval_record)
- await session.commit()
- logger.info(f"Evaluation completed for job {job_id}: {metrics}")
- return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()}
- except Exception as e:
- logger.error(f"Evaluation failed for job {job_id}: {e}")
- return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)}
- async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
- """获取已完成评估的结果。"""
- async with async_session() as session:
- result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
- record = result.scalar_one_or_none()
- if record:
- return {
- "id": record.id,
- "job_id": record.job_id,
- "metrics": json.loads(record.metrics) if record.metrics else {},
- "created_at": record.created_at.isoformat(),
- }
- return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}
|