eval_service.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import asyncio
  2. import json
  3. import uuid
  4. from datetime import datetime, timezone
  5. from typing import Any
  6. from app.config import get_settings
  7. from app.core.background_tasks import background_task_manager
  8. from app.core.db import async_session, EvalResultModel
  9. from app.core.logging import logger
  10. from app.core.remote_executor import ssh_exec
  11. from sqlalchemy import select
  12. settings = get_settings()
  13. async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
  14. """启动评估后台任务,立即返回 eval_id。"""
  15. eval_id = str(uuid.uuid4())
  16. # 写 DB
  17. record = EvalResultModel(
  18. id=eval_id,
  19. job_id=job_id,
  20. status="pending",
  21. metrics="{}",
  22. )
  23. async with async_session() as session:
  24. session.add(record)
  25. await session.commit()
  26. # 注册并启动
  27. background_task_manager.register_task(eval_id, "evaluation", {"job_id": job_id})
  28. await background_task_manager.run(
  29. eval_id, "evaluation", _execute_evaluation(eval_id, job_id, config)
  30. )
  31. logger.info(f"Evaluation task started: job={job_id} (eval_id={eval_id})")
  32. return {"id": eval_id, "job_id": job_id, "status": "pending"}
  33. async def _execute_evaluation(eval_id: str, job_id: str, config: dict[str, Any]) -> dict:
  34. """后台执行评估。"""
  35. try:
  36. # 远程训练模式:把评估任务也发到远程容器执行
  37. if settings.use_remote_compute:
  38. logger.info(f"Running remote evaluation for job {job_id}")
  39. result = await _run_remote_evaluation(eval_id, job_id)
  40. return {"metrics": result.get("metrics", {})}
  41. adapter_path = settings.adapters_dir / job_id
  42. if not adapter_path.exists():
  43. raise ValueError("Adapter not found")
  44. import torch
  45. from transformers import AutoModelForCausalLM, AutoTokenizer
  46. # 加载 base model + adapter
  47. model = AutoModelForCausalLM.from_pretrained(adapter_path, torch_dtype=torch.float16, device_map="auto")
  48. tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
  49. # 计算 perplexity
  50. sample_texts = [
  51. "The quick brown fox jumps over the lazy dog.",
  52. "Hello, how are you doing today?",
  53. ]
  54. losses = []
  55. model.eval()
  56. with torch.no_grad():
  57. for text in sample_texts:
  58. inputs = tokenizer(text, return_tensors="pt").to(model.device)
  59. outputs = model(**inputs, labels=inputs["input_ids"])
  60. losses.append(outputs.loss.item())
  61. avg_loss = sum(losses) / len(losses) if losses else 0
  62. perplexity = torch.exp(torch.tensor(avg_loss)).item() if avg_loss > 0 else 0
  63. metrics = {
  64. "eval_loss": round(avg_loss, 4),
  65. "perplexity": round(perplexity, 2),
  66. "num_samples": len(sample_texts),
  67. }
  68. # 更新 DB
  69. async with async_session() as session:
  70. result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
  71. eval_record = result.scalar_one_or_none()
  72. if eval_record:
  73. eval_record.metrics = json.dumps(metrics)
  74. eval_record.status = "completed"
  75. eval_record.progress = 100.0
  76. await session.commit()
  77. logger.info(f"Evaluation completed for job {job_id}: {metrics}")
  78. return {"metrics": metrics}
  79. except Exception as e:
  80. logger.error(f"Evaluation failed for job {job_id}: {e}")
  81. async with async_session() as session:
  82. result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
  83. eval_record = result.scalar_one_or_none()
  84. if eval_record:
  85. eval_record.status = "failed"
  86. eval_record.error = str(e)
  87. await session.commit()
  88. return {"error": str(e)}
  89. async def _run_remote_evaluation(eval_id: str, job_id: str) -> dict[str, Any]:
  90. """通过 SSH 在远程容器里执行评估。"""
  91. remote_cmd = (
  92. f"docker exec "
  93. f"-e MACA_MPS_MODE=1 "
  94. f"-e CUDA_VISIBLE_DEVICES=3 "
  95. f"-w {settings.compute_node_workdir} "
  96. f"{settings.compute_node_docker_container} "
  97. f"{settings.compute_node_python} -c \""
  98. "import asyncio, json; "
  99. "from app.core.remote_eval import run_remote_eval; "
  100. f"result = asyncio.run(run_remote_eval('{job_id}')); "
  101. "print(json.dumps(result, ensure_ascii=False))\" 2>&1"
  102. )
  103. code, stdout, stderr = await asyncio.to_thread(ssh_exec, remote_cmd, timeout=300)
  104. if code != 0:
  105. raise RuntimeError(f"Remote evaluation failed: {stderr}")
  106. # 提取最后一行 JSON
  107. for line in reversed(stdout.strip().split("\n")):
  108. line = line.strip()
  109. if line.startswith("{"):
  110. try:
  111. result = json.loads(line)
  112. # 保存结果到本地数据库(更新已有记录)
  113. metrics = result.get("metrics", {})
  114. async with async_session() as session:
  115. res = await session.execute(
  116. select(EvalResultModel).where(EvalResultModel.id == eval_id)
  117. )
  118. eval_record = res.scalar_one_or_none()
  119. if eval_record:
  120. eval_record.metrics = json.dumps(metrics)
  121. eval_record.status = "completed"
  122. eval_record.progress = 100.0
  123. await session.commit()
  124. return {"id": eval_id, "job_id": job_id, "metrics": metrics}
  125. except json.JSONDecodeError:
  126. continue
  127. raise RuntimeError(f"Invalid response: {stdout[:500]}")
  128. async def get_evaluation_results(eval_id: str) -> dict[str, Any]:
  129. """获取已完成评估的结果。"""
  130. async with async_session() as session:
  131. result = await session.execute(select(EvalResultModel).where(EvalResultModel.id == eval_id))
  132. record = result.scalar_one_or_none()
  133. if record:
  134. return {
  135. "id": record.id,
  136. "job_id": record.job_id,
  137. "status": record.status,
  138. "progress": record.progress,
  139. "metrics": json.loads(record.metrics) if record.metrics else {},
  140. "error": record.error,
  141. "created_at": record.created_at.isoformat(),
  142. }
  143. return {"id": eval_id, "job_id": "", "status": "not_found", "metrics": {}}
  144. async def recover_stale_evaluations() -> None:
  145. async with async_session() as session:
  146. result = await session.execute(
  147. select(EvalResultModel).where(
  148. EvalResultModel.status.in_(["pending", "running"])
  149. )
  150. )
  151. records = result.scalars().all()
  152. for record in records:
  153. record.status = "failed"
  154. record.error = "Server restarted, task interrupted"
  155. if records:
  156. await session.commit()
  157. logger.info(f"Recovered {len(records)} stale evaluation tasks")