middlewares.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. from datetime import datetime, timezone
  2. import functools
  3. import json
  4. import logging
  5. import time
  6. from typing import Type, Union
  7. from fastapi import Request, Response, status
  8. from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
  9. from jwt import DecodeError, ExpiredSignatureError
  10. from starlette.middleware.base import BaseHTTPMiddleware
  11. from openai.types.chat import ChatCompletion, ChatCompletionChunk
  12. from openai.types import CompletionUsage
  13. from openai.types.audio.transcription_create_response import (
  14. Transcription,
  15. )
  16. from openai.types.create_embedding_response import (
  17. Usage as EmbeddingUsage,
  18. )
  19. from gpustack.api.exceptions import ErrorResponse
  20. from gpustack.routes.rerank import RerankResponse, RerankUsage
  21. from gpustack.schemas.images import ImageGenerationChunk, ImagesResponse
  22. from gpustack.schemas.model_usage import OperationEnum
  23. from gpustack.schemas.api_keys import ApiKey
  24. from gpustack.schemas.models import Model
  25. from gpustack.schemas.users import User
  26. from gpustack.security import JWTManager
  27. from gpustack import envs
  28. from gpustack.api.auth import SESSION_COOKIE_NAME
  29. from gpustack.server.metrics_collector import (
  30. ModelUsageMetrics,
  31. accumulate_gateway_metrics,
  32. )
  33. from gpustack.api.types.openai_ext import CreateEmbeddingResponseExt, CompletionExt
  34. logger = logging.getLogger(__name__)
  35. @functools.lru_cache(maxsize=1)
  36. def _warn_about_missing_start_time() -> None:
  37. """Per-process warn-once for the RequestTimeMiddleware misconfiguration.
  38. ``lru_cache`` keeps the latch encapsulated inside the function — there's
  39. no module-level mutable state to track or reset in tests.
  40. """
  41. logger.warning(
  42. "request.state.start_time missing in record_model_usage; "
  43. "RequestTimeMiddleware may not be registered or runs after "
  44. "ModelUsageMiddleware. Falling back to now() — started_at and "
  45. "completed_at on the audit row will be equal until this is fixed."
  46. )
  47. class RequestTimeMiddleware(BaseHTTPMiddleware):
  48. async def dispatch(self, request: Request, call_next):
  49. request.state.start_time = datetime.now(timezone.utc)
  50. try:
  51. response = await call_next(request)
  52. except Exception as e:
  53. # Log the full traceback so unexpected errors don't disappear
  54. # behind the generic 500 response. The exception is otherwise
  55. # serialized only via str(e), which often hides the real cause
  56. # (validation errors, attribute errors with terse repr, etc.).
  57. logger.exception(
  58. "Unhandled exception in request %s %s",
  59. request.method,
  60. request.url.path,
  61. )
  62. response = JSONResponse(
  63. status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  64. content=ErrorResponse(
  65. code=status.HTTP_500_INTERNAL_SERVER_ERROR,
  66. reason="Internal Server Error",
  67. message=f"Unexpected error occurred: {e}",
  68. ).model_dump(),
  69. )
  70. return response
  71. class ModelUsageMiddleware(BaseHTTPMiddleware):
  72. async def dispatch(self, request: Request, call_next):
  73. response = await call_next(request)
  74. if response.status_code == 200:
  75. path = request.url.path
  76. if path == "/v1-openai/chat/completions" or path == "/v1/chat/completions":
  77. return await process_request(
  78. request, response, ChatCompletion, OperationEnum.CHAT_COMPLETION
  79. )
  80. elif path == "/v1-openai/completions" or path == "/v1/completions":
  81. return await process_request(
  82. request, response, CompletionExt, OperationEnum.COMPLETION
  83. )
  84. elif path == "/v1-openai/embeddings" or path == "/v1/embeddings":
  85. return await process_request(
  86. request,
  87. response,
  88. CreateEmbeddingResponseExt,
  89. OperationEnum.EMBEDDING,
  90. )
  91. elif (
  92. path == "/v1-openai/images/generations"
  93. or path == "/v1/images/generations"
  94. or path == "/v1-openai/images/edits"
  95. or path == "/v1/images/edits"
  96. ):
  97. return await process_request(
  98. request,
  99. response,
  100. ImagesResponse,
  101. OperationEnum.IMAGE_GENERATION,
  102. )
  103. elif path == "/v1-openai/audio/speech" or path == "/v1/audio/speech":
  104. return await process_request(
  105. request,
  106. response,
  107. FileResponse,
  108. OperationEnum.AUDIO_SPEECH,
  109. )
  110. elif (
  111. path == "/v1-openai/audio/transcriptions"
  112. or path == "/v1/audio/transcriptions"
  113. ):
  114. return await process_request(
  115. request,
  116. response,
  117. Transcription,
  118. OperationEnum.AUDIO_TRANSCRIPTION,
  119. )
  120. elif request.url.path == "/v1/rerank":
  121. return await process_request(
  122. request,
  123. response,
  124. RerankResponse,
  125. OperationEnum.RERANK,
  126. )
  127. return response
  128. async def process_request(
  129. request: Request,
  130. response: StreamingResponse,
  131. response_class: Type[
  132. Union[
  133. ChatCompletion,
  134. CompletionExt,
  135. CreateEmbeddingResponseExt,
  136. RerankResponse,
  137. ImagesResponse,
  138. FileResponse,
  139. Transcription,
  140. ]
  141. ],
  142. operation: OperationEnum,
  143. ):
  144. stream: bool = getattr(request.state, "stream", False)
  145. if stream:
  146. if response_class == ChatCompletion:
  147. response_class = ChatCompletionChunk
  148. if response_class == ImagesResponse:
  149. response_class = ImageGenerationChunk
  150. return await handle_streaming_response(
  151. request, response, response_class, operation
  152. )
  153. else:
  154. response_body = b"".join([chunk async for chunk in response.body_iterator])
  155. try:
  156. usage = None
  157. if (
  158. response.headers.get("content-type")
  159. .lower()
  160. .startswith("application/json")
  161. ):
  162. response_dict = json.loads(response_body)
  163. response_instance = response_class(**response_dict)
  164. if hasattr(response_instance, "usage"):
  165. usage = response_instance.usage
  166. await record_model_usage(request, usage, operation)
  167. except Exception as e:
  168. logger.error(f"Error processing model usage: {e}")
  169. response = Response(
  170. content=response_body,
  171. status_code=response.status_code,
  172. headers=dict(response.headers),
  173. )
  174. return response
  175. async def record_model_usage(
  176. request: Request,
  177. usage: Union[CompletionUsage, EmbeddingUsage, RerankUsage, None],
  178. operation: OperationEnum,
  179. ):
  180. total_tokens = getattr(usage, 'total_tokens', 0) or 0
  181. prompt_tokens = getattr(usage, 'prompt_tokens', total_tokens) or total_tokens
  182. completion_tokens = (
  183. getattr(usage, 'completion_tokens', total_tokens - prompt_tokens)
  184. or total_tokens - prompt_tokens
  185. )
  186. prompt_token_details = (
  187. getattr(usage, "prompt_tokens_details", None) if usage else None
  188. )
  189. input_cached_tokens = 0
  190. if prompt_token_details:
  191. if isinstance(prompt_token_details, dict):
  192. input_cached_tokens = prompt_token_details.get("cached_tokens", 0) or 0
  193. else:
  194. input_cached_tokens = getattr(prompt_token_details, "cached_tokens", 0) or 0
  195. user: User = request.state.user
  196. model: Model = request.state.model
  197. api_key: ApiKey | None = getattr(request.state, "api_key", None)
  198. # Reaching this function means the canonical usage chunk was observed,
  199. # so the report is ``completed=True``. Wall-clock anchors come from
  200. # RequestTimeMiddleware (start) and now (completion); the unified
  201. # flusher uses ``completed_at`` to choose the billing period.
  202. now = datetime.now(timezone.utc)
  203. started_at = getattr(request.state, "start_time", None)
  204. if started_at is None:
  205. # Falling back to ``now`` means request duration collapses to ~0,
  206. # which silently breaks SLO/latency analytics built off of
  207. # (completed_at - started_at). Surface the misconfiguration once
  208. # without flooding the log on every request.
  209. _warn_about_missing_start_time()
  210. started_at = now
  211. metric = ModelUsageMetrics(
  212. model=model.name,
  213. input_token=prompt_tokens,
  214. output_token=completion_tokens,
  215. total_token=total_tokens,
  216. input_cached_token=input_cached_tokens,
  217. request_count=1,
  218. completed=True,
  219. started_at=int(started_at.timestamp() * 1000),
  220. completed_at=int(now.timestamp() * 1000),
  221. user_id=user.id if user is not None else None,
  222. model_id=model.id,
  223. model_route_id=getattr(request.state, "model_route_id", None),
  224. # Capture cluster_id at request time so it survives a later model
  225. # delete; the unified flusher prefers this over re-reading the
  226. # live model row.
  227. cluster_id=getattr(model, "cluster_id", None),
  228. access_key=api_key.access_key if api_key is not None else None,
  229. operation=operation,
  230. )
  231. await accumulate_gateway_metrics([metric])
  232. async def handle_streaming_response(
  233. request: Request,
  234. response: StreamingResponse,
  235. response_class: Type[
  236. Union[ChatCompletionChunk, CompletionExt, ImageGenerationChunk]
  237. ],
  238. operation: OperationEnum,
  239. ):
  240. async def streaming_generator():
  241. async for chunk in response.body_iterator:
  242. try:
  243. async for processed_chunk in process_chunk(
  244. chunk, request, response_class, operation
  245. ):
  246. yield processed_chunk
  247. except Exception as e:
  248. logger.error(f"Error processing streaming response: {e}")
  249. yield chunk
  250. return StreamingResponse(streaming_generator(), headers=response.headers)
  251. async def process_chunk(
  252. chunk,
  253. request,
  254. response_class,
  255. operation: OperationEnum,
  256. ):
  257. if not hasattr(request.state, 'first_token_time'):
  258. request.state.first_token_time = datetime.now(timezone.utc)
  259. # each chunk may contain multiple data lines
  260. lines = chunk.decode("utf-8").split("\n\n")
  261. for line in lines[:-1]:
  262. if not line.startswith('data: '):
  263. # skip non-data SSE messages
  264. yield f"{line}\n\n".encode("utf-8")
  265. continue
  266. data = line.split('data: ')[-1]
  267. if data.startswith('[DONE]'):
  268. yield "data: [DONE]\n\n".encode("utf-8")
  269. continue
  270. if '"usage":' in data:
  271. response_dict = None
  272. try:
  273. response_dict = json.loads(data.strip())
  274. except Exception as e:
  275. raise e
  276. response_chunk = response_class(**response_dict)
  277. if is_usage_chunk(response_chunk):
  278. await record_model_usage(request, response_chunk.usage, operation)
  279. # Fill rate metrics. These are extended info not included in OAI APIs.
  280. # llama-box provides them out-of-the-box. Align with other backends here.
  281. if should_add_metrics(response_dict):
  282. add_metrics(response_dict, request, response_chunk)
  283. yield f"data: {json.dumps(response_dict, separators=(',', ':'))}\n\n".encode(
  284. "utf-8"
  285. )
  286. else:
  287. yield f"{line}\n\n".encode("utf-8")
  288. def should_add_metrics(response_dict):
  289. if not isinstance(response_dict, dict):
  290. return False
  291. usage = response_dict.get('usage', {})
  292. return 'prompt_tokens' in usage and 'tokens_per_second' not in usage
  293. def add_metrics(response_dict, request, response_chunk):
  294. now = datetime.now(timezone.utc)
  295. time_to_first_token_ms = (
  296. request.state.first_token_time - request.state.start_time
  297. ).total_seconds() * 1000
  298. tokens_after_first = max(response_chunk.usage.completion_tokens - 1, 1)
  299. time_per_output_token_ms = (
  300. (now - request.state.first_token_time).total_seconds()
  301. * 1000
  302. / tokens_after_first
  303. )
  304. tokens_per_second = (
  305. 1000 / time_per_output_token_ms if time_per_output_token_ms > 0 else 0
  306. )
  307. response_dict['usage'].update(
  308. {
  309. "time_to_first_token_ms": time_to_first_token_ms,
  310. "time_per_output_token_ms": time_per_output_token_ms,
  311. "tokens_per_second": tokens_per_second,
  312. }
  313. )
  314. class RefreshTokenMiddleware(BaseHTTPMiddleware):
  315. async def dispatch(self, request: Request, call_next):
  316. response = await call_next(request)
  317. jwt_manager: JWTManager = request.app.state.jwt_manager
  318. # Cookie-based refresh (existing local auth)
  319. token = request.cookies.get(SESSION_COOKIE_NAME)
  320. if token:
  321. try:
  322. payload = jwt_manager.decode_jwt_token(token)
  323. if payload:
  324. # Check if the token is about to expire (less than 15 minutes left)
  325. if payload['exp'] - time.time() < 15 * 60:
  326. new_token = jwt_manager.create_jwt_token(
  327. username=payload['sub']
  328. )
  329. response.set_cookie(
  330. key=SESSION_COOKIE_NAME,
  331. value=new_token,
  332. httponly=True,
  333. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  334. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  335. )
  336. except (ExpiredSignatureError, DecodeError):
  337. pass
  338. else:
  339. # SSO Bearer token sliding expiration
  340. auth_header = request.headers.get("Authorization", "")
  341. if auth_header.startswith("Bearer "):
  342. bearer_token = auth_header[7:]
  343. try:
  344. payload = jwt_manager.decode_jwt_token(bearer_token)
  345. if payload:
  346. exp = payload.get('exp', 0)
  347. iat = payload.get('iat', 0) or (exp - envs.JWT_TOKEN_EXPIRE_MINUTES * 60)
  348. lifetime = exp - iat
  349. remaining = exp - time.time()
  350. # If token has used more than 50% of its lifetime, issue a new one
  351. if remaining < lifetime * 0.5:
  352. new_token = jwt_manager.create_jwt_token(
  353. username=payload['sub']
  354. )
  355. response.headers['X-New-Token'] = new_token
  356. except (ExpiredSignatureError, DecodeError):
  357. pass
  358. return response
  359. def is_usage_chunk(
  360. chunk: Union[ChatCompletionChunk, CompletionExt, ImageGenerationChunk],
  361. ) -> bool:
  362. choices = getattr(chunk, "choices", None)
  363. if not choices and chunk.usage:
  364. return True
  365. for choice in choices or []:
  366. if choice.finish_reason is not None and chunk.usage:
  367. return True
  368. return False