| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- from datetime import datetime, timezone
- import functools
- import json
- import logging
- import time
- from typing import Type, Union
- from fastapi import Request, Response, status
- from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
- from jwt import DecodeError, ExpiredSignatureError
- from starlette.middleware.base import BaseHTTPMiddleware
- from openai.types.chat import ChatCompletion, ChatCompletionChunk
- from openai.types import CompletionUsage
- from openai.types.audio.transcription_create_response import (
- Transcription,
- )
- from openai.types.create_embedding_response import (
- Usage as EmbeddingUsage,
- )
- from gpustack.api.exceptions import ErrorResponse
- from gpustack.routes.rerank import RerankResponse, RerankUsage
- from gpustack.schemas.images import ImageGenerationChunk, ImagesResponse
- from gpustack.schemas.model_usage import OperationEnum
- from gpustack.schemas.api_keys import ApiKey
- from gpustack.schemas.models import Model
- from gpustack.schemas.users import User
- from gpustack.security import JWTManager
- from gpustack import envs
- from gpustack.api.auth import SESSION_COOKIE_NAME
- from gpustack.server.metrics_collector import (
- ModelUsageMetrics,
- accumulate_gateway_metrics,
- )
- from gpustack.api.types.openai_ext import CreateEmbeddingResponseExt, CompletionExt
- logger = logging.getLogger(__name__)
- @functools.lru_cache(maxsize=1)
- def _warn_about_missing_start_time() -> None:
- """Per-process warn-once for the RequestTimeMiddleware misconfiguration.
- ``lru_cache`` keeps the latch encapsulated inside the function — there's
- no module-level mutable state to track or reset in tests.
- """
- logger.warning(
- "request.state.start_time missing in record_model_usage; "
- "RequestTimeMiddleware may not be registered or runs after "
- "ModelUsageMiddleware. Falling back to now() — started_at and "
- "completed_at on the audit row will be equal until this is fixed."
- )
- class RequestTimeMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- request.state.start_time = datetime.now(timezone.utc)
- try:
- response = await call_next(request)
- except Exception as e:
- # Log the full traceback so unexpected errors don't disappear
- # behind the generic 500 response. The exception is otherwise
- # serialized only via str(e), which often hides the real cause
- # (validation errors, attribute errors with terse repr, etc.).
- logger.exception(
- "Unhandled exception in request %s %s",
- request.method,
- request.url.path,
- )
- response = JSONResponse(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- content=ErrorResponse(
- code=status.HTTP_500_INTERNAL_SERVER_ERROR,
- reason="Internal Server Error",
- message=f"Unexpected error occurred: {e}",
- ).model_dump(),
- )
- return response
- class ModelUsageMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- response = await call_next(request)
- if response.status_code == 200:
- path = request.url.path
- if path == "/v1-openai/chat/completions" or path == "/v1/chat/completions":
- return await process_request(
- request, response, ChatCompletion, OperationEnum.CHAT_COMPLETION
- )
- elif path == "/v1-openai/completions" or path == "/v1/completions":
- return await process_request(
- request, response, CompletionExt, OperationEnum.COMPLETION
- )
- elif path == "/v1-openai/embeddings" or path == "/v1/embeddings":
- return await process_request(
- request,
- response,
- CreateEmbeddingResponseExt,
- OperationEnum.EMBEDDING,
- )
- elif (
- path == "/v1-openai/images/generations"
- or path == "/v1/images/generations"
- or path == "/v1-openai/images/edits"
- or path == "/v1/images/edits"
- ):
- return await process_request(
- request,
- response,
- ImagesResponse,
- OperationEnum.IMAGE_GENERATION,
- )
- elif path == "/v1-openai/audio/speech" or path == "/v1/audio/speech":
- return await process_request(
- request,
- response,
- FileResponse,
- OperationEnum.AUDIO_SPEECH,
- )
- elif (
- path == "/v1-openai/audio/transcriptions"
- or path == "/v1/audio/transcriptions"
- ):
- return await process_request(
- request,
- response,
- Transcription,
- OperationEnum.AUDIO_TRANSCRIPTION,
- )
- elif request.url.path == "/v1/rerank":
- return await process_request(
- request,
- response,
- RerankResponse,
- OperationEnum.RERANK,
- )
- return response
- async def process_request(
- request: Request,
- response: StreamingResponse,
- response_class: Type[
- Union[
- ChatCompletion,
- CompletionExt,
- CreateEmbeddingResponseExt,
- RerankResponse,
- ImagesResponse,
- FileResponse,
- Transcription,
- ]
- ],
- operation: OperationEnum,
- ):
- stream: bool = getattr(request.state, "stream", False)
- if stream:
- if response_class == ChatCompletion:
- response_class = ChatCompletionChunk
- if response_class == ImagesResponse:
- response_class = ImageGenerationChunk
- return await handle_streaming_response(
- request, response, response_class, operation
- )
- else:
- response_body = b"".join([chunk async for chunk in response.body_iterator])
- try:
- usage = None
- if (
- response.headers.get("content-type")
- .lower()
- .startswith("application/json")
- ):
- response_dict = json.loads(response_body)
- response_instance = response_class(**response_dict)
- if hasattr(response_instance, "usage"):
- usage = response_instance.usage
- await record_model_usage(request, usage, operation)
- except Exception as e:
- logger.error(f"Error processing model usage: {e}")
- response = Response(
- content=response_body,
- status_code=response.status_code,
- headers=dict(response.headers),
- )
- return response
- async def record_model_usage(
- request: Request,
- usage: Union[CompletionUsage, EmbeddingUsage, RerankUsage, None],
- operation: OperationEnum,
- ):
- total_tokens = getattr(usage, 'total_tokens', 0) or 0
- prompt_tokens = getattr(usage, 'prompt_tokens', total_tokens) or total_tokens
- completion_tokens = (
- getattr(usage, 'completion_tokens', total_tokens - prompt_tokens)
- or total_tokens - prompt_tokens
- )
- prompt_token_details = (
- getattr(usage, "prompt_tokens_details", None) if usage else None
- )
- input_cached_tokens = 0
- if prompt_token_details:
- if isinstance(prompt_token_details, dict):
- input_cached_tokens = prompt_token_details.get("cached_tokens", 0) or 0
- else:
- input_cached_tokens = getattr(prompt_token_details, "cached_tokens", 0) or 0
- user: User = request.state.user
- model: Model = request.state.model
- api_key: ApiKey | None = getattr(request.state, "api_key", None)
- # Reaching this function means the canonical usage chunk was observed,
- # so the report is ``completed=True``. Wall-clock anchors come from
- # RequestTimeMiddleware (start) and now (completion); the unified
- # flusher uses ``completed_at`` to choose the billing period.
- now = datetime.now(timezone.utc)
- started_at = getattr(request.state, "start_time", None)
- if started_at is None:
- # Falling back to ``now`` means request duration collapses to ~0,
- # which silently breaks SLO/latency analytics built off of
- # (completed_at - started_at). Surface the misconfiguration once
- # without flooding the log on every request.
- _warn_about_missing_start_time()
- started_at = now
- metric = ModelUsageMetrics(
- model=model.name,
- input_token=prompt_tokens,
- output_token=completion_tokens,
- total_token=total_tokens,
- input_cached_token=input_cached_tokens,
- request_count=1,
- completed=True,
- started_at=int(started_at.timestamp() * 1000),
- completed_at=int(now.timestamp() * 1000),
- user_id=user.id if user is not None else None,
- model_id=model.id,
- model_route_id=getattr(request.state, "model_route_id", None),
- # Capture cluster_id at request time so it survives a later model
- # delete; the unified flusher prefers this over re-reading the
- # live model row.
- cluster_id=getattr(model, "cluster_id", None),
- access_key=api_key.access_key if api_key is not None else None,
- operation=operation,
- )
- await accumulate_gateway_metrics([metric])
- async def handle_streaming_response(
- request: Request,
- response: StreamingResponse,
- response_class: Type[
- Union[ChatCompletionChunk, CompletionExt, ImageGenerationChunk]
- ],
- operation: OperationEnum,
- ):
- async def streaming_generator():
- async for chunk in response.body_iterator:
- try:
- async for processed_chunk in process_chunk(
- chunk, request, response_class, operation
- ):
- yield processed_chunk
- except Exception as e:
- logger.error(f"Error processing streaming response: {e}")
- yield chunk
- return StreamingResponse(streaming_generator(), headers=response.headers)
- async def process_chunk(
- chunk,
- request,
- response_class,
- operation: OperationEnum,
- ):
- if not hasattr(request.state, 'first_token_time'):
- request.state.first_token_time = datetime.now(timezone.utc)
- # each chunk may contain multiple data lines
- lines = chunk.decode("utf-8").split("\n\n")
- for line in lines[:-1]:
- if not line.startswith('data: '):
- # skip non-data SSE messages
- yield f"{line}\n\n".encode("utf-8")
- continue
- data = line.split('data: ')[-1]
- if data.startswith('[DONE]'):
- yield "data: [DONE]\n\n".encode("utf-8")
- continue
- if '"usage":' in data:
- response_dict = None
- try:
- response_dict = json.loads(data.strip())
- except Exception as e:
- raise e
- response_chunk = response_class(**response_dict)
- if is_usage_chunk(response_chunk):
- await record_model_usage(request, response_chunk.usage, operation)
- # Fill rate metrics. These are extended info not included in OAI APIs.
- # llama-box provides them out-of-the-box. Align with other backends here.
- if should_add_metrics(response_dict):
- add_metrics(response_dict, request, response_chunk)
- yield f"data: {json.dumps(response_dict, separators=(',', ':'))}\n\n".encode(
- "utf-8"
- )
- else:
- yield f"{line}\n\n".encode("utf-8")
- def should_add_metrics(response_dict):
- if not isinstance(response_dict, dict):
- return False
- usage = response_dict.get('usage', {})
- return 'prompt_tokens' in usage and 'tokens_per_second' not in usage
- def add_metrics(response_dict, request, response_chunk):
- now = datetime.now(timezone.utc)
- time_to_first_token_ms = (
- request.state.first_token_time - request.state.start_time
- ).total_seconds() * 1000
- tokens_after_first = max(response_chunk.usage.completion_tokens - 1, 1)
- time_per_output_token_ms = (
- (now - request.state.first_token_time).total_seconds()
- * 1000
- / tokens_after_first
- )
- tokens_per_second = (
- 1000 / time_per_output_token_ms if time_per_output_token_ms > 0 else 0
- )
- response_dict['usage'].update(
- {
- "time_to_first_token_ms": time_to_first_token_ms,
- "time_per_output_token_ms": time_per_output_token_ms,
- "tokens_per_second": tokens_per_second,
- }
- )
- class RefreshTokenMiddleware(BaseHTTPMiddleware):
- async def dispatch(self, request: Request, call_next):
- response = await call_next(request)
- jwt_manager: JWTManager = request.app.state.jwt_manager
- # Cookie-based refresh (existing local auth)
- token = request.cookies.get(SESSION_COOKIE_NAME)
- if token:
- try:
- payload = jwt_manager.decode_jwt_token(token)
- if payload:
- # Check if the token is about to expire (less than 15 minutes left)
- if payload['exp'] - time.time() < 15 * 60:
- new_token = jwt_manager.create_jwt_token(
- username=payload['sub']
- )
- response.set_cookie(
- key=SESSION_COOKIE_NAME,
- value=new_token,
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- except (ExpiredSignatureError, DecodeError):
- pass
- else:
- # SSO Bearer token sliding expiration
- auth_header = request.headers.get("Authorization", "")
- if auth_header.startswith("Bearer "):
- bearer_token = auth_header[7:]
- try:
- payload = jwt_manager.decode_jwt_token(bearer_token)
- if payload:
- exp = payload.get('exp', 0)
- iat = payload.get('iat', 0) or (exp - envs.JWT_TOKEN_EXPIRE_MINUTES * 60)
- lifetime = exp - iat
- remaining = exp - time.time()
- # If token has used more than 50% of its lifetime, issue a new one
- if remaining < lifetime * 0.5:
- new_token = jwt_manager.create_jwt_token(
- username=payload['sub']
- )
- response.headers['X-New-Token'] = new_token
- except (ExpiredSignatureError, DecodeError):
- pass
- return response
- def is_usage_chunk(
- chunk: Union[ChatCompletionChunk, CompletionExt, ImageGenerationChunk],
- ) -> bool:
- choices = getattr(chunk, "choices", None)
- if not choices and chunk.usage:
- return True
- for choice in choices or []:
- if choice.finish_reason is not None and chunk.usage:
- return True
- return False
|