inference.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from pydantic import BaseModel, Field
  2. from app.services import inference_service
  3. from fastapi import APIRouter
  4. router = APIRouter()
  5. class GenerateRequest(BaseModel):
  6. adapter_id: str = Field(..., description="训练任务 ID(adapter 目录名)")
  7. prompt: str
  8. max_new_tokens: int = 256
  9. temperature: float = 0.8
  10. top_p: float = 0.95
  11. repetition_penalty: float = 1.1
  12. do_sample: bool = True
  13. class GenerateResponse(BaseModel):
  14. prompt: str
  15. generated_text: str
  16. generated_only: str
  17. tokens_generated: int
  18. error: str | None = None
  19. @router.post("/generate", response_model=GenerateResponse)
  20. async def generate(req: GenerateRequest):
  21. """使用已训练的 adapter 生成文本。"""
  22. from app.config import get_settings
  23. settings = get_settings()
  24. adapter_path = str(settings.adapters_dir / req.adapter_id)
  25. result = await inference_service.generate(
  26. adapter_path,
  27. req.prompt,
  28. req.max_new_tokens,
  29. req.temperature,
  30. req.top_p,
  31. req.repetition_penalty,
  32. req.do_sample,
  33. )
  34. return GenerateResponse(
  35. prompt=result.get("prompt", req.prompt),
  36. generated_text=result.get("generated_text", ""),
  37. generated_only=result.get("generated_only", ""),
  38. tokens_generated=result.get("tokens_generated", 0),
  39. error=result.get("error"),
  40. )
  41. @router.get("/adapters")
  42. async def list_adapters():
  43. """列出所有可用的 adapter。"""
  44. return await inference_service.get_available_adapters()