training.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from fastapi import APIRouter
  2. from app.schemas.training import TrainingConfig, TrainingJobResponse, TrainingProgress
  3. from app.services import training_service
  4. router = APIRouter()
  5. @router.post("/jobs", response_model=TrainingJobResponse)
  6. async def create_training_job(config: TrainingConfig):
  7. """创建并加入训练任务。"""
  8. config_dict = config.model_dump()
  9. result = await training_service.create_training_job(config_dict)
  10. return TrainingJobResponse(**result)
  11. @router.get("/jobs", response_model=list[TrainingJobResponse])
  12. async def list_training_jobs():
  13. """列出所有训练任务。"""
  14. items = await training_service.list_training_jobs()
  15. return [TrainingJobResponse(**item) for item in items]
  16. @router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
  17. async def get_training_job(job_id: str):
  18. """获取指定任务详情。"""
  19. item = await training_service.get_training_job(job_id)
  20. if item:
  21. return TrainingJobResponse(**item)
  22. return TrainingJobResponse(
  23. id=job_id,
  24. model_id="",
  25. model_type="text",
  26. peft_method="lora",
  27. status="pending",
  28. created_at="",
  29. )
  30. @router.post("/jobs/{job_id}/cancel")
  31. async def cancel_training_job(job_id: str):
  32. """取消运行中的训练任务。"""
  33. return await training_service.cancel_training_job(job_id)
  34. @router.get("/jobs/{job_id}/logs")
  35. async def stream_training_logs(job_id: str):
  36. """通过 SSE 流式推送训练日志。"""
  37. from fastapi.responses import StreamingResponse
  38. async def log_stream():
  39. from app.config import get_settings
  40. _settings = get_settings()
  41. log_file = _settings.adapters_dir / job_id / "trainer_log.txt"
  42. if os.path.exists(log_file):
  43. with open(log_file, "r") as f:
  44. for line in f:
  45. yield f"data: {line}\n\n"
  46. else:
  47. yield "data: No logs available\n\n"
  48. return StreamingResponse(log_stream(), media_type="text/event-stream")