from fastapi import APIRouter from app.schemas.training import TrainingConfig, TrainingJobResponse, TrainingProgress from app.services import training_service router = APIRouter() @router.post("/jobs", response_model=TrainingJobResponse) async def create_training_job(config: TrainingConfig): """创建并加入训练任务。""" config_dict = config.model_dump() result = await training_service.create_training_job(config_dict) return TrainingJobResponse(**result) @router.get("/jobs", response_model=list[TrainingJobResponse]) async def list_training_jobs(): """列出所有训练任务。""" items = await training_service.list_training_jobs() return [TrainingJobResponse(**item) for item in items] @router.get("/jobs/{job_id}", response_model=TrainingJobResponse) async def get_training_job(job_id: str): """获取指定任务详情。""" item = await training_service.get_training_job(job_id) if item: return TrainingJobResponse(**item) return TrainingJobResponse( id=job_id, model_id="", model_type="text", peft_method="lora", status="pending", created_at="", ) @router.post("/jobs/{job_id}/cancel") async def cancel_training_job(job_id: str): """取消运行中的训练任务。""" return await training_service.cancel_training_job(job_id) @router.get("/jobs/{job_id}/logs") async def stream_training_logs(job_id: str): """通过 SSE 流式推送训练日志。""" import os from fastapi.responses import StreamingResponse async def log_stream(): from app.config import get_settings _settings = get_settings() log_file = _settings.adapters_dir / job_id / "trainer_log.txt" if os.path.exists(log_file): with open(log_file, "r") as f: for line in f: yield f"data: {line}\n\n" else: yield "data: No logs available\n\n" return StreamingResponse(log_stream(), media_type="text/event-stream")