eval_service.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import json
  2. import uuid
  3. from datetime import datetime, timezone
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.db import async_session, EvalResultModel
  7. from app.core.logging import logger
  8. from sqlalchemy import select
  9. settings = get_settings()
  10. async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
  11. """在已训练的 adapter 上运行评估(perplexity)。"""
  12. eval_id = str(uuid.uuid4())
  13. adapter_path = settings.adapters_dir / job_id
  14. if not adapter_path.exists():
  15. return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"}
  16. try:
  17. import torch
  18. from transformers import AutoModelForCausalLM, AutoTokenizer
  19. # 加载 base model + adapter
  20. model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
  21. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  22. # 加载评估数据
  23. async with async_session() as session:
  24. from app.core.db import TrainingJobModel
  25. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  26. record = result.scalar_one_or_none()
  27. if record:
  28. dataset_path = record.dataset_id # 这里简化处理,实际应从文件系统读取
  29. metrics = {}
  30. model.eval()
  31. # 计算 perplexity(使用 adapter 自身的数据或默认样例)
  32. sample_texts = [
  33. "The quick brown fox jumps over the lazy dog.",
  34. "Hello, how are you doing today?",
  35. ]
  36. losses = []
  37. with torch.no_grad():
  38. for text in sample_texts:
  39. inputs = tokenizer(text, return_tensors="pt").to(model.device)
  40. outputs = model(**inputs, labels=inputs["input_ids"])
  41. losses.append(outputs.loss.item())
  42. avg_loss = sum(losses) / len(losses) if losses else 0
  43. perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0
  44. metrics = {
  45. "eval_loss": round(avg_loss, 4),
  46. "perplexity": round(perplexity, 2),
  47. "num_samples": len(sample_texts),
  48. }
  49. # 保存结果
  50. eval_record = EvalResultModel(
  51. id=eval_id,
  52. job_id=job_id,
  53. metrics=json.dumps(metrics),
  54. created_at=datetime.now(timezone.utc),
  55. )
  56. async with async_session() as session:
  57. session.add(eval_record)
  58. await session.commit()
  59. logger.info(f"Evaluation completed for job {job_id}: {metrics}")
  60. return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()}
  61. except Exception as e:
  62. logger.error(f"Evaluation failed for job {job_id}: {e}")
  63. return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)}
  64. async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
  65. """获取已完成评估的结果。"""
  66. async with async_session() as session:
  67. result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
  68. record = result.scalar_one_or_none()
  69. if record:
  70. return {
  71. "id": record.id,
  72. "job_id": record.job_id,
  73. "metrics": json.loads(record.metrics) if record.metrics else {},
  74. "created_at": record.created_at.isoformat(),
  75. }
  76. return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}