| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- from pydantic import BaseModel, Field
- from app.services import inference_service
- from fastapi import APIRouter
- router = APIRouter()
- class GenerateRequest(BaseModel):
- adapter_id: str = Field(..., description="训练任务 ID(adapter 目录名)")
- prompt: str
- max_new_tokens: int = 256
- temperature: float = 0.8
- top_p: float = 0.95
- repetition_penalty: float = 1.1
- do_sample: bool = True
- class GenerateResponse(BaseModel):
- prompt: str
- generated_text: str
- generated_only: str
- tokens_generated: int
- error: str | None = None
- @router.post("/generate", response_model=GenerateResponse)
- async def generate(req: GenerateRequest):
- """使用已训练的 adapter 生成文本。"""
- from app.config import get_settings
- settings = get_settings()
- adapter_path = str(settings.adapters_dir / req.adapter_id)
- result = await inference_service.generate(
- adapter_path,
- req.prompt,
- req.max_new_tokens,
- req.temperature,
- req.top_p,
- req.repetition_penalty,
- req.do_sample,
- )
- return GenerateResponse(
- prompt=result.get("prompt", req.prompt),
- generated_text=result.get("generated_text", ""),
- generated_only=result.get("generated_only", ""),
- tokens_generated=result.get("tokens_generated", 0),
- error=result.get("error"),
- )
- @router.get("/adapters")
- async def list_adapters():
- """列出所有可用的 adapter。"""
- return await inference_service.get_available_adapters()
|