rerank.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import logging
  2. from typing import List, Optional
  3. from fastapi import APIRouter, Request
  4. from pydantic import BaseModel
  5. from gpustack.routes.openai import proxy_request_by_model
  6. from gpustack.server.deps import SessionDep, CurrentUserDep
  7. router = APIRouter()
  8. logger = logging.getLogger(__name__)
  9. class RerankRequest(BaseModel):
  10. model: str
  11. query: str
  12. documents: List[str]
  13. top_n: Optional[int] = None
  14. return_documents: Optional[bool] = True
  15. class RerankUsage(BaseModel):
  16. total_tokens: Optional[int] = None
  17. prompt_tokens: Optional[int] = None
  18. class RerankResultDocument(BaseModel):
  19. text: str
  20. class RerankResult(BaseModel):
  21. index: int
  22. document: RerankResultDocument
  23. relevance_score: float
  24. class RerankResponse(BaseModel):
  25. model: str
  26. # object: str
  27. usage: RerankUsage
  28. results: List[RerankResult]
  29. @router.post("/rerank", response_model=RerankResponse)
  30. async def rerank(
  31. request: Request,
  32. user: CurrentUserDep,
  33. session: SessionDep,
  34. ):
  35. return await proxy_request_by_model(request, user, session)