openai.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import json
  2. import re
  3. import random
  4. import asyncio
  5. from typing import AsyncGenerator, List, Optional, Tuple, Union, Dict
  6. import aiohttp
  7. import logging
  8. from fastapi import APIRouter, Query, Request, Response, status
  9. from openai.types import Model as OAIModel
  10. from openai.pagination import SyncPage
  11. from sqlmodel import or_, select
  12. from sqlmodel.ext.asyncio.session import AsyncSession
  13. from starlette.datastructures import UploadFile
  14. from gpustack.api.exceptions import (
  15. BadRequestException,
  16. NotFoundException,
  17. InternalServerErrorException,
  18. OpenAIAPIError,
  19. OpenAIAPIErrorResponse,
  20. ServiceUnavailableException,
  21. GatewayTimeoutException,
  22. )
  23. from gpustack.api.responses import StreamingResponseWithStatusCode
  24. from gpustack import envs
  25. from gpustack.http_proxy.load_balancer import LoadBalancer
  26. from gpustack.routes.model_common import build_category_conditions
  27. from gpustack.schemas.models import Model
  28. from gpustack.schemas.model_routes import (
  29. ModelRoute,
  30. MyModel,
  31. effective_route_name,
  32. )
  33. from gpustack.schemas.principals import Principal, PLATFORM_PRINCIPAL_ID
  34. from gpustack.schemas.workers import Worker
  35. from gpustack.server.deps import SessionDep, CurrentUserDep
  36. from gpustack.server.services import (
  37. ModelInstanceService,
  38. ModelRouteService,
  39. WorkerService,
  40. UserService,
  41. )
  42. from gpustack.server.worker_request import request_to_worker, stream_to_worker
  43. from gpustack.gateway.utils import (
  44. model_instance_prefix,
  45. openai_model_prefixes,
  46. router_header_key,
  47. )
  48. logger = logging.getLogger(__name__)
  49. load_balancer = LoadBalancer()
  50. # Endpoints served by a dedicated server router (e.g. rerank.router), so the
  51. # auto-registration must skip them to avoid duplicate /v1/<endpoint> registration.
  52. _server_managed_elsewhere = {"/rerank"}
  53. def get_api_router() -> APIRouter:
  54. """Full OpenAI-compatible endpoint set, mounted at /v1."""
  55. router = APIRouter()
  56. router.add_api_route("/models", list_models, methods=["GET"])
  57. for rp in openai_model_prefixes:
  58. for endpoint in rp.prefixes:
  59. if endpoint in _server_managed_elsewhere:
  60. continue
  61. router.add_api_route(endpoint, proxy_request_by_model, methods=["POST"])
  62. return router
  63. def get_legacy_api_router() -> APIRouter:
  64. """Legacy subset, mounted at /v1-openai. Frozen — new endpoints go to /v1 only."""
  65. router = APIRouter()
  66. router.add_api_route("/models", list_models, methods=["GET"])
  67. for rp in openai_model_prefixes:
  68. if not rp.support_legacy:
  69. continue
  70. for endpoint in rp.prefixes:
  71. if endpoint in _server_managed_elsewhere:
  72. continue
  73. router.add_api_route(endpoint, proxy_request_by_model, methods=["POST"])
  74. return router
  75. async def list_models(
  76. user: CurrentUserDep,
  77. session: SessionDep,
  78. categories: List[str] = Query(
  79. [],
  80. description="Model categories to filter by.",
  81. ),
  82. with_meta: Optional[bool] = Query(
  83. None,
  84. description="Include model meta information.",
  85. ),
  86. ):
  87. target_class = ModelRoute if user.is_admin else MyModel
  88. statement = select(target_class).where(target_class.ready_targets > 0)
  89. if target_class == MyModel:
  90. # Non-admin users should only see their own private models when filtering by categories.
  91. statement = statement.where(target_class.user_id == user.id)
  92. if categories:
  93. conditions = build_category_conditions(session, target_class, categories)
  94. statement = statement.where(or_(*conditions))
  95. models = (await session.exec(statement)).all()
  96. # Bulk-load owner principals to resolve each route's effective
  97. # name (slug-prefixed for non-platform owners). Without the prefix,
  98. # two owners holding routes named "qwen3-0.6b" would publish the
  99. # same ``id`` here and Higress's AI proxy would dispatch
  100. # ambiguously.
  101. principal_ids = {
  102. m.owner_principal_id for m in models if m.owner_principal_id is not None
  103. }
  104. principal_by_id: Dict[int, Principal] = {}
  105. if principal_ids:
  106. rows = (
  107. await session.exec(select(Principal).where(Principal.id.in_(principal_ids)))
  108. ).all()
  109. principal_by_id = {p.id: p for p in rows}
  110. result = SyncPage[OAIModel](data=[], object="list")
  111. for model in models:
  112. owner = (
  113. principal_by_id.get(model.owner_principal_id)
  114. if model.owner_principal_id
  115. else None
  116. )
  117. eff_name = effective_route_name(
  118. model.name,
  119. getattr(owner, "slug", None),
  120. getattr(owner, "id", None) == PLATFORM_PRINCIPAL_ID,
  121. )
  122. result.data.append(
  123. OAIModel(
  124. id=eff_name,
  125. object="model",
  126. created=int(model.created_at.timestamp()),
  127. owned_by="gpustack",
  128. meta=model.meta if with_meta else None,
  129. )
  130. )
  131. return result
  132. async def proxy_request_by_model(
  133. request: Request,
  134. user: CurrentUserDep,
  135. session: SessionDep,
  136. ):
  137. endpoint = re.sub(r"^/(v1|v1-openai)/", "", request.url.path)
  138. """
  139. Proxy the request to the model instance that is running the model specified in the
  140. request body.
  141. """
  142. model_name, stream, body_json, form_data = await parse_request_body(request)
  143. if not await UserService(session).model_allowed_for_user(
  144. model_name=model_name,
  145. user_id=user.id,
  146. api_key=getattr(request.state, "api_key", None),
  147. ):
  148. raise NotFoundException(
  149. message="Model not found",
  150. is_openai_exception=True,
  151. )
  152. model_route_service = ModelRouteService(session)
  153. models: List[Model] = await model_route_service.get_model_ids_by_model_route_name(
  154. model_name
  155. )
  156. if len(models) == 0:
  157. raise NotFoundException(
  158. message="Model not found or no running instances available",
  159. is_openai_exception=True,
  160. )
  161. request.state.stream = stream
  162. model = random.choice(models)
  163. request.state.model = model
  164. # Resolve the route id so downstream middleware (usage recording) can
  165. # attribute the request to the route it entered through. The lookup
  166. # is @locked_cached so repeat hits within the same session are cheap.
  167. model_route = await model_route_service.get_by_name(model_name)
  168. request.state.model_route_id = model_route.id if model_route else None
  169. mutate_request(request, model_name, body_json, form_data)
  170. instance = await get_running_instance(session, model.id)
  171. worker: Worker = await WorkerService(session).get_by_id(instance.worker_id)
  172. if not worker:
  173. raise InternalServerErrorException(
  174. message=f"Worker with ID {instance.worker_id} not found",
  175. is_openai_exception=True,
  176. )
  177. extra_headers = {
  178. router_header_key: f"{model_instance_prefix(instance)}.static",
  179. }
  180. path = f"v1/{endpoint}"
  181. logger.debug(
  182. f"proxying to {instance.worker_ip}:{instance.port}, instance port: {instance.port}"
  183. )
  184. try:
  185. headers, data = _prepare_proxy_request(
  186. request,
  187. body_json,
  188. form_data,
  189. extra_headers,
  190. add_stream_options=stream,
  191. )
  192. if stream:
  193. return StreamingResponseWithStatusCode(
  194. _stream_response(
  195. worker,
  196. request.method,
  197. path,
  198. headers,
  199. data,
  200. request.app.state.http_client,
  201. request.app.state.http_client_no_proxy,
  202. ),
  203. media_type="text/event-stream",
  204. )
  205. else:
  206. resp, body = await request_to_worker(
  207. worker=worker,
  208. method=request.method,
  209. path=path,
  210. proxy_client=request.app.state.http_client,
  211. no_proxy_client=request.app.state.http_client_no_proxy,
  212. data=data,
  213. headers=headers,
  214. timeout=aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT),
  215. )
  216. return Response(
  217. status_code=resp.status,
  218. headers=dict(resp.headers),
  219. content=body,
  220. )
  221. except asyncio.TimeoutError as e:
  222. error_message = f"Request to worker {worker.id} timed out"
  223. if str(e):
  224. error_message += f": {e}"
  225. raise GatewayTimeoutException(
  226. message=error_message,
  227. is_openai_exception=True,
  228. )
  229. except Exception as e:
  230. error_message = "An unexpected error occurred"
  231. if str(e):
  232. error_message += f": {e}"
  233. raise ServiceUnavailableException(
  234. message=error_message,
  235. is_openai_exception=True,
  236. )
  237. async def parse_request_body(request: Request):
  238. model_name = None
  239. stream = False
  240. body_json = None
  241. form_data = None
  242. content_type = request.headers.get("content-type", "application/json").lower()
  243. if request.method == "GET":
  244. model_name = request.query_params.get("model")
  245. elif content_type.startswith("multipart/form-data"):
  246. form_data, model_name, stream = await parse_form_data(request)
  247. else:
  248. body_json, model_name, stream = await parse_json_body(request)
  249. if not model_name:
  250. raise BadRequestException(
  251. message="Missing 'model' field",
  252. is_openai_exception=True,
  253. )
  254. return model_name, stream, body_json, form_data
  255. async def parse_form_data(request: Request) -> Tuple[aiohttp.FormData, str, bool]:
  256. try:
  257. form = await request.form()
  258. model_name = form.get("model")
  259. stream = form.get("stream", False)
  260. form_data = aiohttp.FormData()
  261. for key, value in form.items():
  262. if isinstance(value, UploadFile):
  263. form_data.add_field(
  264. key,
  265. await value.read(),
  266. filename=value.filename,
  267. content_type=value.content_type,
  268. )
  269. else:
  270. form_data.add_field(key, value)
  271. return form_data, model_name, stream
  272. except Exception as e:
  273. raise BadRequestException(
  274. message=f"We could not parse the form body of your request: {e}",
  275. is_openai_exception=True,
  276. )
  277. async def parse_json_body(request: Request):
  278. try:
  279. body_json = await request.json()
  280. model_name = body_json.get("model")
  281. stream = body_json.get("stream", False)
  282. return body_json, model_name, stream
  283. except Exception as e:
  284. raise BadRequestException(
  285. message=f"We could not parse the JSON body of your request: {e}",
  286. is_openai_exception=True,
  287. )
  288. def _prepare_proxy_request(
  289. request: Request,
  290. body_json: Optional[dict],
  291. form_data: Optional[aiohttp.FormData],
  292. extra_headers: Optional[dict] = None,
  293. add_stream_options: bool = False,
  294. ) -> Tuple[Dict[str, str], Optional[Union[bytes, aiohttp.FormData]]]:
  295. """
  296. Prepare headers and body for proxy requests.
  297. Returns (headers, data) tuple.
  298. """
  299. headers = filter_headers(request.headers)
  300. if extra_headers:
  301. headers.update(extra_headers)
  302. if add_stream_options and body_json and "stream_options" not in body_json:
  303. # Defaults to include usage.
  304. # TODO Record usage without client awareness.
  305. body_json["stream_options"] = {"include_usage": True}
  306. # Convert body to data
  307. data = (
  308. form_data
  309. if form_data
  310. else (json.dumps(body_json).encode() if body_json else None)
  311. )
  312. # When using data=bytes (instead of json=), aiohttp doesn't set Content-Type
  313. if body_json and not form_data:
  314. headers["Content-Type"] = "application/json"
  315. return headers, data
  316. async def _stream_response(
  317. worker: Worker,
  318. method: str,
  319. path: str,
  320. headers: Dict[str, str],
  321. data: Optional[Union[bytes, aiohttp.FormData]],
  322. proxy_client: aiohttp.ClientSession,
  323. no_proxy_client: aiohttp.ClientSession,
  324. ) -> AsyncGenerator[Tuple[Union[bytes, str], Dict[str, str], int], None]:
  325. """
  326. Stream response from worker. Yields (chunk, headers, status) tuples.
  327. """
  328. try:
  329. async for chunk, resp_headers, resp_status in stream_to_worker(
  330. worker=worker,
  331. method=method,
  332. path=path,
  333. proxy_client=proxy_client,
  334. no_proxy_client=no_proxy_client,
  335. data=data,
  336. headers=headers,
  337. timeout=aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT),
  338. ):
  339. yield chunk, resp_headers, resp_status
  340. except aiohttp.ClientError as e:
  341. error_response = OpenAIAPIErrorResponse(
  342. error=OpenAIAPIError(
  343. message=f"Service unavailable. Please retry your requests after a brief wait. Original error: {e}",
  344. code=status.HTTP_503_SERVICE_UNAVAILABLE,
  345. type="ServiceUnavailable",
  346. ),
  347. )
  348. yield error_response.model_dump_json(), {}, status.HTTP_503_SERVICE_UNAVAILABLE
  349. except Exception as e:
  350. error_response = OpenAIAPIErrorResponse(
  351. error=OpenAIAPIError(
  352. message=f"Internal server error: {e}",
  353. code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  354. type="InternalServerError",
  355. ),
  356. )
  357. yield error_response.model_dump_json(), {}, status.HTTP_500_INTERNAL_SERVER_ERROR
  358. def filter_headers(headers):
  359. return {
  360. key: value
  361. for key, value in headers.items()
  362. if key.lower() != "content-length"
  363. and key.lower() != "host"
  364. and key.lower() != "content-type"
  365. and key.lower() != "transfer-encoding"
  366. and key.lower() != "authorization"
  367. }
  368. async def get_running_instance(session: AsyncSession, model_id: int):
  369. running_instances = await ModelInstanceService(session).get_running_instances(
  370. model_id
  371. )
  372. if not running_instances:
  373. raise ServiceUnavailableException(
  374. message="No running instances available",
  375. is_openai_exception=True,
  376. )
  377. return await load_balancer.get_instance(running_instances)
  378. def mutate_request(
  379. request: Request,
  380. model_name: str,
  381. body_json: Optional[dict],
  382. form_data: Optional[aiohttp.FormData],
  383. ):
  384. path = request.url.path
  385. model: Model = request.state.model
  386. if (
  387. path == "/v1/rerank"
  388. and body_json
  389. and model.env
  390. and model.env.get("GPUSTACK_APPLY_QWEN3_RERANKER_TEMPLATES", False)
  391. ):
  392. apply_qwen3_reranker_templates(body_json)
  393. if model_name != model.name:
  394. if body_json is not None:
  395. body_json["model"] = model.name
  396. elif form_data is not None:
  397. form_data.add_field("model", model.name)
  398. def apply_qwen3_reranker_templates(body_json: dict):
  399. """
  400. Apply Qwen3 reranker templates to the request body.
  401. See instructions in https://huggingface.co/Qwen/Qwen3-Reranker-0.6B.
  402. Note: Once vLLM supports built-in template rendering for this model, this can be removed.
  403. """
  404. prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
  405. suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  406. query_template = "{prefix}<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {query}\n"
  407. document_template = "<Document>: {doc}{suffix}"
  408. if "query" in body_json and "documents" in body_json:
  409. query = body_json["query"]
  410. documents = body_json["documents"]
  411. body_json["query"] = query_template.format(prefix=prefix, query=query)
  412. body_json["documents"] = [
  413. document_template.format(doc=doc, suffix=suffix) for doc in documents
  414. ]