model_routes.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. import logging
  2. import secrets
  3. from datetime import datetime, timezone
  4. from sqlalchemy.orm import selectinload
  5. from sqlmodel import col, or_, select
  6. from sqlmodel.ext.asyncio.session import AsyncSession
  7. from typing import List, Optional, Set, Tuple, Union, Dict
  8. from fastapi import APIRouter, Depends, Query
  9. from fastapi.responses import StreamingResponse
  10. from gpustack.schemas.model_routes import (
  11. AccessPolicyEnum,
  12. ModelRoute,
  13. ModelRouteCreate,
  14. ModelRouteUpdate,
  15. ModelRoutePublic,
  16. ModelRoutesPublic,
  17. ModelRouteListParams,
  18. ModelRouteTarget,
  19. ModelRouteTargetUpdateItem,
  20. ModelRouteTargetUpdate,
  21. ModelRouteTargetPublic,
  22. ModelRouteTargetsPublic,
  23. ModelRouteTargetListParams,
  24. SetFallbackTargetInput,
  25. ModelAuthorizationList,
  26. ModelAuthorizationUpdate,
  27. ModelUserAccessExtended,
  28. MyModel,
  29. TargetStateEnum,
  30. )
  31. from gpustack.schemas.links import ModelRoutePrincipalLink
  32. from gpustack.schemas.organizations import PLATFORM_PRINCIPAL_ID
  33. from gpustack.schemas.model_provider import ModelProvider
  34. from gpustack.schemas.models import Model
  35. from gpustack.server.db import async_session
  36. from gpustack.server.deps import SessionDep, TenantContextDep
  37. from gpustack.api.tenant import (
  38. TenantContext,
  39. assert_resource_visible,
  40. tenant_list_conditions,
  41. )
  42. from gpustack.schemas.users import User
  43. from gpustack.api.exceptions import (
  44. AlreadyExistsException,
  45. InternalServerErrorException,
  46. NotFoundException,
  47. InvalidException,
  48. )
  49. from gpustack.server.services import (
  50. ModelRouteService,
  51. revoke_model_access_cache,
  52. )
  53. from gpustack.routes.model_common import (
  54. build_category_conditions,
  55. categories_filter,
  56. )
  57. logger = logging.getLogger(__name__)
  58. router = APIRouter()
  59. target_router = APIRouter()
  60. my_models_router = APIRouter()
  61. @router.get("", response_model=ModelRoutesPublic, response_model_exclude_none=True)
  62. async def get_model_routes(
  63. ctx: TenantContextDep,
  64. params: ModelRouteListParams = Depends(),
  65. name: str = None,
  66. search: str = None,
  67. categories: Optional[List[str]] = Query(None, description="Filter by categories."),
  68. ):
  69. return await _get_model_routes(
  70. ctx=ctx,
  71. params=params,
  72. name=name,
  73. search=search,
  74. categories=categories,
  75. )
  76. async def _get_model_routes(
  77. params: ModelRouteListParams,
  78. name: str = None,
  79. search: str = None,
  80. categories: Optional[List[str]] = None,
  81. user_id: Optional[int] = None,
  82. owner_principal_id: Optional[int] = None,
  83. target_class: Union[ModelRoute, MyModel] = ModelRoute,
  84. ctx: Optional[TenantContext] = None,
  85. ):
  86. fuzzy_fields = {}
  87. if search:
  88. fuzzy_fields = {"name": search}
  89. fields = {'deleted_at': None}
  90. if name:
  91. fields = {"name": name}
  92. if user_id is not None:
  93. fields["user_id"] = user_id
  94. if owner_principal_id is not None:
  95. fields["owner_principal_id"] = owner_principal_id
  96. # Apply tenant scoping to the streaming path too. Skipped for the MyModel
  97. # view which handles visibility through its own SQL view definition.
  98. if (
  99. ctx is not None
  100. and target_class is ModelRoute
  101. and ctx.current_principal_id is not None
  102. and "owner_principal_id" not in fields
  103. ):
  104. fields["owner_principal_id"] = ctx.current_principal_id
  105. if params.watch:
  106. return StreamingResponse(
  107. target_class.streaming(
  108. fields=fields,
  109. fuzzy_fields=fuzzy_fields,
  110. filter_func=lambda data: categories_filter(data, categories),
  111. ),
  112. media_type="text/event-stream",
  113. )
  114. async with async_session() as session:
  115. extra_conditions: list = []
  116. # Apply tenant scoping when caller passed a TenantContext. Per-user
  117. # visibility for ModelRoute is via the model_route_principals table.
  118. if ctx is not None and target_class is ModelRoute:
  119. extra_conditions.extend(tenant_list_conditions(ctx, ModelRoute))
  120. if categories:
  121. conditions = build_category_conditions(session, target_class, categories)
  122. extra_conditions.append(or_(*conditions))
  123. return await target_class.paginated_by_query(
  124. session=session,
  125. fields=fields,
  126. fuzzy_fields=fuzzy_fields,
  127. page=params.page,
  128. per_page=params.perPage,
  129. order_by=params.order_by,
  130. extra_conditions=extra_conditions,
  131. )
  132. @router.get("/{id}", response_model=ModelRoutePublic, response_model_exclude_none=True)
  133. async def get_model_route(
  134. session: SessionDep,
  135. ctx: TenantContextDep,
  136. id: int,
  137. ):
  138. return await _get_model_route(session=session, id=id, ctx=ctx)
  139. async def _get_model_route(
  140. session: AsyncSession,
  141. id: int,
  142. target_class: Union[ModelRoute, MyModel] = ModelRoute,
  143. user_id: Optional[int] = None,
  144. owner_principal_id: Optional[int] = None,
  145. ctx: Optional[TenantContext] = None,
  146. ):
  147. fields = {"id": id}
  148. if user_id is not None:
  149. fields["user_id"] = user_id
  150. if owner_principal_id is not None:
  151. fields["owner_principal_id"] = owner_principal_id
  152. existing = await target_class.one_by_fields(
  153. session=session,
  154. fields=fields,
  155. )
  156. if not existing or existing.deleted_at is not None:
  157. raise NotFoundException(f"ModelAccess with id '{id}' not found.")
  158. if ctx is not None and target_class is ModelRoute:
  159. assert_resource_visible(
  160. ctx,
  161. existing,
  162. not_found_message=f"ModelAccess with id '{id}' not found.",
  163. )
  164. return existing
  165. @router.post("", response_model=ModelRoutePublic, response_model_exclude_none=True)
  166. async def create_model_route(
  167. session: SessionDep, ctx: TenantContextDep, input: ModelRouteCreate
  168. ):
  169. # Names are unique within their owning Org. The gateway emits an
  170. # Org-slug prefix as the effective model name for non-platform Orgs,
  171. # so two Orgs can each have a route called "qwen3-0.6b" without
  172. # colliding in the AI proxy match rules.
  173. target_org_id = ctx.target_principal_id_for_write()
  174. existing = await ModelRoute.one_by_fields(
  175. session,
  176. {
  177. 'deleted_at': None,
  178. "name": input.name,
  179. "owner_principal_id": target_org_id,
  180. },
  181. )
  182. if existing:
  183. raise AlreadyExistsException(
  184. f"ModelRoute with name '{input.name}' already exists."
  185. )
  186. source = input.model_dump(exclude={"targets"})
  187. targets = input.targets or []
  188. await validate_targets(session, targets)
  189. source["targets"] = len(targets)
  190. # Stamp the route's owning org from the caller's tenant context.
  191. # ModelRouteBase defaults `owner_principal_id` to PLATFORM_PRINCIPAL_ID
  192. # so `model_dump()` always emits the key — `setdefault` would silently
  193. # keep it at 1 for non-platform admins. Override directly.
  194. if target_org_id is not None:
  195. source["owner_principal_id"] = target_org_id
  196. # Multi-tenant default: a non-platform Org's new route is scoped to
  197. # that Org (ORG policy — `non_admin_user_models` matches by the
  198. # route's `owner_principal_id`). The Default (platform) Org keeps
  199. # AUTHED — admin's shared catalog stays visible to every
  200. # authenticated user, and existing routes migrated to the platform
  201. # Org must keep working. Caller's explicit `access_policy` always
  202. # wins.
  203. owner_org_id = source.get("owner_principal_id")
  204. is_platform_org = owner_org_id == PLATFORM_PRINCIPAL_ID
  205. if (
  206. not is_platform_org
  207. and owner_org_id is not None
  208. and "access_policy" not in input.model_fields_set
  209. ):
  210. source["access_policy"] = AccessPolicyEnum.ORG
  211. try:
  212. route: ModelRoute = await ModelRoute.create(
  213. session=session, source=source, auto_commit=False
  214. )
  215. await create_model_route_targets(
  216. session=session,
  217. route_id=route.id,
  218. route_name=route.name,
  219. targets=targets,
  220. auto_commit=False,
  221. )
  222. await session.commit()
  223. await session.refresh(route)
  224. await revoke_model_access_cache(session=session)
  225. return route
  226. except Exception as e:
  227. await session.rollback()
  228. raise InternalServerErrorException(
  229. f"Failed to create ModelAccess '{input.name}': {e}"
  230. )
  231. @router.put("/{id}", response_model=ModelRoutePublic, response_model_exclude_none=True)
  232. async def update_model_route(
  233. id: int,
  234. session: SessionDep,
  235. ctx: TenantContextDep,
  236. input: ModelRouteUpdate,
  237. ):
  238. existing = await ModelRoute.one_by_id(
  239. session=session,
  240. id=id,
  241. )
  242. if not existing or existing.deleted_at is not None:
  243. raise NotFoundException(f"ModelRoute with id '{id}' not found.")
  244. assert_resource_visible(
  245. ctx,
  246. existing,
  247. not_found_message=f"ModelRoute with id '{id}' not found.",
  248. )
  249. # Names are unique within their owning Org (effective name on the
  250. # gateway side carries the Org slug prefix for non-platform Orgs).
  251. duplicated_name = await ModelRoute.one_by_fields(
  252. session,
  253. {
  254. 'deleted_at': None,
  255. "name": input.name,
  256. "owner_principal_id": existing.owner_principal_id,
  257. },
  258. )
  259. if duplicated_name and duplicated_name.id != id:
  260. raise AlreadyExistsException(
  261. f"ModelRoute with name '{input.name}' already exists."
  262. )
  263. existing_name = existing.name
  264. input_name = input.name
  265. input_data = input.model_dump(exclude={"targets"}, include=input.model_fields_set)
  266. try:
  267. if input.targets is not None or input.name != existing.name:
  268. target_count, _ = await batch_handle_targets(
  269. session=session,
  270. route_id=existing.id,
  271. route_name=existing.name,
  272. targets=input.targets,
  273. auto_commit=False,
  274. new_route_name=input.name if input.name != existing.name else None,
  275. )
  276. input_data["targets"] = target_count
  277. await ModelRouteService(session).update(
  278. existing, source=input_data, auto_commit=False
  279. )
  280. await session.commit()
  281. if existing_name != input_name:
  282. await revoke_model_access_cache(session=session)
  283. except Exception as e:
  284. raise InternalServerErrorException(f"Failed to update ModelRoute '{id}': {e}")
  285. return await ModelRoute.one_by_id(session=session, id=id)
  286. @router.delete("/{id}")
  287. async def delete_model_route(
  288. id: int,
  289. session: SessionDep,
  290. ctx: TenantContextDep,
  291. ):
  292. existing = await ModelRoute.one_by_id(
  293. session=session,
  294. id=id,
  295. options=[selectinload(ModelRoute.route_targets)],
  296. )
  297. if not existing or existing.deleted_at is not None:
  298. raise NotFoundException(f"ModelRoute with id '{id}' not found.")
  299. assert_resource_visible(
  300. ctx,
  301. existing,
  302. not_found_message=f"ModelRoute with id '{id}' not found.",
  303. )
  304. try:
  305. await revoke_model_access_cache(session=session, model=existing)
  306. await ModelRouteService(session).delete(existing)
  307. except Exception as e:
  308. raise InternalServerErrorException(f"Failed to delete ModelRoute '{id}': {e}")
  309. async def unset_fallback_target(
  310. session: AsyncSession,
  311. route_id: int,
  312. auto_commit: bool = False,
  313. ):
  314. targets = await ModelRouteTarget.all_by_field(
  315. session=session,
  316. field="route_id",
  317. value=route_id,
  318. for_update=True,
  319. )
  320. for target in targets:
  321. if target.fallback_status_codes is not None and target.deleted_at is None:
  322. target.fallback_status_codes = None
  323. await target.update(session=session, auto_commit=auto_commit)
  324. @router.post(
  325. "/{id}/add-targets",
  326. response_model=List[ModelRouteTargetPublic],
  327. response_model_exclude_none=True,
  328. )
  329. async def add_model_route_targets(
  330. id: int,
  331. session: SessionDep,
  332. targets: List[ModelRouteTargetUpdateItem],
  333. ):
  334. route = await ModelRoute.one_by_id(session=session, id=id)
  335. if not route or route.deleted_at is not None:
  336. raise NotFoundException(f"ModelRoute with id '{id}' not found.")
  337. target_count, created_targets = await batch_handle_targets(
  338. session=session,
  339. route_id=route.id,
  340. route_name=route.name,
  341. new_route_name=None,
  342. targets=targets,
  343. auto_commit=False,
  344. )
  345. try:
  346. route.targets = target_count
  347. await ModelRouteService(session=session).update(route, auto_commit=True)
  348. await session.commit()
  349. for target in created_targets:
  350. await session.refresh(target)
  351. return created_targets
  352. except Exception as e:
  353. raise InternalServerErrorException(
  354. f"Failed to add targets to ModelRoute '{id}': {e}"
  355. )
  356. async def batch_handle_targets(
  357. session: AsyncSession,
  358. route_id: int,
  359. route_name: str,
  360. targets: List[ModelRouteTargetUpdateItem],
  361. auto_commit: bool = True,
  362. new_route_name: Optional[str] = None,
  363. ) -> Tuple[int, List[ModelRouteTarget]]:
  364. existing_targets = await ModelRouteTarget.all_by_field(
  365. session=session,
  366. field="route_id",
  367. value=route_id,
  368. for_update=True,
  369. )
  370. target_count = len(existing_targets)
  371. existing_target_map = {target.id: target for target in existing_targets}
  372. invalid_target_ids = [
  373. target.id
  374. for target in targets
  375. if target.id is not None and target.id not in existing_target_map
  376. ]
  377. if len(invalid_target_ids) > 0:
  378. raise NotFoundException(
  379. f"ModelRouteTargets with ids '{', '.join(map(str, invalid_target_ids))}' not found."
  380. )
  381. target_count += len([target for target in targets if target.id is None])
  382. to_delete_target_ids = [
  383. target.id
  384. for target in existing_targets
  385. if target.id not in [e.id for e in targets if e.id is not None]
  386. ]
  387. target_count -= len(to_delete_target_ids)
  388. fallback_index = await validate_targets(session=session, targets=targets)
  389. if fallback_index is not None:
  390. fallback_target = targets[fallback_index]
  391. if fallback_target.id is None:
  392. await unset_fallback_target(session, route_id, auto_commit=auto_commit)
  393. targets_to_return = []
  394. try:
  395. # Delete
  396. for target_id in to_delete_target_ids:
  397. target = existing_target_map[target_id]
  398. await target.delete(session=session, auto_commit=auto_commit)
  399. # Update
  400. updated_targets = await update_model_route_targets(
  401. session=session,
  402. targets=targets,
  403. existing_target_map=existing_target_map,
  404. new_route_name=new_route_name,
  405. auto_commit=auto_commit,
  406. )
  407. targets_to_return.extend(updated_targets)
  408. # Create
  409. created_targets = await create_model_route_targets(
  410. session=session,
  411. route_id=route_id,
  412. route_name=new_route_name or route_name,
  413. targets=targets,
  414. auto_commit=auto_commit,
  415. )
  416. targets_to_return.extend(created_targets)
  417. except Exception as e:
  418. raise InternalServerErrorException(
  419. f"Failed to batch handle ModelRouteTargets: {e}"
  420. )
  421. return target_count, targets_to_return
  422. async def update_model_route_targets(
  423. session: AsyncSession,
  424. targets: List[ModelRouteTargetUpdateItem],
  425. existing_target_map: Dict[int, ModelRouteTarget],
  426. new_route_name: Optional[str] = None,
  427. auto_commit: bool = False,
  428. ) -> List[ModelRouteTarget]:
  429. to_update_target_map: Dict[int, ModelRouteTargetUpdateItem] = {
  430. target.id: target
  431. for target in targets
  432. if target.id is not None and target.id in existing_target_map
  433. }
  434. targets_to_return = []
  435. for id, existing_target in existing_target_map.items():
  436. if new_route_name is None and id not in to_update_target_map:
  437. continue
  438. to_compare_fields = {
  439. "route_name",
  440. "provider_model_name",
  441. "weight",
  442. "model_id",
  443. "provider_id",
  444. "fallback_status_codes",
  445. }
  446. existing_dict = existing_target.model_dump(
  447. include=to_compare_fields, exclude_none=True
  448. )
  449. input_target = to_update_target_map.get(id, None)
  450. input_dict = {**existing_dict}
  451. if input_target is not None:
  452. input_dict.update(
  453. input_target.model_dump(include=to_compare_fields, exclude_none=True)
  454. )
  455. if new_route_name is not None:
  456. input_dict["route_name"] = new_route_name
  457. update_source = {}
  458. if existing_dict != input_dict:
  459. # set state to UNAVAILABLE to force re-validate on next use
  460. update_source.update(
  461. {
  462. **input_dict,
  463. "state": TargetStateEnum.UNAVAILABLE,
  464. }
  465. )
  466. if len(update_source) > 0:
  467. updated = await existing_target.update(
  468. session=session, source=update_source, auto_commit=auto_commit
  469. )
  470. targets_to_return.append(updated)
  471. return targets_to_return
  472. async def create_model_route_targets(
  473. session: AsyncSession,
  474. route_id: int,
  475. route_name: str,
  476. targets: List[ModelRouteTargetUpdateItem],
  477. auto_commit: bool = True,
  478. ) -> List[ModelRouteTarget]:
  479. created_targets = []
  480. for target in targets:
  481. if target.id is not None:
  482. continue
  483. route_target = ModelRouteTarget.model_validate(
  484. {
  485. **target.model_dump(),
  486. "route_id": route_id,
  487. "name": route_name + "-" + secrets.token_hex(5),
  488. "route_name": route_name,
  489. }
  490. )
  491. if route_target.model_id is not None:
  492. route_target.state = TargetStateEnum.UNAVAILABLE
  493. route_target: ModelRouteTarget = await ModelRouteTarget.create(
  494. session=session, source=route_target, auto_commit=auto_commit
  495. )
  496. created_targets.append(route_target)
  497. if auto_commit:
  498. await session.commit()
  499. for target in created_targets:
  500. await session.refresh(target)
  501. return created_targets
  502. async def validate_targets(
  503. session: SessionDep,
  504. targets: List[ModelRouteTargetUpdateItem],
  505. ) -> Optional[int]:
  506. fallback_index: Optional[int] = None
  507. for index, target in enumerate(targets):
  508. if (
  509. target.fallback_status_codes is not None
  510. and len(target.fallback_status_codes) > 0
  511. ):
  512. if fallback_index is not None:
  513. raise InvalidException(
  514. "Only one target can be set as fallback for status codes."
  515. )
  516. fallback_index = index
  517. if target.provider_id is not None:
  518. provider = await ModelProvider.one_by_id(
  519. session=session, id=target.provider_id
  520. )
  521. if provider is None or provider.deleted_at is not None:
  522. raise NotFoundException(
  523. f"ModelProvider with id '{target.provider_id}' not found."
  524. )
  525. validate_provider_model_name(provider, target.provider_model_name)
  526. elif target.model_id is not None:
  527. model = await Model.one_by_id(session=session, id=target.model_id)
  528. if model is None or model.deleted_at is not None:
  529. raise NotFoundException(f"Model with id '{target.model_id}' not found.")
  530. return fallback_index
  531. def validate_provider_model_name(
  532. provider: ModelProvider,
  533. model_name: str,
  534. ):
  535. supported_models = provider.models or []
  536. model_names = [model.name for model in supported_models]
  537. if model_name not in model_names:
  538. raise InvalidException(
  539. f"provider_model_name '{model_name}' is not supported by provider '{provider.name}'. Supported models: {', '.join(model_names)}"
  540. )
  541. @target_router.get(
  542. "", response_model=ModelRouteTargetsPublic, response_model_exclude_none=True
  543. )
  544. async def get_model_route_targets(
  545. session: SessionDep,
  546. params: ModelRouteTargetListParams = Depends(),
  547. name: str = None,
  548. search: str = None,
  549. ):
  550. fuzzy_fields = {}
  551. if search:
  552. fuzzy_fields = {"name": search}
  553. fields = {'deleted_at': None}
  554. if name:
  555. fields = {"name": name}
  556. ext_fields = params.model_dump(
  557. include={
  558. "route_id",
  559. "route_name",
  560. "model_id",
  561. "provider_id",
  562. },
  563. exclude_none=True,
  564. )
  565. fields.update(ext_fields)
  566. if params.watch:
  567. return StreamingResponse(
  568. ModelRouteTarget.streaming(fields=fields, fuzzy_fields=fuzzy_fields),
  569. media_type="text/event-stream",
  570. )
  571. return await ModelRouteTarget.paginated_by_query(
  572. session=session,
  573. fields=fields,
  574. fuzzy_fields=fuzzy_fields,
  575. page=params.page,
  576. per_page=params.perPage,
  577. order_by=params.order_by,
  578. )
  579. @target_router.put(
  580. "/{id}",
  581. response_model=ModelRouteTargetPublic,
  582. response_model_exclude_none=True,
  583. )
  584. async def update_model_route_target(
  585. id: int,
  586. session: SessionDep,
  587. input: ModelRouteTargetUpdate,
  588. ):
  589. existing = await ModelRouteTarget.one_by_id(
  590. session=session,
  591. id=id,
  592. )
  593. if not existing or existing.deleted_at is not None:
  594. raise NotFoundException(f"ModelRouteTarget with id '{id}' not found.")
  595. # don't need to update fallback_status_codes here, handled in set-fallback target
  596. targets = [
  597. ModelRouteTargetUpdateItem.model_validate(
  598. {
  599. **input.model_dump(),
  600. "id": id,
  601. "fallback_status_codes": existing.fallback_status_codes,
  602. }
  603. )
  604. ]
  605. await validate_targets(session, targets)
  606. try:
  607. await update_model_route_targets(
  608. session=session,
  609. targets=targets,
  610. existing_target_map={id: existing},
  611. auto_commit=True,
  612. )
  613. except Exception as e:
  614. raise InternalServerErrorException(
  615. f"Failed to update ModelRouteTarget '{id}': {e}"
  616. )
  617. return await ModelRouteTarget.one_by_id(session=session, id=id)
  618. @target_router.delete("/{id}")
  619. async def delete_model_route_target(
  620. id: int,
  621. session: SessionDep,
  622. ):
  623. existing = await ModelRouteTarget.one_by_id(
  624. session=session,
  625. id=id,
  626. )
  627. if not existing or existing.deleted_at is not None:
  628. raise NotFoundException(f"ModelRouteTarget with id '{id}' not found.")
  629. route = existing.model_route
  630. try:
  631. await existing.delete(session=session, auto_commit=False)
  632. if route:
  633. route.targets = max(0, route.targets - 1)
  634. await ModelRouteService(session=session).update(route, auto_commit=False)
  635. await session.commit()
  636. except Exception as e:
  637. await session.rollback()
  638. raise InternalServerErrorException(
  639. f"Failed to delete ModelRouteTarget '{id}': {e}"
  640. )
  641. @target_router.post(
  642. "/{id}/set-fallback",
  643. response_model=ModelRouteTargetPublic,
  644. response_model_exclude_none=True,
  645. )
  646. async def set_fallback_target(
  647. id: int,
  648. session: SessionDep,
  649. input: SetFallbackTargetInput,
  650. ):
  651. existing = await ModelRouteTarget.one_by_id(
  652. session=session,
  653. id=id,
  654. )
  655. if not existing or existing.deleted_at is not None:
  656. raise NotFoundException(f"ModelRouteTarget with id '{id}' not found.")
  657. if existing.fallback_status_codes == input.fallback_status_codes:
  658. return existing
  659. try:
  660. if input.fallback_status_codes is not None:
  661. await unset_fallback_target(session, existing.route_id, auto_commit=False)
  662. existing.fallback_status_codes = input.fallback_status_codes
  663. await existing.update(session=session, auto_commit=False)
  664. await session.commit()
  665. except Exception as e:
  666. await session.rollback()
  667. raise InternalServerErrorException(
  668. f"Failed to set fallback status codes for ModelRouteTarget '{id}': {e}"
  669. )
  670. return await ModelRouteTarget.one_by_id(session=session, id=id)
  671. async def _list_route_users(session, route_id: int) -> List[ModelUserAccessExtended]:
  672. """Build the OSS-facing access list for a route.
  673. User-only ACL rows live in ``model_route_principals`` with
  674. ``user_id`` set; we join them with ``users`` so the response can
  675. carry display-only fields (``username`` / ``full_name`` /
  676. ``avatar_url``) without an extra round trip from the client.
  677. """
  678. stmt = (
  679. select(User, ModelRoutePrincipalLink)
  680. .join(
  681. ModelRoutePrincipalLink,
  682. ModelRoutePrincipalLink.user_id == User.id,
  683. )
  684. .where(
  685. ModelRoutePrincipalLink.route_id == route_id,
  686. ModelRoutePrincipalLink.user_id.is_not(None),
  687. )
  688. )
  689. rows = (await session.exec(stmt)).all()
  690. return [
  691. ModelUserAccessExtended(
  692. id=user.id,
  693. username=user.username,
  694. full_name=user.full_name,
  695. avatar_url=user.avatar_url,
  696. )
  697. for user, _ in rows
  698. ]
  699. async def _replace_route_user_principals(
  700. session, route_id: int, user_ids: List[int]
  701. ) -> None:
  702. """Replace the user-grant rows on a route with exactly ``user_ids``.
  703. Touches only rows where ``user_id IS NOT NULL`` — org / group
  704. grants set by the enterprise UI's ALLOWED_PRINCIPALS flow are left
  705. alone, even if the OSS UI happens to call this endpoint on the
  706. same route.
  707. """
  708. existing_stmt = select(ModelRoutePrincipalLink).where(
  709. ModelRoutePrincipalLink.route_id == route_id,
  710. ModelRoutePrincipalLink.user_id.is_not(None),
  711. )
  712. existing = list((await session.exec(existing_stmt)).all())
  713. existing_by_user = {row.user_id: row for row in existing}
  714. desired = set(user_ids)
  715. for user_id, row in existing_by_user.items():
  716. if user_id not in desired:
  717. await session.delete(row)
  718. now = datetime.now(timezone.utc).replace(tzinfo=None)
  719. for user_id in desired:
  720. if user_id in existing_by_user:
  721. continue
  722. session.add(
  723. ModelRoutePrincipalLink(
  724. route_id=route_id,
  725. user_id=user_id,
  726. principal=f"user:{user_id}",
  727. created_at=now,
  728. updated_at=now,
  729. )
  730. )
  731. @router.get("/{id}/access", response_model=ModelAuthorizationList)
  732. async def get_model_authorization_list(session: SessionDep, id: int):
  733. model: ModelRoute = await ModelRoute.one_by_id(session, id)
  734. if not model:
  735. raise NotFoundException(message="Model not found")
  736. return ModelAuthorizationList(items=await _list_route_users(session, id))
  737. @router.post("/{id}/access", response_model=ModelAuthorizationList)
  738. async def add_model_authorization(
  739. session: SessionDep, id: int, access_request: ModelAuthorizationUpdate
  740. ):
  741. model = await ModelRoute.one_by_id(session, id)
  742. if not model:
  743. raise NotFoundException(message="Model not found")
  744. requested_user_ids = [u.id for u in access_request.users]
  745. if requested_user_ids:
  746. users = await User.all_by_fields(
  747. session=session,
  748. fields={},
  749. extra_conditions=[col(User.id).in_(requested_user_ids)],
  750. )
  751. existing_user_ids = {u.id for u in users}
  752. for req_id in requested_user_ids:
  753. if req_id not in existing_user_ids:
  754. raise NotFoundException(message=f"User ID {req_id} not found")
  755. # Cache invalidation needs the union of "previously granted" and
  756. # "newly granted" user ids — anyone in either set may see a
  757. # different model list after the change.
  758. previous_users = await _list_route_users(session, id)
  759. affected_user_ids: Optional[Set[int]] = {item.id for item in previous_users} | set(
  760. requested_user_ids
  761. )
  762. cache_model = model
  763. if (
  764. access_request.access_policy is not None
  765. and access_request.access_policy != model.access_policy
  766. ):
  767. model.access_policy = access_request.access_policy
  768. # Switching policy (e.g. to PUBLIC) widens visibility beyond
  769. # the explicit user list — broaden cache invalidation.
  770. affected_user_ids = None
  771. cache_model = None
  772. try:
  773. await _replace_route_user_principals(session, id, requested_user_ids)
  774. await revoke_model_access_cache(
  775. session=session,
  776. model=cache_model,
  777. extra_user_ids=affected_user_ids,
  778. )
  779. await ModelRouteService(session).update(model)
  780. except Exception as e:
  781. await session.rollback()
  782. raise InternalServerErrorException(message=f"Failed to add model access: {e}")
  783. return ModelAuthorizationList(items=await _list_route_users(session, id))
  784. @my_models_router.get("", response_model=ModelRoutesPublic)
  785. async def get_my_models(
  786. ctx: TenantContextDep,
  787. params: ModelRouteListParams = Depends(),
  788. search: str = None,
  789. categories: Optional[List[str]] = Query(None, description="Filter by categories."),
  790. ):
  791. """List the model routes available to the calling user.
  792. For non-admin users: visibility is governed by `non_admin_user_models`,
  793. which already encodes PUBLIC/AUTHED/ORG/ALLOWED_PRINCIPALS semantics. We do NOT additionally filter by current_principal_id — routes
  794. published cross-org via ALLOWED_PRINCIPALS would otherwise be hidden.
  795. For platform admins: optionally filter by org if a context was provided.
  796. """
  797. user = ctx.user
  798. user_id = None
  799. target_class = ModelRoute
  800. owner_principal_id = None
  801. if not user.is_admin:
  802. target_class = MyModel
  803. user_id = user.id
  804. else:
  805. # Admin can opt into a per-org view by setting the org context.
  806. owner_principal_id = ctx.current_principal_id
  807. return await _get_model_routes(
  808. params=params,
  809. search=search,
  810. categories=categories,
  811. target_class=target_class,
  812. user_id=user_id,
  813. owner_principal_id=owner_principal_id,
  814. )
  815. @my_models_router.get("/{id}", response_model=ModelRoutePublic)
  816. async def get_my_model(
  817. session: SessionDep,
  818. id: int,
  819. ctx: TenantContextDep,
  820. ):
  821. user = ctx.user
  822. user_id = None
  823. target_class = ModelRoute
  824. owner_principal_id = None
  825. if not user.is_admin:
  826. target_class = MyModel
  827. user_id = user.id
  828. else:
  829. owner_principal_id = ctx.current_principal_id
  830. return await _get_model_route(
  831. session=session,
  832. id=id,
  833. user_id=user_id,
  834. owner_principal_id=owner_principal_id,
  835. target_class=target_class,
  836. )