model_instances.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import asyncio
  2. from typing import Dict, List, Optional, Tuple
  3. import aiohttp
  4. from fastapi import APIRouter, Request, status, HTTPException
  5. from fastapi.responses import PlainTextResponse, StreamingResponse, RedirectResponse
  6. from urllib.parse import urlencode
  7. from sqlalchemy.orm import selectinload
  8. from gpustack.api.responses import StreamingResponseWithStatusCode
  9. from gpustack import envs
  10. from gpustack.server.services import ModelInstanceService
  11. from gpustack.server.worker_request import request_to_worker, stream_to_worker
  12. from gpustack.utils.network import use_proxy_env_for_url
  13. from gpustack.worker.logs import LogOptionsDep
  14. from gpustack.api.exceptions import (
  15. InternalServerErrorException,
  16. NotFoundException,
  17. )
  18. from gpustack.schemas.workers import Worker
  19. from gpustack.schemas.clusters import Cluster
  20. from gpustack.api.tenant import (
  21. bypass_tenant_filter,
  22. assert_resource_visible,
  23. tenant_list_conditions,
  24. )
  25. from gpustack.server.db import async_session
  26. from gpustack.server.deps import ListParamsDep, SessionDep, TenantContextDep
  27. from gpustack.schemas.models import (
  28. BackendEnum,
  29. Model,
  30. ModelInstance,
  31. ModelInstanceCreate,
  32. ModelInstanceLogOptions,
  33. ModelInstanceLogWorkerOption,
  34. ModelInstancePublic,
  35. ModelInstanceUpdate,
  36. ModelInstancesPublic,
  37. ModelInstanceStateEnum,
  38. ServeLogOptionsResponse,
  39. )
  40. from gpustack.schemas.model_files import ModelFileStateEnum
  41. from gpustack.config.config import get_global_config
  42. from gpustack.utils.grafana import resolve_grafana_base_url
  43. router = APIRouter()
  44. # Subordinate-worker display names, keyed by BackendEnum values.
  45. _SUBORDINATE_DISPLAY_NAMES: Dict[str, str] = {
  46. BackendEnum.VLLM: "ray-worker",
  47. }
  48. def _default_display_name(backend: Optional[str], is_main_worker: bool) -> str:
  49. """Resolve the UI display name for the internal 'default' container."""
  50. if is_main_worker:
  51. return backend or "default"
  52. if backend and backend in _SUBORDINATE_DISPLAY_NAMES:
  53. return _SUBORDINATE_DISPLAY_NAMES[backend]
  54. # Generic subordinate: "sub-<backend>" or just "subordinate".
  55. return f"sub-{backend}" if backend else "subordinate"
  56. def _map_container_display_name(
  57. internal_name: str, backend: Optional[str], is_main_worker: bool
  58. ) -> str:
  59. """Forward-map an internal container name to its UI display name."""
  60. if internal_name != "default":
  61. return internal_name
  62. return _default_display_name(backend, is_main_worker)
  63. def _unmap_container_display_name(
  64. display_name: str, backend: Optional[str], is_main_worker: bool
  65. ) -> str:
  66. """Reverse-map a UI display name back to the internal container name."""
  67. if display_name == _default_display_name(backend, is_main_worker):
  68. return "default"
  69. return display_name
  70. @router.get("", response_model=ModelInstancesPublic)
  71. async def get_model_instances(
  72. ctx: TenantContextDep,
  73. params: ListParamsDep,
  74. id: Optional[int] = None,
  75. model_id: Optional[int] = None,
  76. worker_id: Optional[int] = None,
  77. state: Optional[str] = None,
  78. ):
  79. fields = {}
  80. if id:
  81. fields["id"] = id
  82. if model_id:
  83. fields["model_id"] = model_id
  84. if worker_id:
  85. fields["worker_id"] = worker_id
  86. if state:
  87. fields["state"] = state
  88. # System users (workers, cluster service accounts) and admin in
  89. # "All" mode must see every Org's instances regardless of their
  90. # ``principal_id`` — otherwise a worker's awatch stream
  91. # would silently filter out instances scheduled to it on clusters
  92. # outside its Personal Org.
  93. if ctx.current_principal_id is not None and not bypass_tenant_filter(ctx):
  94. fields["owner_principal_id"] = ctx.current_principal_id
  95. if params.watch:
  96. return StreamingResponse(
  97. ModelInstance.streaming(fields=fields),
  98. media_type="text/event-stream",
  99. )
  100. async with async_session() as session:
  101. extra_conditions = tenant_list_conditions(ctx, ModelInstance)
  102. return await ModelInstance.paginated_by_query(
  103. session=session,
  104. fields=fields,
  105. extra_conditions=extra_conditions,
  106. page=params.page,
  107. per_page=params.perPage,
  108. )
  109. @router.get("/{id}", response_model=ModelInstancePublic)
  110. async def get_model_instance(
  111. session: SessionDep,
  112. ctx: TenantContextDep,
  113. id: int,
  114. ):
  115. model_instance = await ModelInstance.one_by_id(session, id)
  116. assert_resource_visible(
  117. ctx,
  118. model_instance,
  119. not_found_message="Model instance not found",
  120. )
  121. return model_instance
  122. @router.get("/{id}/dashboard")
  123. async def get_model_instance_dashboard(
  124. session: SessionDep,
  125. id: int,
  126. request: Request,
  127. ):
  128. model_instance = await ModelInstance.one_by_id(session, id)
  129. if not model_instance:
  130. raise NotFoundException(message="Model instance not found")
  131. cfg = get_global_config()
  132. if not cfg.get_grafana_url() or not cfg.grafana_model_dashboard_uid:
  133. raise InternalServerErrorException(
  134. message="Grafana dashboard settings are not configured"
  135. )
  136. cluster = None
  137. if model_instance.cluster_id is not None:
  138. cluster = await Cluster.one_by_id(session, model_instance.cluster_id)
  139. query_params = {}
  140. if cluster is not None:
  141. query_params["var-cluster_name"] = cluster.name
  142. query_params["var-model_name"] = model_instance.model_name
  143. query_params["var-model_instance_name"] = model_instance.name
  144. grafana_base = resolve_grafana_base_url(cfg, request)
  145. slug = "gpustack-model"
  146. dashboard_url = f"{grafana_base}/d/{cfg.grafana_model_dashboard_uid}/{slug}"
  147. if query_params:
  148. dashboard_url = f"{dashboard_url}?{urlencode(query_params)}"
  149. return RedirectResponse(url=dashboard_url, status_code=302)
  150. async def fetch_model_instance(session, id):
  151. model_instance = await ModelInstance.one_by_id(
  152. session, id, options=[selectinload(ModelInstance.model_files)]
  153. )
  154. if not model_instance:
  155. raise NotFoundException(message="Model instance not found")
  156. if not model_instance.worker_id:
  157. raise NotFoundException(message="Model instance not assigned to a worker")
  158. return model_instance
  159. async def fetch_worker(session, worker_id):
  160. worker = await Worker.one_by_id(session, worker_id)
  161. if not worker:
  162. raise NotFoundException(message="Model instance's worker not found")
  163. return worker
  164. @router.get("/{id}/logs")
  165. async def get_serving_logs( # noqa: C901
  166. request: Request,
  167. session: SessionDep,
  168. id: int,
  169. log_options: LogOptionsDep,
  170. worker_id: Optional[int] = None,
  171. container_name: Optional[str] = None,
  172. ):
  173. model_instance = await fetch_model_instance(session, id)
  174. # Reverse-map: convert UI display name back to internal container name.
  175. if container_name:
  176. is_main = (worker_id or model_instance.worker_id) == model_instance.worker_id
  177. container_name = _unmap_container_display_name(
  178. container_name, model_instance.backend, is_main
  179. )
  180. # Build valid worker IDs (main worker + subordinate workers for distributed instances)
  181. valid_worker_ids = {model_instance.worker_id}
  182. if (
  183. model_instance.distributed_servers
  184. and model_instance.distributed_servers.subordinate_workers
  185. ):
  186. valid_worker_ids.update(
  187. sw.worker_id
  188. for sw in model_instance.distributed_servers.subordinate_workers
  189. )
  190. # Determine target worker ID
  191. target_worker_id = worker_id or model_instance.worker_id
  192. if target_worker_id not in valid_worker_ids:
  193. raise NotFoundException(
  194. message=f"Worker {target_worker_id} not found for model instance {id}"
  195. )
  196. worker = await fetch_worker(session, target_worker_id)
  197. params = {
  198. "tail": log_options.tail,
  199. "follow": log_options.follow,
  200. "model_instance_name": model_instance.name,
  201. "previous": log_options.previous,
  202. }
  203. if container_name:
  204. params["container_name"] = container_name
  205. if (
  206. model_instance.state != ModelInstanceStateEnum.RUNNING
  207. and model_instance.model_files
  208. and model_instance.model_files[0].state != ModelFileStateEnum.READY
  209. ):
  210. params["model_file_id"] = model_instance.model_files[0].id
  211. timeout = aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT, sock_connect=5)
  212. if log_options.follow:
  213. def on_exception(e: Exception, t: aiohttp.ClientTimeout) -> tuple[str, int]:
  214. msg = (
  215. str(e)
  216. if not isinstance(e, TimeoutError)
  217. else f"Log stream timed out ({t.total} seconds). Please reopen the log page."
  218. )
  219. return f"\x1b[999;1H{msg}\n", status.HTTP_500_INTERNAL_SERVER_ERROR
  220. return StreamingResponseWithStatusCode(
  221. stream_to_worker(
  222. worker=worker,
  223. method="GET",
  224. path=f"serveLogs/{model_instance.id}",
  225. proxy_client=request.app.state.http_client,
  226. no_proxy_client=request.app.state.http_client_no_proxy,
  227. params=params,
  228. timeout=timeout,
  229. on_exception=on_exception,
  230. raw=True,
  231. ),
  232. media_type="application/octet-stream",
  233. )
  234. else:
  235. resp, body = await request_to_worker(
  236. worker=worker,
  237. method="GET",
  238. path=f"serveLogs/{model_instance.id}",
  239. proxy_client=request.app.state.http_client,
  240. no_proxy_client=request.app.state.http_client_no_proxy,
  241. params=params,
  242. timeout=timeout,
  243. )
  244. return PlainTextResponse(
  245. content=body.decode() if body else "", status_code=resp.status
  246. )
  247. async def resolve_instance_log_worker_targets(
  248. session, model_instance: ModelInstance
  249. ) -> List[Tuple[int, str, Optional[Worker]]]:
  250. """
  251. Ordered targets: main worker, then distributed subordinate workers.
  252. Worker may be None if the subordinate id is not present in DB (cannot proxy HTTP).
  253. """
  254. targets: List[Tuple[int, str, Optional[Worker]]] = []
  255. seen: set[int] = set()
  256. main_id = model_instance.worker_id
  257. if main_id is not None and main_id not in seen:
  258. main_worker = await fetch_worker(session, main_id)
  259. targets.append((main_id, main_worker.name or "", main_worker))
  260. seen.add(main_id)
  261. dservers = model_instance.distributed_servers
  262. if dservers and dservers.subordinate_workers:
  263. for sw in dservers.subordinate_workers:
  264. wid = sw.worker_id
  265. if wid is None or wid in seen:
  266. continue
  267. name = sw.worker_name or ""
  268. w = await Worker.one_by_id(session, wid)
  269. if not name:
  270. name = w.name if w else ""
  271. targets.append((wid, name or "", w))
  272. seen.add(wid)
  273. return targets
  274. async def fetch_serve_log_options_from_worker(
  275. request: Request,
  276. worker: Worker,
  277. model_instance_id: int,
  278. ) -> ServeLogOptionsResponse:
  279. log_options_url = (
  280. f"http://{worker.advertise_address}:{worker.port}/serveLogOptions"
  281. f"/{model_instance_id}"
  282. )
  283. timeout = aiohttp.ClientTimeout(total=envs.PROXY_TIMEOUT, sock_connect=5)
  284. use_proxy_env = use_proxy_env_for_url(log_options_url)
  285. client: aiohttp.ClientSession = (
  286. request.app.state.http_client
  287. if use_proxy_env
  288. else request.app.state.http_client_no_proxy
  289. )
  290. try:
  291. async with client.get(log_options_url, timeout=timeout) as resp:
  292. if resp.status != 200:
  293. raise ValueError(
  294. f"HTTP {resp.status}: error fetching model instance log options"
  295. )
  296. data = await resp.json()
  297. except ValueError:
  298. raise
  299. except Exception as e:
  300. raise ValueError(str(e)) from e
  301. return ServeLogOptionsResponse.model_validate(
  302. data if isinstance(data, dict) else {}
  303. )
  304. @router.get("/{id}/log-options", response_model=ModelInstanceLogOptions)
  305. async def get_model_instance_log_options(
  306. request: Request,
  307. session: SessionDep,
  308. id: int,
  309. ):
  310. """Return per-worker restart_count values that exist on disk for this model instance."""
  311. model_instance = await fetch_model_instance(session, id)
  312. targets = await resolve_instance_log_worker_targets(session, model_instance)
  313. async def fetch_one(
  314. target: Tuple[int, str, Optional[Worker]],
  315. ) -> ModelInstanceLogWorkerOption:
  316. wid, name, worker = target
  317. display_name = name
  318. if worker is None:
  319. return ModelInstanceLogWorkerOption(
  320. worker_id=wid,
  321. name=display_name,
  322. restarts=[],
  323. error="Worker not found in database",
  324. )
  325. if not display_name:
  326. display_name = worker.name or ""
  327. try:
  328. payload = await fetch_serve_log_options_from_worker(
  329. request, worker, model_instance.id
  330. )
  331. return ModelInstanceLogWorkerOption(
  332. worker_id=wid,
  333. name=display_name,
  334. restarts=payload.restarts,
  335. error=None,
  336. )
  337. except Exception as e:
  338. return ModelInstanceLogWorkerOption(
  339. worker_id=wid,
  340. name=display_name,
  341. restarts=[],
  342. error=str(e),
  343. )
  344. worker_options = await asyncio.gather(
  345. *[fetch_one(t) for t in targets],
  346. )
  347. for wo in worker_options:
  348. is_main = wo.worker_id == model_instance.worker_id
  349. for entry in wo.restarts:
  350. entry.containers = [
  351. _map_container_display_name(c, model_instance.backend, is_main)
  352. for c in entry.containers
  353. ]
  354. if worker_options and all(o.error for o in worker_options):
  355. detail = "; ".join(
  356. f"{o.worker_id}: {o.error}" for o in worker_options if o.error
  357. )
  358. raise HTTPException(
  359. status_code=502,
  360. detail=f"Failed to fetch log options from all workers: {detail}",
  361. )
  362. return ModelInstanceLogOptions(
  363. main_worker_id=model_instance.worker_id,
  364. workers=list(worker_options),
  365. )
  366. @router.post("", response_model=ModelInstancePublic)
  367. async def create_model_instance(
  368. session: SessionDep, model_instance_in: ModelInstanceCreate
  369. ):
  370. # Inherit the parent Model's tenant binding. The schema default of
  371. # PLATFORM_PRINCIPAL_ID would otherwise persist `owner_principal_id=1`
  372. # for instances of a non-platform Model whenever the caller (worker /
  373. # API client) doesn't echo the field back.
  374. if model_instance_in.model_id is not None:
  375. parent = await Model.one_by_id(session, model_instance_in.model_id)
  376. if parent is not None:
  377. model_instance_in.owner_principal_id = parent.owner_principal_id
  378. try:
  379. model_instance = await ModelInstance.create(session, model_instance_in)
  380. except Exception as e:
  381. raise InternalServerErrorException(
  382. message=f"Failed to create model instance: {e}"
  383. )
  384. return model_instance
  385. @router.put("/{id}", response_model=ModelInstancePublic)
  386. async def update_model_instance(
  387. session: SessionDep,
  388. ctx: TenantContextDep,
  389. id: int,
  390. model_instance_in: ModelInstanceUpdate,
  391. ):
  392. model_instance = await ModelInstance.one_by_id(session, id, for_update=True)
  393. assert_resource_visible(
  394. ctx,
  395. model_instance,
  396. not_found_message="Model instance not found",
  397. )
  398. try:
  399. await ModelInstanceService(session).update(model_instance, model_instance_in)
  400. except Exception as e:
  401. raise InternalServerErrorException(
  402. message=f"Failed to update model instance: {e}"
  403. )
  404. return model_instance
  405. @router.delete("/{id}")
  406. async def delete_model_instance(session: SessionDep, ctx: TenantContextDep, id: int):
  407. model_instance = await ModelInstance.one_by_id(session, id, for_update=True)
  408. assert_resource_visible(
  409. ctx,
  410. model_instance,
  411. not_found_message="Model instance not found",
  412. )
  413. try:
  414. await ModelInstanceService(session).delete(model_instance)
  415. except Exception as e:
  416. raise InternalServerErrorException(
  417. message=f"Failed to delete model instance: {e}"
  418. )