| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- from fastapi import APIRouter
- from app.schemas.training import TrainingConfig, TrainingJobResponse, TrainingProgress
- from app.services import training_service
- router = APIRouter()
- @router.post("/jobs", response_model=TrainingJobResponse)
- async def create_training_job(config: TrainingConfig):
- """创建并加入训练任务。"""
- config_dict = config.model_dump()
- result = await training_service.create_training_job(config_dict)
- return TrainingJobResponse(**result)
- @router.get("/jobs", response_model=list[TrainingJobResponse])
- async def list_training_jobs():
- """列出所有训练任务。"""
- items = await training_service.list_training_jobs()
- return [TrainingJobResponse(**item) for item in items]
- @router.get("/jobs/{job_id}", response_model=TrainingJobResponse)
- async def get_training_job(job_id: str):
- """获取指定任务详情。"""
- item = await training_service.get_training_job(job_id)
- if item:
- return TrainingJobResponse(**item)
- return TrainingJobResponse(
- id=job_id,
- model_id="",
- model_type="text",
- peft_method="lora",
- status="pending",
- created_at="",
- )
- @router.post("/jobs/{job_id}/cancel")
- async def cancel_training_job(job_id: str):
- """取消运行中的训练任务。"""
- return await training_service.cancel_training_job(job_id)
- @router.get("/jobs/{job_id}/logs")
- async def stream_training_logs(job_id: str):
- """通过 SSE 流式推送训练日志。"""
- from fastapi.responses import StreamingResponse
- async def log_stream():
- from app.config import get_settings
- _settings = get_settings()
- log_file = _settings.adapters_dir / job_id / "trainer_log.txt"
- if os.path.exists(log_file):
- with open(log_file, "r") as f:
- for line in f:
- yield f"data: {line}\n\n"
- else:
- yield "data: No logs available\n\n"
- return StreamingResponse(log_stream(), media_type="text/event-stream")
|