metrics_collector.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. import asyncio
  2. import logging
  3. from datetime import date, datetime, timezone
  4. from typing import Dict, List, Optional, Set, Tuple
  5. from pydantic import BaseModel
  6. from sqlmodel.ext.asyncio.session import AsyncSession
  7. from gpustack import envs
  8. from gpustack.schemas.api_keys import ApiKey
  9. from gpustack.schemas.clusters import Cluster
  10. from gpustack.schemas.model_provider import ModelProvider
  11. from gpustack.schemas.model_routes import ModelRoute
  12. from gpustack.schemas.model_usage import ModelUsage, OperationEnum
  13. from gpustack.schemas.model_usage_details import ModelUsageDetails
  14. from gpustack.schemas.models import Model, is_embedding_model, is_reranker_model
  15. from gpustack.schemas.users import User
  16. from gpustack.server.db import async_session
  17. from gpustack.utils.usage_snapshots import build_model_usage_snapshot
  18. logger = logging.getLogger(__name__)
  19. FLUSH_INTERVAL_SECONDS = 10
  20. # Heuristics for partial-stream usage estimation. The proxy never applies
  21. # these ratios itself — they kick in server-side only when an incomplete
  22. # report leaves token fields blank. Tunable via env (see ``gpustack.envs``).
  23. # Buffer to accumulate pushed gateway metrics: {key: ModelUsageMetrics}.
  24. # Key format (see ``_make_buffer_key``):
  25. # "{model_id}.{provider_id}.{model}.{user_id}.{access_key}.{operation}.{date}"
  26. # ``operation`` and ``date`` are part of the key so per-operation rollups
  27. # stay separate and a stream that crosses midnight lands in the period
  28. # it ends in (anchored on completed_at).
  29. gateway_metrics_buffer: Dict[str, "ModelUsageMetrics"] = {}
  30. # Raw per-report metrics retained for ``model_usage_details`` audit rows.
  31. # Unlike ``gateway_metrics_buffer``, entries are not aggregated.
  32. gateway_details_buffer: List["ModelUsageMetrics"] = []
  33. # Single lock guarding both rollup and details buffers; ingest writes
  34. # them together, so they must be drained together too.
  35. gateway_buffers_lock = asyncio.Lock()
  36. class ModelUsageMetrics(BaseModel):
  37. model: str
  38. input_token: int = 0
  39. output_token: int = 0
  40. total_token: int = 0
  41. input_cached_token: int = 0
  42. request_count: int = 1
  43. # ``completed`` is True iff the canonical usage chunk was observed before
  44. # the stream ended. When False, token fields may be 0 (OpenAI/vLLM) or
  45. # partial (Anthropic message_start carries input_token early), so the
  46. # server falls back to estimation from the byte/chunk fields below.
  47. completed: bool = False
  48. output_chunk_count: int = 0
  49. request_content_bytes: int = 0
  50. # Wall-clock UnixMilli stamps captured at request entry and report
  51. # dispatch respectively. ``None`` means the report didn't carry one;
  52. # legacy payloads sending literal ``0`` are also treated as absent
  53. # downstream (see ``_unixmilli_to_naive_utc``).
  54. started_at: Optional[int] = None
  55. completed_at: Optional[int] = None
  56. user_id: Optional[int] = None
  57. model_id: Optional[int] = None
  58. model_route_id: Optional[int] = None
  59. # Captured at request time by middleware (request.state.model.cluster_id).
  60. # Carried on the metric so the historical cluster id survives even if the
  61. # model is deleted between request and flush.
  62. cluster_id: Optional[int] = None
  63. provider_id: Optional[int] = None
  64. provider_name: Optional[str] = None
  65. provider_type: Optional[str] = None
  66. access_key: Optional[str] = None
  67. # Inference operation type (chat_completion / embedding / rerank / ...).
  68. # None when the gateway report doesn't carry it; middleware-fed metrics
  69. # always populate it so per-operation rollups survive unification.
  70. operation: Optional[OperationEnum] = None
  71. def _unixmilli_to_naive_utc(ms: Optional[int]) -> Optional[datetime]:
  72. """Convert a UnixMilli stamp to naive UTC, or None if absent / non-positive.
  73. Accepts ``None`` (current absence sentinel) and ``<= 0`` (legacy absence
  74. sentinel that some older gateway payloads still send) — both collapse to
  75. ``None``. The naive-UTC convention matches ``TimestampsMixin._datetime_func``
  76. and the ``UTCDateTime`` storage type, which both strip tzinfo.
  77. """
  78. if ms is None or ms <= 0:
  79. return None
  80. return datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc).replace(tzinfo=None)
  81. def _resolve_metric_datetime(
  82. metric: ModelUsageMetrics,
  83. ) -> Tuple[date, datetime]:
  84. """Resolve (date, naive-UTC datetime) anchored on the metric's wall-clock.
  85. Prefers ``completed_at`` so a stream that crosses a calendar boundary
  86. lands in the period it ends in (per the proxy contract). Falls back to
  87. ``started_at`` and finally to server ``now`` if both are absent.
  88. """
  89. dt = (
  90. _unixmilli_to_naive_utc(metric.completed_at)
  91. or _unixmilli_to_naive_utc(metric.started_at)
  92. or datetime.now(timezone.utc).replace(tzinfo=None)
  93. )
  94. return dt.date(), dt
  95. def _make_buffer_key(metric: ModelUsageMetrics) -> str:
  96. # Include the completion-anchored date so streams that cross midnight
  97. # accumulate into the correct billing-period rollup instead of being
  98. # merged with the next day's traffic.
  99. metric_date, _ = _resolve_metric_datetime(metric)
  100. operation = metric.operation.value if metric.operation else ""
  101. return ".".join(
  102. str(part or "")
  103. for part in [
  104. metric.model_id,
  105. metric.provider_id,
  106. metric.model,
  107. metric.user_id,
  108. metric.access_key,
  109. operation,
  110. metric_date.isoformat(),
  111. ]
  112. )
  113. def _estimate_partial_usage(metric: ModelUsageMetrics) -> None:
  114. """Backfill input_token / output_token for incomplete reports in place.
  115. Only fills slots that are still empty so that legitimate partial values
  116. (e.g. Anthropic's early ``input_token``) survive untouched. Estimation
  117. is intentionally a server-side concern — the proxy never applies these
  118. ratios itself.
  119. """
  120. if metric.completed:
  121. return
  122. if metric.input_token <= 0 and metric.request_content_bytes > 0:
  123. metric.input_token = max(
  124. 1,
  125. metric.request_content_bytes // envs.USAGE_ESTIMATED_BYTES_PER_INPUT_TOKEN,
  126. )
  127. if metric.output_token <= 0 and metric.output_chunk_count > 0:
  128. metric.output_token = (
  129. metric.output_chunk_count * envs.USAGE_ESTIMATED_TOKENS_PER_OUTPUT_CHUNK
  130. )
  131. estimated_total = metric.input_token + metric.output_token
  132. if metric.total_token < estimated_total:
  133. metric.total_token = estimated_total
  134. def _resolve_usage_tokens(
  135. metric: ModelUsageMetrics, model: Optional[Model]
  136. ) -> tuple[int, int]:
  137. prompt_tokens = metric.input_token
  138. completion_tokens = metric.output_token
  139. if (
  140. model is not None
  141. and (is_reranker_model(model) or is_embedding_model(model))
  142. and metric.total_token > (prompt_tokens + completion_tokens)
  143. ):
  144. return metric.total_token - completion_tokens, completion_tokens
  145. return prompt_tokens, completion_tokens
  146. async def accumulate_gateway_metrics(metrics: List[ModelUsageMetrics]):
  147. async with gateway_buffers_lock:
  148. for incoming in metrics:
  149. # Take ownership before any in-place work:
  150. # * ``_estimate_partial_usage`` mutates token fields directly.
  151. # * The rollup buffer's ``+=`` mutates the stored entry, which
  152. # would also mutate the caller's instance (and bleed into the
  153. # details audit row) if we shared references.
  154. # One copy at the top + one for details keeps both buffers, the
  155. # caller, and the audit trail isolated from one another.
  156. metric = incoming.model_copy()
  157. # Backfill estimated tokens before either buffer sees the metric:
  158. # the rollup buffer aggregates by += and would otherwise lose the
  159. # per-row byte/chunk context needed for estimation later on.
  160. _estimate_partial_usage(metric)
  161. gateway_details_buffer.append(metric.model_copy())
  162. key = _make_buffer_key(metric)
  163. existing = gateway_metrics_buffer.get(key)
  164. if existing is None:
  165. gateway_metrics_buffer[key] = metric
  166. else:
  167. existing.input_token += metric.input_token
  168. existing.output_token += metric.output_token
  169. existing.total_token += metric.total_token
  170. existing.input_cached_token += metric.input_cached_token
  171. existing.request_count += metric.request_count
  172. _trim_details_buffer_locked()
  173. def _trim_details_buffer_locked() -> None:
  174. """Cap ``gateway_details_buffer`` to bound memory under persistent flush
  175. failure.
  176. The flush failure path re-accumulates pending details so transient errors
  177. don't lose the audit trail, but persistent failures (DB down, schema
  178. drift, constraint violation) would let the buffer grow unbounded as new
  179. ingest piles on. Drop oldest entries past the cap and log once per
  180. overflow event so operators notice. Caller must hold
  181. ``gateway_buffers_lock``.
  182. """
  183. cap = envs.USAGE_DETAILS_BUFFER_MAX_SIZE
  184. overflow = len(gateway_details_buffer) - cap
  185. if overflow <= 0:
  186. return
  187. del gateway_details_buffer[:overflow]
  188. logger.warning(
  189. "gateway_details_buffer exceeded cap (%d); dropped %d oldest detail "
  190. "rows. Likely cause: persistent flush failure to model_usage_details.",
  191. cap,
  192. overflow,
  193. )
  194. async def flush_gateway_metrics():
  195. async with gateway_buffers_lock:
  196. if not gateway_metrics_buffer and not gateway_details_buffer:
  197. return
  198. pending_rollups = list(gateway_metrics_buffer.values())
  199. pending_details = list(gateway_details_buffer)
  200. gateway_metrics_buffer.clear()
  201. gateway_details_buffer.clear()
  202. try:
  203. await store_usage_metrics(pending_rollups, pending_details)
  204. except Exception as e:
  205. logger.error(f"Error flushing gateway metrics to DB: {e}")
  206. # Re-buffering raw details restores both buffers via the same
  207. # aggregation logic as the original ingest path.
  208. await accumulate_gateway_metrics(pending_details)
  209. async def flush_gateway_metrics_to_db():
  210. while True:
  211. await asyncio.sleep(FLUSH_INTERVAL_SECONDS)
  212. await flush_gateway_metrics()
  213. async def create_or_update_model_usage(
  214. session: AsyncSession, metric: ModelUsage, auto_commit: bool = True
  215. ):
  216. current_usage = await ModelUsage.one_by_fields(
  217. session=session,
  218. fields={
  219. "model_id": metric.model_id,
  220. "user_id": metric.user_id,
  221. "provider_id": metric.provider_id,
  222. "provider_name": metric.provider_name,
  223. "provider_type": metric.provider_type,
  224. "model_name": metric.model_name,
  225. "access_key": metric.access_key,
  226. "operation": metric.operation,
  227. "date": metric.date,
  228. },
  229. )
  230. if current_usage is None:
  231. await metric.save(session=session, auto_commit=auto_commit)
  232. else:
  233. current_usage.prompt_token_count += metric.prompt_token_count
  234. current_usage.completion_token_count += metric.completion_token_count
  235. current_usage.prompt_cached_token_count += metric.prompt_cached_token_count
  236. current_usage.request_count += metric.request_count
  237. await current_usage.save(session=session, auto_commit=auto_commit)
  238. def _validate_usage_metric(
  239. metric: ModelUsageMetrics,
  240. models: Dict[int, Model],
  241. providers: Dict[int, ModelProvider],
  242. user_ids: Set[int],
  243. ) -> bool:
  244. if metric.model_id is None and metric.provider_id is None:
  245. logger.debug(
  246. f"Both model_id and provider_id are None for metric: {metric}, skipping."
  247. )
  248. return False
  249. if metric.model_id is not None:
  250. model = models.get(metric.model_id)
  251. if not model:
  252. logger.debug(f"Model ID {metric.model_id} not found in database.")
  253. return False
  254. if model.name != metric.model:
  255. logger.debug(
  256. f"Model name {metric.model} does not match database record {model.name} for model ID {metric.model_id}."
  257. )
  258. return False
  259. if metric.provider_id is not None:
  260. provider = providers.get(metric.provider_id)
  261. if not provider:
  262. logger.debug(f"Provider ID {metric.provider_id} not found in database.")
  263. return False
  264. if metric.model not in {m.name for m in provider.models}:
  265. logger.debug(
  266. f"Model name {metric.model} not found for provider ID {metric.provider_id} in database."
  267. )
  268. return False
  269. if metric.user_id is not None and metric.user_id not in user_ids:
  270. logger.debug(f"User ID {metric.user_id} not found in database.")
  271. return False
  272. return True
  273. def _build_metric_snapshot(
  274. metric: ModelUsageMetrics,
  275. model_by_id: Dict[int, Model],
  276. provider_by_id: Dict[int, ModelProvider],
  277. user_by_id: Dict[int, User],
  278. api_key_by_access_key: Dict[str, ApiKey],
  279. cluster_names_by_id: Dict[int, str],
  280. ) -> dict:
  281. user = user_by_id.get(metric.user_id)
  282. api_key = api_key_by_access_key.get(metric.access_key)
  283. model = model_by_id.get(metric.model_id)
  284. provider = provider_by_id.get(metric.provider_id)
  285. if model is None:
  286. snapshot = {
  287. "model_id": metric.model_id,
  288. "model_name": metric.model,
  289. "cluster_name": None,
  290. }
  291. if provider is not None:
  292. provider_type = getattr(getattr(provider, "config", None), "type", None)
  293. if provider_type is not None and hasattr(provider_type, "value"):
  294. provider_type = provider_type.value
  295. snapshot.update(
  296. {
  297. "provider_id": provider.id,
  298. "provider_name": provider.name,
  299. "provider_type": provider_type,
  300. }
  301. )
  302. if user is not None:
  303. snapshot.update(
  304. {
  305. "user_id": user.id,
  306. "user_name": user.username,
  307. }
  308. )
  309. if api_key is not None:
  310. snapshot.update(
  311. {
  312. "api_key_id": api_key.id,
  313. "api_key_name": api_key.name,
  314. "access_key": api_key.access_key,
  315. "api_key_is_custom": api_key.is_custom,
  316. }
  317. )
  318. else:
  319. snapshot = build_model_usage_snapshot(
  320. model,
  321. cluster_name=cluster_names_by_id.get(model.cluster_id),
  322. user=user,
  323. api_key=api_key,
  324. provider=provider,
  325. )
  326. snapshot.setdefault("user_id", metric.user_id)
  327. snapshot.setdefault("provider_id", metric.provider_id)
  328. snapshot.setdefault("provider_name", metric.provider_name)
  329. snapshot.setdefault("provider_type", metric.provider_type)
  330. snapshot.setdefault("access_key", metric.access_key)
  331. snapshot.setdefault("api_key_is_custom", None)
  332. return snapshot
  333. async def store_usage_metrics(
  334. metrics: List[ModelUsageMetrics],
  335. detail_metrics: Optional[List[ModelUsageMetrics]] = None,
  336. ):
  337. detail_metrics = list(detail_metrics or [])
  338. if not metrics and not detail_metrics:
  339. return
  340. all_metrics = list(metrics) + detail_metrics
  341. dedup_model_names = {m.model for m in all_metrics}
  342. dedup_user_ids = {m.user_id for m in all_metrics if m.user_id is not None}
  343. dedup_access_keys = {m.access_key for m in all_metrics if m.access_key is not None}
  344. dedup_provider_ids = {
  345. m.provider_id for m in all_metrics if m.provider_id is not None
  346. }
  347. dedup_route_ids = {
  348. m.model_route_id for m in all_metrics if m.model_route_id is not None
  349. }
  350. async with async_session() as session:
  351. try:
  352. models = await Model.all_by_fields(
  353. session=session,
  354. fields={},
  355. extra_conditions=[Model.name.in_(dedup_model_names)],
  356. )
  357. providers = await ModelProvider.all_by_fields(
  358. session=session,
  359. fields={},
  360. extra_conditions=(
  361. [ModelProvider.id.in_(dedup_provider_ids)]
  362. if dedup_provider_ids
  363. else []
  364. ),
  365. )
  366. users = await User.all_by_fields(
  367. session=session,
  368. fields={},
  369. extra_conditions=[User.id.in_(dedup_user_ids)],
  370. )
  371. api_keys = await ApiKey.all_by_fields(
  372. session=session,
  373. fields={},
  374. extra_conditions=(
  375. [ApiKey.access_key.in_(dedup_access_keys)]
  376. if dedup_access_keys
  377. else []
  378. ),
  379. )
  380. route_name_by_id: Dict[int, str] = {}
  381. if dedup_route_ids:
  382. routes = await ModelRoute.all_by_fields(
  383. session=session,
  384. fields={},
  385. extra_conditions=[ModelRoute.id.in_(dedup_route_ids)],
  386. )
  387. route_name_by_id = {r.id: r.name for r in routes}
  388. validated_user_ids = {u.id for u in users}
  389. user_by_id = {u.id: u for u in users}
  390. api_key_by_access_key = {k.access_key: k for k in api_keys}
  391. model_by_id = {m.id: m for m in models}
  392. cluster_ids = {m.cluster_id for m in models if m.cluster_id is not None}
  393. clusters = await Cluster.all_by_fields(
  394. session=session,
  395. fields={},
  396. extra_conditions=([Cluster.id.in_(cluster_ids)] if cluster_ids else []),
  397. )
  398. cluster_names_by_id = {c.id: c.name for c in clusters}
  399. provider_by_id = {p.id: p for p in providers}
  400. for metric in metrics:
  401. if not _validate_usage_metric(
  402. metric, model_by_id, provider_by_id, validated_user_ids
  403. ):
  404. continue
  405. snapshot = _build_metric_snapshot(
  406. metric,
  407. model_by_id,
  408. provider_by_id,
  409. user_by_id,
  410. api_key_by_access_key,
  411. cluster_names_by_id,
  412. )
  413. prompt_tokens, completion_tokens = _resolve_usage_tokens(
  414. metric, model_by_id.get(metric.model_id)
  415. )
  416. metric_date, _ = _resolve_metric_datetime(metric)
  417. model_usage = ModelUsage(
  418. date=metric_date,
  419. prompt_token_count=prompt_tokens,
  420. completion_token_count=completion_tokens,
  421. prompt_cached_token_count=metric.input_cached_token,
  422. request_count=metric.request_count,
  423. operation=metric.operation,
  424. **snapshot,
  425. )
  426. await create_or_update_model_usage(
  427. session, model_usage, auto_commit=False
  428. )
  429. for metric in detail_metrics:
  430. if not _validate_usage_metric(
  431. metric, model_by_id, provider_by_id, validated_user_ids
  432. ):
  433. continue
  434. snapshot = _build_metric_snapshot(
  435. metric,
  436. model_by_id,
  437. provider_by_id,
  438. user_by_id,
  439. api_key_by_access_key,
  440. cluster_names_by_id,
  441. )
  442. prompt_tokens, completion_tokens = _resolve_usage_tokens(
  443. metric, model_by_id.get(metric.model_id)
  444. )
  445. # Preserve the reported model_route_id verbatim — details
  446. # is FK-less by design (ModelUsageDetails docstring) so the
  447. # historical id stays audit-valuable even when the route
  448. # was deleted between request dispatch and this flush.
  449. # Name is best-effort from the live table; null when the
  450. # route is gone.
  451. model_route_id = metric.model_route_id
  452. model_route_name = route_name_by_id.get(metric.model_route_id)
  453. # cluster_id only lives on the audit/details rows, not on
  454. # the dashboard rollup (ModelUsage). Prefer the metric's
  455. # own cluster_id (captured at request time, survives model
  456. # deletes); fall back to the live model only when the
  457. # ingest source didn't carry one (older gateway clients).
  458. cluster_id = metric.cluster_id
  459. if cluster_id is None:
  460. cluster_id = getattr(
  461. model_by_id.get(metric.model_id), "cluster_id", None
  462. )
  463. metric_date, metric_dt = _resolve_metric_datetime(metric)
  464. started_dt = _unixmilli_to_naive_utc(metric.started_at)
  465. completed_dt = _unixmilli_to_naive_utc(metric.completed_at)
  466. session.add(
  467. ModelUsageDetails(
  468. date=metric_date,
  469. model_route_id=model_route_id,
  470. model_route_name=model_route_name,
  471. cluster_id=cluster_id,
  472. prompt_token_count=prompt_tokens,
  473. completion_token_count=completion_tokens,
  474. prompt_cached_token_count=metric.input_cached_token,
  475. operation=metric.operation,
  476. # Proxy-reported wall-clock — preserved as NULL when
  477. # the report didn't carry it, so reconciliation jobs
  478. # can tell estimated rows apart from authoritative
  479. # ones.
  480. started_at=started_dt,
  481. completed_at=completed_dt,
  482. # Audit timestamps still pinned to the request's
  483. # wall-clock so the row's lifecycle stamps don't
  484. # drift by the flush interval.
  485. created_at=metric_dt,
  486. updated_at=metric_dt,
  487. **snapshot,
  488. )
  489. )
  490. await session.commit()
  491. except Exception as e:
  492. logger.exception(f"Error storing gateway metrics: {e}")
  493. await session.rollback()
  494. # Propagate so flush_gateway_metrics can re-buffer the pending
  495. # records — without this, a transactional rollback silently
  496. # drops a flush window's worth of audit rows.
  497. raise