from pydantic import BaseModel, Field from app.services import inference_service from fastapi import APIRouter router = APIRouter() class GenerateRequest(BaseModel): adapter_id: str = Field(..., description="训练任务 ID(adapter 目录名)") prompt: str max_new_tokens: int = 256 temperature: float = 0.8 top_p: float = 0.95 repetition_penalty: float = 1.1 do_sample: bool = True class GenerateResponse(BaseModel): prompt: str generated_text: str generated_only: str tokens_generated: int error: str | None = None @router.post("/generate", response_model=GenerateResponse) async def generate(req: GenerateRequest): """使用已训练的 adapter 生成文本。""" from app.config import get_settings settings = get_settings() adapter_path = str(settings.adapters_dir / req.adapter_id) result = await inference_service.generate( adapter_path, req.prompt, req.max_new_tokens, req.temperature, req.top_p, req.repetition_penalty, req.do_sample, ) return GenerateResponse( prompt=result.get("prompt", req.prompt), generated_text=result.get("generated_text", ""), generated_only=result.get("generated_only", ""), tokens_generated=result.get("tokens_generated", 0), error=result.get("error"), ) @router.get("/adapters") async def list_adapters(): """列出所有可用的 adapter。""" return await inference_service.get_available_adapters()