eval_service.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import json
  2. import uuid
  3. from datetime import datetime
  4. from typing import Any
  5. from app.config import get_settings
  6. from app.core.db import async_session, EvalResultModel
  7. from app.core.logging import logger
  8. from app.core.remote_executor import ssh_exec
  9. from sqlalchemy import select
  10. settings = get_settings()
  11. async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
  12. """在已训练的 adapter 上运行评估(perplexity)。"""
  13. eval_id = str(uuid.uuid4())
  14. # 远程训练模式:把评估任务也发到远程容器执行
  15. if settings.use_remote_compute:
  16. logger.info(f"Running remote evaluation for job {job_id}")
  17. return await _run_remote_evaluation(eval_id, job_id)
  18. adapter_path = settings.adapters_dir / job_id
  19. if not adapter_path.exists():
  20. return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": "Adapter not found"}
  21. try:
  22. import torch
  23. from transformers import AutoModelForCausalLM, AutoTokenizer
  24. # 加载 base model + adapter
  25. model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
  26. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  27. # 加载评估数据
  28. async with async_session() as session:
  29. from app.core.db import TrainingJobModel
  30. result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
  31. record = result.scalar_one_or_none()
  32. if record:
  33. dataset_path = record.dataset_id
  34. metrics = {}
  35. model.eval()
  36. # 计算 perplexity(使用 adapter 自身的数据或默认样例)
  37. sample_texts = [
  38. "The quick brown fox jumps over the lazy dog.",
  39. "Hello, how are you doing today?",
  40. ]
  41. losses = []
  42. with torch.no_grad():
  43. for text in sample_texts:
  44. inputs = tokenizer(text, return_tensors="pt").to(model.device)
  45. outputs = model(**inputs, labels=inputs["input_ids"])
  46. losses.append(outputs.loss.item())
  47. avg_loss = sum(losses) / len(losses) if losses else 0
  48. perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0
  49. metrics = {
  50. "eval_loss": round(avg_loss, 4),
  51. "perplexity": round(perplexity, 2),
  52. "num_samples": len(sample_texts),
  53. }
  54. # 保存结果
  55. eval_record = EvalResultModel(
  56. id=eval_id,
  57. job_id=job_id,
  58. metrics=json.dumps(metrics),
  59. created_at=datetime.utcnow(),
  60. )
  61. async with async_session() as session:
  62. session.add(eval_record)
  63. await session.commit()
  64. logger.info(f"Evaluation completed for job {job_id}: {metrics}")
  65. return {"id": eval_id, "job_id": job_id, "metrics": metrics, "created_at": eval_record.created_at.isoformat()}
  66. except Exception as e:
  67. logger.error(f"Evaluation failed for job {job_id}: {e}")
  68. return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": str(e)}
  69. async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
  70. """通过 SSH 在远程容器里执行评估。"""
  71. remote_cmd = (
  72. f"docker exec "
  73. f"-e MACA_MPS_MODE=1 "
  74. f"-e METAX_VISIBLE_DEVICES=2,3 "
  75. f"-w {settings.compute_node_workdir} "
  76. f"{settings.compute_node_docker_container} "
  77. f"{settings.compute_node_python} -c \""
  78. "import asyncio, json; "
  79. "from app.core.remote_eval import run_remote_eval; "
  80. f"result = asyncio.run(run_remote_eval('{job_id}')); "
  81. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  82. )
  83. code, stdout, stderr = ssh_exec(remote_cmd, timeout=300)
  84. if code != 0:
  85. logger.error(f"Remote evaluation failed: {stderr}")
  86. return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": stderr.strip()}
  87. # 提取最后一行 JSON
  88. for line in reversed(stdout.strip().split("\n")):
  89. line = line.strip()
  90. if line.startswith("{"):
  91. try:
  92. result = json.loads(line)
  93. # 保存结果到本地数据库
  94. eval_record = EvalResultModel(
  95. id=eval_id,
  96. job_id=job_id,
  97. metrics=json.dumps(result.get("metrics", {})),
  98. created_at=datetime.utcnow(),
  99. )
  100. async with async_session() as session:
  101. session.add(eval_record)
  102. await session.commit()
  103. return {"id": eval_id, "job_id": job_id, "metrics": result.get("metrics", {}),
  104. "created_at": eval_record.created_at.isoformat()}
  105. except json.JSONDecodeError:
  106. continue
  107. return {"id": eval_id, "job_id": job_id, "metrics": {}, "created_at": "", "error": f"Invalid response: {stdout[:500]}"}
  108. async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
  109. """获取已完成评估的结果。"""
  110. async with async_session() as session:
  111. result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
  112. record = result.scalar_one_or_none()
  113. if record:
  114. return {
  115. "id": record.id,
  116. "job_id": record.job_id,
  117. "metrics": json.loads(record.metrics) if record.metrics else {},
  118. "created_at": record.created_at.isoformat(),
  119. }
  120. return {"id": eval_id, "job_id": "", "metrics": {}, "created_at": ""}