| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540 |
- import asyncio
- import logging
- from datetime import date, datetime, timezone
- from typing import Dict, List, Optional, Set, Tuple
- from pydantic import BaseModel
- from sqlmodel.ext.asyncio.session import AsyncSession
- from gpustack import envs
- from gpustack.schemas.api_keys import ApiKey
- from gpustack.schemas.clusters import Cluster
- from gpustack.schemas.model_provider import ModelProvider
- from gpustack.schemas.model_routes import ModelRoute
- from gpustack.schemas.model_usage import ModelUsage, OperationEnum
- from gpustack.schemas.model_usage_details import ModelUsageDetails
- from gpustack.schemas.models import Model, is_embedding_model, is_reranker_model
- from gpustack.schemas.users import User
- from gpustack.server.db import async_session
- from gpustack.utils.usage_snapshots import build_model_usage_snapshot
- logger = logging.getLogger(__name__)
- FLUSH_INTERVAL_SECONDS = 10
- # Heuristics for partial-stream usage estimation. The proxy never applies
- # these ratios itself — they kick in server-side only when an incomplete
- # report leaves token fields blank. Tunable via env (see ``gpustack.envs``).
- # Buffer to accumulate pushed gateway metrics: {key: ModelUsageMetrics}.
- # Key format (see ``_make_buffer_key``):
- # "{model_id}.{provider_id}.{model}.{user_id}.{access_key}.{operation}.{date}"
- # ``operation`` and ``date`` are part of the key so per-operation rollups
- # stay separate and a stream that crosses midnight lands in the period
- # it ends in (anchored on completed_at).
- gateway_metrics_buffer: Dict[str, "ModelUsageMetrics"] = {}
- # Raw per-report metrics retained for ``model_usage_details`` audit rows.
- # Unlike ``gateway_metrics_buffer``, entries are not aggregated.
- gateway_details_buffer: List["ModelUsageMetrics"] = []
- # Single lock guarding both rollup and details buffers; ingest writes
- # them together, so they must be drained together too.
- gateway_buffers_lock = asyncio.Lock()
- class ModelUsageMetrics(BaseModel):
- model: str
- input_token: int = 0
- output_token: int = 0
- total_token: int = 0
- input_cached_token: int = 0
- request_count: int = 1
- # ``completed`` is True iff the canonical usage chunk was observed before
- # the stream ended. When False, token fields may be 0 (OpenAI/vLLM) or
- # partial (Anthropic message_start carries input_token early), so the
- # server falls back to estimation from the byte/chunk fields below.
- completed: bool = False
- output_chunk_count: int = 0
- request_content_bytes: int = 0
- # Wall-clock UnixMilli stamps captured at request entry and report
- # dispatch respectively. ``None`` means the report didn't carry one;
- # legacy payloads sending literal ``0`` are also treated as absent
- # downstream (see ``_unixmilli_to_naive_utc``).
- started_at: Optional[int] = None
- completed_at: Optional[int] = None
- user_id: Optional[int] = None
- model_id: Optional[int] = None
- model_route_id: Optional[int] = None
- # Captured at request time by middleware (request.state.model.cluster_id).
- # Carried on the metric so the historical cluster id survives even if the
- # model is deleted between request and flush.
- cluster_id: Optional[int] = None
- provider_id: Optional[int] = None
- provider_name: Optional[str] = None
- provider_type: Optional[str] = None
- access_key: Optional[str] = None
- # Inference operation type (chat_completion / embedding / rerank / ...).
- # None when the gateway report doesn't carry it; middleware-fed metrics
- # always populate it so per-operation rollups survive unification.
- operation: Optional[OperationEnum] = None
- def _unixmilli_to_naive_utc(ms: Optional[int]) -> Optional[datetime]:
- """Convert a UnixMilli stamp to naive UTC, or None if absent / non-positive.
- Accepts ``None`` (current absence sentinel) and ``<= 0`` (legacy absence
- sentinel that some older gateway payloads still send) — both collapse to
- ``None``. The naive-UTC convention matches ``TimestampsMixin._datetime_func``
- and the ``UTCDateTime`` storage type, which both strip tzinfo.
- """
- if ms is None or ms <= 0:
- return None
- return datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc).replace(tzinfo=None)
- def _resolve_metric_datetime(
- metric: ModelUsageMetrics,
- ) -> Tuple[date, datetime]:
- """Resolve (date, naive-UTC datetime) anchored on the metric's wall-clock.
- Prefers ``completed_at`` so a stream that crosses a calendar boundary
- lands in the period it ends in (per the proxy contract). Falls back to
- ``started_at`` and finally to server ``now`` if both are absent.
- """
- dt = (
- _unixmilli_to_naive_utc(metric.completed_at)
- or _unixmilli_to_naive_utc(metric.started_at)
- or datetime.now(timezone.utc).replace(tzinfo=None)
- )
- return dt.date(), dt
- def _make_buffer_key(metric: ModelUsageMetrics) -> str:
- # Include the completion-anchored date so streams that cross midnight
- # accumulate into the correct billing-period rollup instead of being
- # merged with the next day's traffic.
- metric_date, _ = _resolve_metric_datetime(metric)
- operation = metric.operation.value if metric.operation else ""
- return ".".join(
- str(part or "")
- for part in [
- metric.model_id,
- metric.provider_id,
- metric.model,
- metric.user_id,
- metric.access_key,
- operation,
- metric_date.isoformat(),
- ]
- )
- def _estimate_partial_usage(metric: ModelUsageMetrics) -> None:
- """Backfill input_token / output_token for incomplete reports in place.
- Only fills slots that are still empty so that legitimate partial values
- (e.g. Anthropic's early ``input_token``) survive untouched. Estimation
- is intentionally a server-side concern — the proxy never applies these
- ratios itself.
- """
- if metric.completed:
- return
- if metric.input_token <= 0 and metric.request_content_bytes > 0:
- metric.input_token = max(
- 1,
- metric.request_content_bytes // envs.USAGE_ESTIMATED_BYTES_PER_INPUT_TOKEN,
- )
- if metric.output_token <= 0 and metric.output_chunk_count > 0:
- metric.output_token = (
- metric.output_chunk_count * envs.USAGE_ESTIMATED_TOKENS_PER_OUTPUT_CHUNK
- )
- estimated_total = metric.input_token + metric.output_token
- if metric.total_token < estimated_total:
- metric.total_token = estimated_total
- def _resolve_usage_tokens(
- metric: ModelUsageMetrics, model: Optional[Model]
- ) -> tuple[int, int]:
- prompt_tokens = metric.input_token
- completion_tokens = metric.output_token
- if (
- model is not None
- and (is_reranker_model(model) or is_embedding_model(model))
- and metric.total_token > (prompt_tokens + completion_tokens)
- ):
- return metric.total_token - completion_tokens, completion_tokens
- return prompt_tokens, completion_tokens
- async def accumulate_gateway_metrics(metrics: List[ModelUsageMetrics]):
- async with gateway_buffers_lock:
- for incoming in metrics:
- # Take ownership before any in-place work:
- # * ``_estimate_partial_usage`` mutates token fields directly.
- # * The rollup buffer's ``+=`` mutates the stored entry, which
- # would also mutate the caller's instance (and bleed into the
- # details audit row) if we shared references.
- # One copy at the top + one for details keeps both buffers, the
- # caller, and the audit trail isolated from one another.
- metric = incoming.model_copy()
- # Backfill estimated tokens before either buffer sees the metric:
- # the rollup buffer aggregates by += and would otherwise lose the
- # per-row byte/chunk context needed for estimation later on.
- _estimate_partial_usage(metric)
- gateway_details_buffer.append(metric.model_copy())
- key = _make_buffer_key(metric)
- existing = gateway_metrics_buffer.get(key)
- if existing is None:
- gateway_metrics_buffer[key] = metric
- else:
- existing.input_token += metric.input_token
- existing.output_token += metric.output_token
- existing.total_token += metric.total_token
- existing.input_cached_token += metric.input_cached_token
- existing.request_count += metric.request_count
- _trim_details_buffer_locked()
- def _trim_details_buffer_locked() -> None:
- """Cap ``gateway_details_buffer`` to bound memory under persistent flush
- failure.
- The flush failure path re-accumulates pending details so transient errors
- don't lose the audit trail, but persistent failures (DB down, schema
- drift, constraint violation) would let the buffer grow unbounded as new
- ingest piles on. Drop oldest entries past the cap and log once per
- overflow event so operators notice. Caller must hold
- ``gateway_buffers_lock``.
- """
- cap = envs.USAGE_DETAILS_BUFFER_MAX_SIZE
- overflow = len(gateway_details_buffer) - cap
- if overflow <= 0:
- return
- del gateway_details_buffer[:overflow]
- logger.warning(
- "gateway_details_buffer exceeded cap (%d); dropped %d oldest detail "
- "rows. Likely cause: persistent flush failure to model_usage_details.",
- cap,
- overflow,
- )
- async def flush_gateway_metrics():
- async with gateway_buffers_lock:
- if not gateway_metrics_buffer and not gateway_details_buffer:
- return
- pending_rollups = list(gateway_metrics_buffer.values())
- pending_details = list(gateway_details_buffer)
- gateway_metrics_buffer.clear()
- gateway_details_buffer.clear()
- try:
- await store_usage_metrics(pending_rollups, pending_details)
- except Exception as e:
- logger.error(f"Error flushing gateway metrics to DB: {e}")
- # Re-buffering raw details restores both buffers via the same
- # aggregation logic as the original ingest path.
- await accumulate_gateway_metrics(pending_details)
- async def flush_gateway_metrics_to_db():
- while True:
- await asyncio.sleep(FLUSH_INTERVAL_SECONDS)
- await flush_gateway_metrics()
- async def create_or_update_model_usage(
- session: AsyncSession, metric: ModelUsage, auto_commit: bool = True
- ):
- current_usage = await ModelUsage.one_by_fields(
- session=session,
- fields={
- "model_id": metric.model_id,
- "user_id": metric.user_id,
- "provider_id": metric.provider_id,
- "provider_name": metric.provider_name,
- "provider_type": metric.provider_type,
- "model_name": metric.model_name,
- "access_key": metric.access_key,
- "operation": metric.operation,
- "date": metric.date,
- },
- )
- if current_usage is None:
- await metric.save(session=session, auto_commit=auto_commit)
- else:
- current_usage.prompt_token_count += metric.prompt_token_count
- current_usage.completion_token_count += metric.completion_token_count
- current_usage.prompt_cached_token_count += metric.prompt_cached_token_count
- current_usage.request_count += metric.request_count
- await current_usage.save(session=session, auto_commit=auto_commit)
- def _validate_usage_metric(
- metric: ModelUsageMetrics,
- models: Dict[int, Model],
- providers: Dict[int, ModelProvider],
- user_ids: Set[int],
- ) -> bool:
- if metric.model_id is None and metric.provider_id is None:
- logger.debug(
- f"Both model_id and provider_id are None for metric: {metric}, skipping."
- )
- return False
- if metric.model_id is not None:
- model = models.get(metric.model_id)
- if not model:
- logger.debug(f"Model ID {metric.model_id} not found in database.")
- return False
- if model.name != metric.model:
- logger.debug(
- f"Model name {metric.model} does not match database record {model.name} for model ID {metric.model_id}."
- )
- return False
- if metric.provider_id is not None:
- provider = providers.get(metric.provider_id)
- if not provider:
- logger.debug(f"Provider ID {metric.provider_id} not found in database.")
- return False
- if metric.model not in {m.name for m in provider.models}:
- logger.debug(
- f"Model name {metric.model} not found for provider ID {metric.provider_id} in database."
- )
- return False
- if metric.user_id is not None and metric.user_id not in user_ids:
- logger.debug(f"User ID {metric.user_id} not found in database.")
- return False
- return True
- def _build_metric_snapshot(
- metric: ModelUsageMetrics,
- model_by_id: Dict[int, Model],
- provider_by_id: Dict[int, ModelProvider],
- user_by_id: Dict[int, User],
- api_key_by_access_key: Dict[str, ApiKey],
- cluster_names_by_id: Dict[int, str],
- ) -> dict:
- user = user_by_id.get(metric.user_id)
- api_key = api_key_by_access_key.get(metric.access_key)
- model = model_by_id.get(metric.model_id)
- provider = provider_by_id.get(metric.provider_id)
- if model is None:
- snapshot = {
- "model_id": metric.model_id,
- "model_name": metric.model,
- "cluster_name": None,
- }
- if provider is not None:
- provider_type = getattr(getattr(provider, "config", None), "type", None)
- if provider_type is not None and hasattr(provider_type, "value"):
- provider_type = provider_type.value
- snapshot.update(
- {
- "provider_id": provider.id,
- "provider_name": provider.name,
- "provider_type": provider_type,
- }
- )
- if user is not None:
- snapshot.update(
- {
- "user_id": user.id,
- "user_name": user.username,
- }
- )
- if api_key is not None:
- snapshot.update(
- {
- "api_key_id": api_key.id,
- "api_key_name": api_key.name,
- "access_key": api_key.access_key,
- "api_key_is_custom": api_key.is_custom,
- }
- )
- else:
- snapshot = build_model_usage_snapshot(
- model,
- cluster_name=cluster_names_by_id.get(model.cluster_id),
- user=user,
- api_key=api_key,
- provider=provider,
- )
- snapshot.setdefault("user_id", metric.user_id)
- snapshot.setdefault("provider_id", metric.provider_id)
- snapshot.setdefault("provider_name", metric.provider_name)
- snapshot.setdefault("provider_type", metric.provider_type)
- snapshot.setdefault("access_key", metric.access_key)
- snapshot.setdefault("api_key_is_custom", None)
- return snapshot
- async def store_usage_metrics(
- metrics: List[ModelUsageMetrics],
- detail_metrics: Optional[List[ModelUsageMetrics]] = None,
- ):
- detail_metrics = list(detail_metrics or [])
- if not metrics and not detail_metrics:
- return
- all_metrics = list(metrics) + detail_metrics
- dedup_model_names = {m.model for m in all_metrics}
- dedup_user_ids = {m.user_id for m in all_metrics if m.user_id is not None}
- dedup_access_keys = {m.access_key for m in all_metrics if m.access_key is not None}
- dedup_provider_ids = {
- m.provider_id for m in all_metrics if m.provider_id is not None
- }
- dedup_route_ids = {
- m.model_route_id for m in all_metrics if m.model_route_id is not None
- }
- async with async_session() as session:
- try:
- models = await Model.all_by_fields(
- session=session,
- fields={},
- extra_conditions=[Model.name.in_(dedup_model_names)],
- )
- providers = await ModelProvider.all_by_fields(
- session=session,
- fields={},
- extra_conditions=(
- [ModelProvider.id.in_(dedup_provider_ids)]
- if dedup_provider_ids
- else []
- ),
- )
- users = await User.all_by_fields(
- session=session,
- fields={},
- extra_conditions=[User.id.in_(dedup_user_ids)],
- )
- api_keys = await ApiKey.all_by_fields(
- session=session,
- fields={},
- extra_conditions=(
- [ApiKey.access_key.in_(dedup_access_keys)]
- if dedup_access_keys
- else []
- ),
- )
- route_name_by_id: Dict[int, str] = {}
- if dedup_route_ids:
- routes = await ModelRoute.all_by_fields(
- session=session,
- fields={},
- extra_conditions=[ModelRoute.id.in_(dedup_route_ids)],
- )
- route_name_by_id = {r.id: r.name for r in routes}
- validated_user_ids = {u.id for u in users}
- user_by_id = {u.id: u for u in users}
- api_key_by_access_key = {k.access_key: k for k in api_keys}
- model_by_id = {m.id: m for m in models}
- cluster_ids = {m.cluster_id for m in models if m.cluster_id is not None}
- clusters = await Cluster.all_by_fields(
- session=session,
- fields={},
- extra_conditions=([Cluster.id.in_(cluster_ids)] if cluster_ids else []),
- )
- cluster_names_by_id = {c.id: c.name for c in clusters}
- provider_by_id = {p.id: p for p in providers}
- for metric in metrics:
- if not _validate_usage_metric(
- metric, model_by_id, provider_by_id, validated_user_ids
- ):
- continue
- snapshot = _build_metric_snapshot(
- metric,
- model_by_id,
- provider_by_id,
- user_by_id,
- api_key_by_access_key,
- cluster_names_by_id,
- )
- prompt_tokens, completion_tokens = _resolve_usage_tokens(
- metric, model_by_id.get(metric.model_id)
- )
- metric_date, _ = _resolve_metric_datetime(metric)
- model_usage = ModelUsage(
- date=metric_date,
- prompt_token_count=prompt_tokens,
- completion_token_count=completion_tokens,
- prompt_cached_token_count=metric.input_cached_token,
- request_count=metric.request_count,
- operation=metric.operation,
- **snapshot,
- )
- await create_or_update_model_usage(
- session, model_usage, auto_commit=False
- )
- for metric in detail_metrics:
- if not _validate_usage_metric(
- metric, model_by_id, provider_by_id, validated_user_ids
- ):
- continue
- snapshot = _build_metric_snapshot(
- metric,
- model_by_id,
- provider_by_id,
- user_by_id,
- api_key_by_access_key,
- cluster_names_by_id,
- )
- prompt_tokens, completion_tokens = _resolve_usage_tokens(
- metric, model_by_id.get(metric.model_id)
- )
- # Preserve the reported model_route_id verbatim — details
- # is FK-less by design (ModelUsageDetails docstring) so the
- # historical id stays audit-valuable even when the route
- # was deleted between request dispatch and this flush.
- # Name is best-effort from the live table; null when the
- # route is gone.
- model_route_id = metric.model_route_id
- model_route_name = route_name_by_id.get(metric.model_route_id)
- # cluster_id only lives on the audit/details rows, not on
- # the dashboard rollup (ModelUsage). Prefer the metric's
- # own cluster_id (captured at request time, survives model
- # deletes); fall back to the live model only when the
- # ingest source didn't carry one (older gateway clients).
- cluster_id = metric.cluster_id
- if cluster_id is None:
- cluster_id = getattr(
- model_by_id.get(metric.model_id), "cluster_id", None
- )
- metric_date, metric_dt = _resolve_metric_datetime(metric)
- started_dt = _unixmilli_to_naive_utc(metric.started_at)
- completed_dt = _unixmilli_to_naive_utc(metric.completed_at)
- session.add(
- ModelUsageDetails(
- date=metric_date,
- model_route_id=model_route_id,
- model_route_name=model_route_name,
- cluster_id=cluster_id,
- prompt_token_count=prompt_tokens,
- completion_token_count=completion_tokens,
- prompt_cached_token_count=metric.input_cached_token,
- operation=metric.operation,
- # Proxy-reported wall-clock — preserved as NULL when
- # the report didn't carry it, so reconciliation jobs
- # can tell estimated rows apart from authoritative
- # ones.
- started_at=started_dt,
- completed_at=completed_dt,
- # Audit timestamps still pinned to the request's
- # wall-clock so the row's lifecycle stamps don't
- # drift by the flush interval.
- created_at=metric_dt,
- updated_at=metric_dt,
- **snapshot,
- )
- )
- await session.commit()
- except Exception as e:
- logger.exception(f"Error storing gateway metrics: {e}")
- await session.rollback()
- # Propagate so flush_gateway_metrics can re-buffer the pending
- # records — without this, a transactional rollback silently
- # drops a flush window's worth of audit rows.
- raise
|