dashboard.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. from datetime import date, datetime, timedelta, timezone
  2. from typing import Dict, List, Optional
  3. from fastapi import APIRouter, Query
  4. from sqlmodel import desc, distinct, select, func, col, and_, or_
  5. from sqlmodel.ext.asyncio.session import AsyncSession
  6. from gpustack.schemas.common import ItemList
  7. from gpustack.schemas.dashboard import (
  8. CurrentSystemLoad,
  9. HistorySystemLoad,
  10. ModelSummary,
  11. ModelUsageStats,
  12. ModelUsageSummary,
  13. ModelUsageUserSummary,
  14. ResourceClaim,
  15. ResourceCounts,
  16. SystemLoadSummary,
  17. SystemSummary,
  18. TimeSeriesData,
  19. )
  20. from gpustack.schemas.model_usage import ModelUsage
  21. from gpustack.schemas.models import Model, ModelInstance
  22. from gpustack.schemas.system_load import SystemLoad
  23. from gpustack.schemas.users import User
  24. from gpustack.server.deps import SessionDep
  25. from gpustack.schemas import Worker, Cluster
  26. from gpustack.schemas.model_provider import ModelProvider
  27. from gpustack.server.system_load import compute_system_load
  28. router = APIRouter()
  29. @router.get("")
  30. async def dashboard(
  31. session: SessionDep,
  32. cluster_id: Optional[int] = None,
  33. ):
  34. resoruce_counts = await get_resource_counts(session, cluster_id)
  35. system_load = await get_system_load(session, cluster_id)
  36. model_usage = await get_model_usage_summary(session, cluster_id)
  37. active_models = await get_active_models(session, cluster_id)
  38. summary = SystemSummary(
  39. cluster_id=cluster_id,
  40. resource_counts=resoruce_counts,
  41. system_load=system_load,
  42. model_usage=model_usage,
  43. active_models=active_models,
  44. )
  45. return summary
  46. async def get_resource_counts(
  47. session: AsyncSession, cluster_id: Optional[int] = None
  48. ) -> ResourceCounts:
  49. fields = {}
  50. cluster_count = None
  51. if cluster_id is not None:
  52. fields['cluster_id'] = cluster_id
  53. else:
  54. clusters = await Cluster.all_by_field(session, field="deleted_at", value=None)
  55. cluster_count = len(clusters)
  56. workers = await Worker.all_by_fields(
  57. session,
  58. fields=fields,
  59. )
  60. worker_count = len(workers)
  61. gpu_count = 0
  62. for worker in workers:
  63. gpu_count += len(worker.status.gpu_devices or [])
  64. models = await Model.all_by_fields(session, fields=fields)
  65. model_count = len(models)
  66. model_instances = await ModelInstance.all_by_fields(session, fields=fields)
  67. model_instance_count = len(model_instances)
  68. return ResourceCounts(
  69. cluster_count=cluster_count,
  70. worker_count=worker_count,
  71. gpu_count=gpu_count,
  72. model_count=model_count,
  73. model_instance_count=model_instance_count,
  74. )
  75. async def get_system_load(
  76. session: AsyncSession, cluster_id: Optional[int] = None
  77. ) -> SystemLoadSummary:
  78. fields = {}
  79. if cluster_id is not None:
  80. fields['cluster_id'] = cluster_id
  81. workers = await Worker.all_by_fields(session, fields=fields)
  82. current_system_loads = compute_system_load(workers)
  83. current_system_load = next(
  84. (load for load in current_system_loads if load.cluster_id == cluster_id),
  85. SystemLoad(
  86. cluster_id=cluster_id,
  87. cpu=0,
  88. ram=0,
  89. gpu=0,
  90. vram=0,
  91. ),
  92. )
  93. now = datetime.now(timezone.utc)
  94. one_hour_ago = int((now - timedelta(hours=1)).timestamp())
  95. statement = select(SystemLoad)
  96. statement = statement.where(SystemLoad.cluster_id == cluster_id)
  97. statement = statement.where(SystemLoad.timestamp >= one_hour_ago)
  98. system_loads = (await session.exec(statement)).all()
  99. cpu = []
  100. ram = []
  101. gpu = []
  102. vram = []
  103. for system_load in system_loads:
  104. cpu.append(
  105. TimeSeriesData(
  106. timestamp=system_load.timestamp,
  107. value=system_load.cpu,
  108. )
  109. )
  110. ram.append(
  111. TimeSeriesData(
  112. timestamp=system_load.timestamp,
  113. value=system_load.ram,
  114. )
  115. )
  116. gpu.append(
  117. TimeSeriesData(
  118. timestamp=system_load.timestamp,
  119. value=system_load.gpu,
  120. )
  121. )
  122. vram.append(
  123. TimeSeriesData(
  124. timestamp=system_load.timestamp,
  125. value=system_load.vram,
  126. )
  127. )
  128. cpu.sort(key=lambda x: x.timestamp, reverse=False)
  129. ram.sort(key=lambda x: x.timestamp, reverse=False)
  130. gpu.sort(key=lambda x: x.timestamp, reverse=False)
  131. vram.sort(key=lambda x: x.timestamp, reverse=False)
  132. return SystemLoadSummary(
  133. current=CurrentSystemLoad(
  134. cpu=current_system_load.cpu,
  135. ram=current_system_load.ram,
  136. gpu=current_system_load.gpu,
  137. vram=current_system_load.vram,
  138. ),
  139. history=HistorySystemLoad(
  140. cpu=cpu,
  141. ram=ram,
  142. gpu=gpu,
  143. vram=vram,
  144. ),
  145. )
  146. async def get_model_usage_stats(
  147. session: AsyncSession,
  148. start_date: Optional[date] = None,
  149. end_date: Optional[date] = None,
  150. model_ids: Optional[List[int]] = None,
  151. user_ids: Optional[List[int]] = None,
  152. cluster_id: Optional[int] = None,
  153. provider_model_names: Optional[Dict[int, Optional[List[str]]]] = None,
  154. ) -> ModelUsageStats:
  155. if start_date is None or end_date is None:
  156. end_date = date.today()
  157. start_date = end_date - timedelta(days=31)
  158. if model_ids is None and cluster_id is not None:
  159. models = await Model.all_by_fields(session, fields={"cluster_id": cluster_id})
  160. model_ids = [model.id for model in models]
  161. statement = (
  162. select(
  163. ModelUsage.date,
  164. func.sum(ModelUsage.prompt_token_count).label('total_prompt_tokens'),
  165. func.sum(ModelUsage.completion_token_count).label(
  166. 'total_completion_tokens'
  167. ),
  168. func.sum(ModelUsage.request_count).label('total_requests'),
  169. )
  170. .where(ModelUsage.date >= start_date)
  171. .where(ModelUsage.date <= end_date)
  172. .group_by(ModelUsage.date)
  173. .order_by(ModelUsage.date)
  174. )
  175. or_conditions = []
  176. if model_ids is not None:
  177. or_conditions.append(col(ModelUsage.model_id).in_(model_ids))
  178. for provider_id, model_names in (provider_model_names or {}).items():
  179. if provider_id is not None:
  180. and_conds = [col(ModelUsage.provider_id) == provider_id]
  181. if model_names:
  182. and_conds.append(col(ModelUsage.model_name).in_(model_names))
  183. or_conditions.append(and_(*and_conds))
  184. if or_conditions:
  185. statement = statement.where(or_(*or_conditions))
  186. if user_ids is not None:
  187. statement = statement.where(col(ModelUsage.user_id).in_(user_ids))
  188. results = (await session.exec(statement)).all()
  189. prompt_token_history = []
  190. completion_token_history = []
  191. api_request_history = []
  192. for result in results:
  193. prompt_token_history.append(
  194. TimeSeriesData(
  195. timestamp=int(
  196. datetime.combine(result.date, datetime.min.time()).timestamp()
  197. ),
  198. value=result.total_prompt_tokens,
  199. )
  200. )
  201. completion_token_history.append(
  202. TimeSeriesData(
  203. timestamp=int(
  204. datetime.combine(result.date, datetime.min.time()).timestamp()
  205. ),
  206. value=result.total_completion_tokens,
  207. )
  208. )
  209. api_request_history.append(
  210. TimeSeriesData(
  211. timestamp=int(
  212. datetime.combine(result.date, datetime.min.time()).timestamp()
  213. ),
  214. value=result.total_requests,
  215. )
  216. )
  217. return ModelUsageStats(
  218. api_request_history=api_request_history,
  219. prompt_token_history=prompt_token_history,
  220. completion_token_history=completion_token_history,
  221. )
  222. async def get_model_usage_summary(
  223. session: AsyncSession, cluster_id: Optional[int] = None
  224. ) -> ModelUsageSummary:
  225. model_usage_stats = await get_model_usage_stats(session, cluster_id=cluster_id)
  226. # get top users
  227. today = date.today()
  228. one_month_ago = today - timedelta(days=31)
  229. statement = (
  230. select(
  231. ModelUsage.user_id,
  232. User.username,
  233. func.sum(ModelUsage.prompt_token_count).label('total_prompt_tokens'),
  234. func.sum(ModelUsage.completion_token_count).label(
  235. 'total_completion_tokens'
  236. ),
  237. )
  238. .join(User, ModelUsage.user_id == User.id)
  239. .where(ModelUsage.date >= one_month_ago)
  240. .group_by(ModelUsage.user_id, User.username)
  241. .order_by(
  242. func.sum(
  243. ModelUsage.prompt_token_count + ModelUsage.completion_token_count
  244. ).desc()
  245. )
  246. .limit(10)
  247. )
  248. results = (await session.exec(statement)).all()
  249. top_users = []
  250. for result in results:
  251. top_users.append(
  252. ModelUsageUserSummary(
  253. user_id=result.user_id,
  254. username=result.username,
  255. prompt_token_count=result.total_prompt_tokens,
  256. completion_token_count=result.total_completion_tokens,
  257. )
  258. )
  259. return ModelUsageSummary(
  260. api_request_history=model_usage_stats.api_request_history,
  261. prompt_token_history=model_usage_stats.prompt_token_history,
  262. completion_token_history=model_usage_stats.completion_token_history,
  263. top_users=top_users,
  264. )
  265. async def _get_maas_active_models(session: AsyncSession) -> List[ModelSummary]:
  266. all_providers = await ModelProvider.all_by_field(
  267. session, field="deleted_at", value=None
  268. )
  269. if not all_providers:
  270. return []
  271. provider_ids = [p.id for p in all_providers]
  272. total_tokens = func.sum(
  273. ModelUsage.prompt_token_count + ModelUsage.completion_token_count
  274. )
  275. # Aggregate model usage in the database for efficiency
  276. statement = (
  277. select(
  278. ModelUsage.provider_id,
  279. ModelUsage.model_name,
  280. total_tokens.label("total_token_count"),
  281. )
  282. .where(col(ModelUsage.provider_id).in_(provider_ids))
  283. .group_by(ModelUsage.provider_id, ModelUsage.model_name)
  284. .order_by(func.coalesce(total_tokens, 0).desc())
  285. .limit(10)
  286. )
  287. top_model_usages = (await session.exec(statement)).all()
  288. models_by_provider_and_name = {
  289. (p.id, m.name): m for p in all_providers for m in (p.models or [])
  290. }
  291. provider_id_to_name = {p.id: p.name for p in all_providers}
  292. model_summaries = []
  293. for usage in top_model_usages:
  294. model = models_by_provider_and_name.get((usage.provider_id, usage.model_name))
  295. model_summaries.append(
  296. ModelSummary(
  297. provider_id=usage.provider_id,
  298. provider_name=provider_id_to_name.get(
  299. usage.provider_id, "Unknown Provider"
  300. ),
  301. name=usage.model_name,
  302. instance_count=0,
  303. token_count=int(usage.total_token_count or 0),
  304. categories=([model.category] if model and model.category else None),
  305. )
  306. )
  307. return model_summaries
  308. async def _get_gpustack_active_models(
  309. session: AsyncSession, cluster_id: Optional[int] = None
  310. ) -> List[ModelSummary]:
  311. statement = active_model_statement(cluster_id=cluster_id)
  312. results = (await session.exec(statement)).all()
  313. top_model_ids = [result.id for result in results]
  314. extra_conditions = [
  315. col(ModelInstance.model_id).in_(top_model_ids),
  316. ]
  317. model_instances: List[ModelInstance] = await ModelInstance.all_by_fields(
  318. session, fields={}, extra_conditions=extra_conditions
  319. )
  320. model_instances_by_id: Dict[int, List[ModelInstance]] = {}
  321. for model_instance in model_instances:
  322. if model_instance.model_id not in model_instances_by_id:
  323. model_instances_by_id[model_instance.model_id] = []
  324. model_instances_by_id[model_instance.model_id].append(model_instance)
  325. model_summary = []
  326. for result in results:
  327. # We need to summarize the resource claims for all model instances including distributed servers.
  328. # It's complicated to do this in a SQL statement, so we do it in Python.
  329. resource_claim = ResourceClaim(
  330. ram=0,
  331. vram=0,
  332. )
  333. if result.id in model_instances_by_id:
  334. for model_instance in model_instances_by_id[result.id]:
  335. aggregate_resource_claim(resource_claim, model_instance)
  336. model_summary.append(
  337. ModelSummary(
  338. id=result.id,
  339. name=result.name,
  340. categories=result.categories,
  341. resource_claim=resource_claim,
  342. instance_count=result.instance_count,
  343. token_count=(
  344. result.total_token_count
  345. if result.total_token_count is not None
  346. else 0
  347. ),
  348. )
  349. )
  350. return model_summary
  351. async def get_active_models(
  352. session: AsyncSession, cluster_id: Optional[int] = None
  353. ) -> List[ModelSummary]:
  354. summary = await _get_gpustack_active_models(session, cluster_id)
  355. if cluster_id is None:
  356. maas_active_models = await _get_maas_active_models(session)
  357. summary.extend(maas_active_models)
  358. summary.sort(key=lambda x: x.token_count, reverse=True)
  359. summary = summary[:10]
  360. return summary
  361. def aggregate_resource_claim(
  362. resource_claim: ResourceClaim,
  363. model_instance: ModelInstance,
  364. ):
  365. if model_instance.computed_resource_claim is not None:
  366. resource_claim.ram += model_instance.computed_resource_claim.ram or 0
  367. for vram in (model_instance.computed_resource_claim.vram or {}).values():
  368. resource_claim.vram += vram
  369. if (
  370. model_instance.distributed_servers
  371. and model_instance.distributed_servers.subordinate_workers
  372. ):
  373. for subworker in model_instance.distributed_servers.subordinate_workers:
  374. if subworker.computed_resource_claim is not None:
  375. resource_claim.ram += subworker.computed_resource_claim.ram or 0
  376. for vram in (subworker.computed_resource_claim.vram or {}).values():
  377. resource_claim.vram += vram
  378. def active_model_statement(cluster_id: Optional[int]) -> select:
  379. usage_sum_query = (
  380. select(
  381. Model.id.label('model_id'),
  382. func.sum(
  383. ModelUsage.prompt_token_count + ModelUsage.completion_token_count
  384. ).label('total_token_count'),
  385. )
  386. .outerjoin(ModelUsage, Model.id == ModelUsage.model_id)
  387. .group_by(Model.id)
  388. ).alias('usage_sum')
  389. statement = select(
  390. Model.id,
  391. Model.name,
  392. Model.categories,
  393. func.count(distinct(ModelInstance.id)).label('instance_count'),
  394. usage_sum_query.c.total_token_count,
  395. )
  396. if cluster_id is not None:
  397. statement = statement.where(Model.cluster_id == cluster_id)
  398. statement = (
  399. statement.join(ModelInstance, Model.id == ModelInstance.model_id)
  400. .outerjoin(usage_sum_query, Model.id == usage_sum_query.c.model_id)
  401. .group_by(
  402. Model.id,
  403. usage_sum_query.c.total_token_count,
  404. )
  405. .order_by(func.coalesce(usage_sum_query.c.total_token_count, 0).desc())
  406. .limit(10)
  407. )
  408. return statement
  409. async def get_model_usages(
  410. session: AsyncSession,
  411. start_date: Optional[date] = None,
  412. end_date: Optional[date] = None,
  413. model_ids: Optional[List[int]] = None,
  414. user_ids: Optional[List[int]] = None,
  415. provider_model_names: Optional[Dict[int, Optional[List[str]]]] = None,
  416. ) -> List[ModelUsage]:
  417. if start_date is None or end_date is None:
  418. end_date = date.today()
  419. start_date = end_date - timedelta(days=31)
  420. statement = (
  421. select(ModelUsage)
  422. .where(ModelUsage.date >= start_date)
  423. .where(ModelUsage.date <= end_date)
  424. )
  425. or_conditions = []
  426. if model_ids is not None:
  427. or_conditions.append(col(ModelUsage.model_id).in_(model_ids))
  428. for provider_id, model_names in (provider_model_names or {}).items():
  429. if provider_id is not None:
  430. and_conds = [col(ModelUsage.provider_id) == provider_id]
  431. if model_names:
  432. and_conds.append(col(ModelUsage.model_name).in_(model_names))
  433. or_conditions.append(and_(*and_conds))
  434. if or_conditions:
  435. statement = statement.where(or_(*or_conditions))
  436. if user_ids is not None:
  437. statement = statement.where(col(ModelUsage.user_id).in_(user_ids))
  438. statement = statement.order_by(
  439. desc(ModelUsage.date),
  440. ModelUsage.user_id,
  441. ModelUsage.completion_token_count,
  442. )
  443. return (await session.exec(statement)).all()
  444. def get_models_by_provider_id(
  445. provider_model_names: List[str],
  446. ) -> Optional[Dict[int, Optional[List[str]]]]:
  447. model_names_by_provider_id = {}
  448. for id_prefix_name in provider_model_names or []:
  449. if ":" not in id_prefix_name:
  450. continue
  451. id_str, name = id_prefix_name.split(":", 1)
  452. try:
  453. provider_id = int(id_str)
  454. except ValueError:
  455. continue
  456. names: List[str] = model_names_by_provider_id.setdefault(provider_id, [])
  457. names.extend([name] if name else [])
  458. return model_names_by_provider_id if len(model_names_by_provider_id) > 0 else None
  459. @router.get("/usage")
  460. async def usage(
  461. session: SessionDep,
  462. start_date: Optional[date] = Query(
  463. None,
  464. description="Start date for the usage data (YYYY-MM-DD). Defaults to 31 days ago.",
  465. ),
  466. end_date: Optional[date] = Query(
  467. None, description="End date for the usage data (YYYY-MM-DD). Defaults to today."
  468. ),
  469. model_ids: Optional[List[int]] = Query(
  470. None,
  471. description="Filter by model IDs. Defaults to all models.",
  472. ),
  473. user_ids: Optional[List[int]] = Query(
  474. None, description="Filter by user IDs. Defaults to all users."
  475. ),
  476. provider_model_names: Optional[List[str]] = Query(
  477. None,
  478. description="Filter by provider and model names. Format is 'provider_id:model_name'. To filter by provider ID only, use 'provider_id:'. Defaults to no filtering.",
  479. ),
  480. ):
  481. """
  482. Get model usage records.
  483. This endpoint returns detailed model usage records within a specified date range.
  484. """
  485. items = await get_model_usages(
  486. session,
  487. start_date=start_date,
  488. end_date=end_date,
  489. model_ids=model_ids,
  490. user_ids=user_ids,
  491. provider_model_names=get_models_by_provider_id(provider_model_names or []),
  492. )
  493. return ItemList[ModelUsage](items=items)
  494. @router.get("/usage/stats")
  495. async def usage_stats(
  496. session: SessionDep,
  497. start_date: Optional[date] = Query(
  498. None,
  499. description="Start date for the usage data (YYYY-MM-DD). Defaults to 31 days ago.",
  500. ),
  501. end_date: Optional[date] = Query(
  502. None, description="End date for the usage data (YYYY-MM-DD). Defaults to today."
  503. ),
  504. model_ids: Optional[List[int]] = Query(
  505. None,
  506. description="Filter by model IDs. Defaults to all models.",
  507. ),
  508. user_ids: Optional[List[int]] = Query(
  509. None, description="Filter by user IDs. Defaults to all users."
  510. ),
  511. provider_model_names: Optional[List[str]] = Query(
  512. None,
  513. description="Filter by provider and model names. Format is 'provider_id:model_name'. To filter by provider ID only, use 'provider_id:'. Defaults to no filtering.",
  514. ),
  515. ):
  516. """
  517. Get model usage statistics.
  518. This endpoint returns aggregated statistics for model usage, including token counts and request counts.
  519. It can filter by date range, model IDs, user IDs, model names with provider ID prefix.
  520. """
  521. return await get_model_usage_stats(
  522. session,
  523. start_date=start_date,
  524. end_date=end_date,
  525. model_ids=model_ids,
  526. user_ids=user_ids,
  527. provider_model_names=get_models_by_provider_id(provider_model_names or []),
  528. )