model_provider.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. import httpx
  2. import logging
  3. import hashlib
  4. from urllib.parse import urlparse
  5. from typing import List, Dict, Any, Optional, Union
  6. from datetime import datetime, timezone
  7. from sqlalchemy.orm import selectinload
  8. from fastapi import APIRouter, Depends
  9. from fastapi.responses import StreamingResponse
  10. from gpustack.schemas.model_provider import (
  11. MaskedAPIToken,
  12. ModelProvider,
  13. ModelProviderCreate,
  14. ModelProviderUpdate,
  15. ModelProviderPublic,
  16. ModelProvidersPublic,
  17. ModelProviderListParams,
  18. ProviderModelsInput,
  19. ModelProviderTypeEnum,
  20. TestProviderModelInput,
  21. TestProviderModelResult,
  22. ProviderModel,
  23. OpenAIConfig,
  24. )
  25. from gpustack.schemas.models import CategoryEnum
  26. from gpustack.schemas.model_routes import ModelRouteTarget
  27. from gpustack.api.exceptions import (
  28. AlreadyExistsException,
  29. InternalServerErrorException,
  30. NotFoundException,
  31. InvalidException,
  32. )
  33. from gpustack.api.tenant import (
  34. assert_resource_visible,
  35. tenant_list_conditions,
  36. )
  37. from gpustack.server.db import async_session
  38. from gpustack.server.deps import SessionDep, TenantContextDep
  39. from openai.types import Model as OAIModel
  40. from openai.pagination import SyncPage
  41. router = APIRouter()
  42. logger = logging.getLogger(__name__)
  43. ANTHROPIC_API_VERSION = "2023-06-01"
  44. @router.get("", response_model=ModelProvidersPublic, response_model_exclude_none=True)
  45. async def get_model_providers(
  46. ctx: TenantContextDep,
  47. params: ModelProviderListParams = Depends(),
  48. name: str = None,
  49. search: str = None,
  50. ):
  51. fuzzy_fields = {}
  52. if search:
  53. fuzzy_fields = {"name": search}
  54. fields = {'deleted_at': None}
  55. if name:
  56. fields = {"name": name}
  57. extra_conditions = list(tenant_list_conditions(ctx, ModelProvider))
  58. if params.watch:
  59. return StreamingResponse(
  60. ModelProvider.streaming(fields=fields, fuzzy_fields=fuzzy_fields),
  61. media_type="text/event-stream",
  62. )
  63. async with async_session() as session:
  64. provider_list = await ModelProvider.paginated_by_query(
  65. session=session,
  66. fields=fields,
  67. fuzzy_fields=fuzzy_fields,
  68. extra_conditions=extra_conditions,
  69. page=params.page,
  70. per_page=params.perPage,
  71. order_by=params.order_by,
  72. )
  73. provider_list.items = [
  74. ModelProvider._convert_to_public_class(provider)
  75. for provider in provider_list.items
  76. ]
  77. return provider_list
  78. def validate_provider(provider: Union[ModelProviderCreate, ModelProviderUpdate]):
  79. if provider.config is not None and len(provider.config.model_extra or {}) > 0:
  80. raise InvalidException(
  81. message=f"fields {', '.join(provider.config.model_extra.keys())} are not allowed in {provider.config.type.value} config"
  82. )
  83. try:
  84. provider.config.check_required_fields()
  85. except ValueError as e:
  86. raise InvalidException(message=f"{e}")
  87. if len(provider.api_tokens) > 1:
  88. llm_model = next(
  89. (model for model in provider.models or [] if model.category == "llm"),
  90. None,
  91. )
  92. if not llm_model:
  93. raise InvalidException(
  94. message="At least one llm model is required when api_tokens has more than 1 token for failover"
  95. )
  96. if len(provider.models or []) == 0:
  97. raise InvalidException(message="At least one model is required for a provider")
  98. if isinstance(provider.config, OpenAIConfig) and provider.config.openaiCustomUrl:
  99. parsed_url = urlparse(provider.config.openaiCustomUrl.rstrip("/"))
  100. if parsed_url.path == "":
  101. raise InvalidException(
  102. message=f"openaiCustomUrl {provider.config.openaiCustomUrl} is invalid, it must include a path, e.g. http://my-openai.com/v1"
  103. )
  104. def parse_api_tokens(
  105. existing_tokens: List[str], api_tokens: List[MaskedAPIToken]
  106. ) -> List[str]:
  107. target_tokens = []
  108. hashed_token_dict = {
  109. hashlib.sha256(token.encode()).hexdigest(): token for token in existing_tokens
  110. }
  111. for index, api_token in enumerate(api_tokens):
  112. token_value = api_token.input
  113. if api_token.hash is not None:
  114. token_value = hashed_token_dict.get(api_token.hash)
  115. if not token_value or not token_value.strip():
  116. raise InvalidException(
  117. message=f"API token at index {index} is invalid, empty, or does not match any existing token"
  118. )
  119. target_tokens.append(token_value)
  120. return target_tokens
  121. @router.post("", response_model=ModelProviderPublic, response_model_exclude_none=True)
  122. async def create_model_provider(
  123. session: SessionDep, ctx: TenantContextDep, input: ModelProviderCreate
  124. ):
  125. # Provider names are unique per Org (admin's Global providers
  126. # — owner_principal_id=NULL — form their own namespace).
  127. existing = await ModelProvider.one_by_fields(
  128. session,
  129. {
  130. 'deleted_at': None,
  131. "name": input.name,
  132. "owner_principal_id": ctx.current_principal_id,
  133. },
  134. )
  135. if existing:
  136. raise AlreadyExistsException(message=f"provider {input.name} already exists")
  137. validate_provider(input)
  138. input_dict = input.model_dump(exclude={"api_tokens", "clone_from_id"})
  139. existing_tokens = []
  140. if input.clone_from_id is not None:
  141. clone_from = await ModelProvider.one_by_id(
  142. session=session,
  143. id=input.clone_from_id,
  144. )
  145. if not clone_from:
  146. raise NotFoundException(
  147. message=f"provider {input.clone_from_id} to clone from not found"
  148. )
  149. existing_tokens = clone_from.api_tokens or []
  150. input_dict["api_tokens"] = parse_api_tokens(
  151. existing_tokens=existing_tokens, api_tokens=input.api_tokens
  152. )
  153. # Tenant scope: bind to caller's current Org. Platform admin in
  154. # "All" mode (no current_principal_id) creates global (NULL).
  155. input_dict["owner_principal_id"] = ctx.current_principal_id
  156. try:
  157. created = await ModelProvider.create(session=session, source=input_dict)
  158. return ModelProvider._convert_to_public_class(created)
  159. except Exception as e:
  160. raise InternalServerErrorException(
  161. message=f"Failed to create provider {input.name}: {e}"
  162. )
  163. @router.get(
  164. "/{id}", response_model=ModelProviderPublic, response_model_exclude_none=True
  165. )
  166. async def get_model_provider(session: SessionDep, ctx: TenantContextDep, id: int):
  167. provider = await ModelProvider.one_by_id(session=session, id=id)
  168. assert_resource_visible(
  169. ctx,
  170. provider,
  171. not_found_message=f"provider {id} not found",
  172. )
  173. return ModelProvider._convert_to_public_class(provider)
  174. def deleted_model_names(
  175. existing_models: List[ProviderModel],
  176. input_models: List[ProviderModel],
  177. ) -> List[str]:
  178. input_model_names = {model.name for model in input_models}
  179. deleted_names = [
  180. model.name for model in existing_models if model.name not in input_model_names
  181. ]
  182. return deleted_names
  183. @router.put(
  184. "/{id}", response_model=ModelProviderPublic, response_model_exclude_none=True
  185. )
  186. async def update_model_provider(
  187. session: SessionDep,
  188. ctx: TenantContextDep,
  189. id: int,
  190. input: ModelProviderUpdate,
  191. ):
  192. provider = await ModelProvider.one_by_id(session=session, id=id)
  193. assert_resource_visible(
  194. ctx,
  195. provider,
  196. not_found_message=f"provider {id} not found",
  197. )
  198. validate_provider(input)
  199. deleted_models = deleted_model_names(provider.models or [], input.models or [])
  200. try:
  201. input_dict = input.model_dump(exclude={"api_tokens"})
  202. if input.api_tokens is not None:
  203. input_dict["api_tokens"] = parse_api_tokens(
  204. existing_tokens=provider.api_tokens or [],
  205. api_tokens=input.api_tokens,
  206. )
  207. await provider.update(
  208. session=session, source=input_dict, auto_commit=len(deleted_models) == 0
  209. )
  210. if len(deleted_models) > 0:
  211. routes = await ModelRouteTarget.all_by_fields(
  212. session=session,
  213. fields={"provider_id": id},
  214. extra_conditions=[
  215. ModelRouteTarget.provider_model_name.in_(deleted_models)
  216. ],
  217. )
  218. for route in routes:
  219. await route.delete(session=session, auto_commit=False)
  220. await session.commit()
  221. except Exception as e:
  222. raise InternalServerErrorException(
  223. message=f"Failed to update provider {id}: {e}"
  224. )
  225. updated_provider = await ModelProvider.one_by_id(session=session, id=id)
  226. return ModelProvider._convert_to_public_class(updated_provider)
  227. @router.delete("/{id}")
  228. async def delete_model_provider(session: SessionDep, ctx: TenantContextDep, id: int):
  229. existing = await ModelProvider.one_by_id(
  230. session=session,
  231. id=id,
  232. options=[selectinload(ModelProvider.model_route_targets)],
  233. )
  234. if not existing or existing.deleted_at is not None:
  235. raise NotFoundException(message=f"provider {id} not found")
  236. assert_resource_visible(
  237. ctx,
  238. existing,
  239. not_found_message=f"provider {id} not found",
  240. )
  241. try:
  242. await existing.delete(session=session)
  243. except Exception as e:
  244. raise InternalServerErrorException(
  245. message=f"Failed to delete provider {id}: {e}"
  246. )
  247. def get_model_name(model: Dict[str, Any]) -> Optional[str]:
  248. return model.get("id", model.get("name", None))
  249. categories_to_infer = [
  250. CategoryEnum.IMAGE,
  251. CategoryEnum.EMBEDDING,
  252. CategoryEnum.RERANKER,
  253. ]
  254. category_values = {e.value for e in CategoryEnum}
  255. def determine_model_category(
  256. provider_type: ModelProviderTypeEnum,
  257. model: Dict[str, Any],
  258. ) -> List[str]:
  259. if provider_type == ModelProviderTypeEnum.DOUBAO:
  260. domain: str = model.get("domain", "").lower()
  261. if domain in category_values:
  262. return [domain]
  263. model_id: str = get_model_name(model) or ""
  264. model_name = model_id.rsplit("/", 1)[-1]
  265. for category_enum in categories_to_infer:
  266. if category_enum.value in model_name:
  267. return [category_enum.value]
  268. return [CategoryEnum.LLM.value]
  269. class CustomOAIModel(OAIModel):
  270. categories: Optional[List[str]] = None
  271. @router.post("/get-models")
  272. async def get_models_from_provider(
  273. input: ProviderModelsInput,
  274. ):
  275. if input.api_token is None or input.config is None:
  276. raise InvalidException(
  277. message="api_token and config are required to fetch models from provider"
  278. )
  279. result = SyncPage[CustomOAIModel](data=[], object="list")
  280. try:
  281. input.config.check_required_fields()
  282. except ValueError as e:
  283. logger.error(f"{e}")
  284. raise InvalidException(message=f"{e}")
  285. base_url, model_uri = input.config.get_model_url()
  286. if not base_url or not model_uri:
  287. logger.warning(
  288. f"provider type {input.config.type} not supported for fetching models"
  289. )
  290. return result
  291. data = []
  292. async with httpx.AsyncClient(
  293. base_url=base_url,
  294. proxy=input.proxy_url,
  295. trust_env=True,
  296. ) as client:
  297. headers = {}
  298. if input.config.type == ModelProviderTypeEnum.CLAUDE:
  299. headers["X-API-Key"] = input.api_token
  300. headers["anthropic-version"] = (
  301. getattr(input.config, "claudeVersion", None) or ANTHROPIC_API_VERSION
  302. )
  303. else:
  304. headers["Authorization"] = f"Bearer {input.api_token}"
  305. try:
  306. response = await client.get(url=model_uri, headers=headers, timeout=30)
  307. response.raise_for_status()
  308. content = response.json()
  309. data: List[Dict[str, Any]] = content.get("data") or []
  310. except httpx.HTTPStatusError as exc:
  311. raise InvalidException(
  312. message=f"Failed to get models from {input.config.type}: {exc.response.status_code} {exc.response.text}"
  313. )
  314. except httpx.RequestError as exc:
  315. raise InternalServerErrorException(
  316. message=f"Network error: {exc.__class__.__name__}: {exc}"
  317. )
  318. fallback_created = int(datetime.now(timezone.utc).timestamp())
  319. for item in data:
  320. if input.config.type == ModelProviderTypeEnum.DOUBAO:
  321. status = item.get("status", None)
  322. if status is not None:
  323. continue
  324. model_id = get_model_name(item)
  325. if not model_id:
  326. continue
  327. categories = determine_model_category(input.config.type, item)
  328. model = CustomOAIModel(
  329. id=model_id,
  330. created=item.get("created") or fallback_created,
  331. object=item.get("object") or "model",
  332. owned_by=item.get("owned_by") or input.config.type.value,
  333. categories=categories,
  334. )
  335. result.data.append(model)
  336. return result
  337. @router.post("/{id}/get-models")
  338. async def get_models_from_specific_provider(
  339. session: SessionDep,
  340. ctx: TenantContextDep,
  341. id: int,
  342. input: ProviderModelsInput,
  343. ):
  344. provider = await ModelProvider.one_by_id(session=session, id=id)
  345. if not provider or provider.deleted_at is not None:
  346. raise NotFoundException(message=f"provider {id} not found")
  347. assert_resource_visible(
  348. ctx,
  349. provider,
  350. not_found_message=f"provider {id} not found",
  351. )
  352. if provider.api_tokens is None or len(provider.api_tokens) == 0:
  353. raise InvalidException(
  354. message=f"provider {provider.name} id: {id} has no API tokens configured"
  355. )
  356. proxy_url = (
  357. input.proxy_url if 'proxy_url' in input.model_fields_set else provider.proxy_url
  358. )
  359. return await get_models_from_provider(
  360. ProviderModelsInput(
  361. api_token=input.api_token or provider.api_tokens[0],
  362. config=input.config or provider.config,
  363. proxy_url=proxy_url,
  364. )
  365. )
  366. def _get_model_output_token_dict(model_name: str) -> Dict[str, Any]:
  367. name = model_name.lower().rsplit("/", 1)[-1]
  368. max_token_key = (
  369. "max_completion_tokens"
  370. if name.startswith(("gpt-5", "o1", "o3", "o4"))
  371. else "max_tokens"
  372. )
  373. return {max_token_key: 16}
  374. @router.post(
  375. "/test-model",
  376. response_model=TestProviderModelResult,
  377. response_model_exclude_none=True,
  378. )
  379. async def try_model_with_provider(
  380. input: TestProviderModelInput,
  381. ):
  382. if input.api_token is None or input.config is None:
  383. raise InvalidException(
  384. message="api_token and config are required to fetch models from provider"
  385. )
  386. endpoint, completion_url = input.config.get_chat_url()
  387. if not endpoint or not completion_url:
  388. raise InvalidException(
  389. message=f"provider type {input.config.type} does not support testing model accessibility"
  390. )
  391. max_output_token_dict = _get_model_output_token_dict(input.model_name)
  392. data = {
  393. "model": input.model_name,
  394. "messages": [{"role": "user", "content": "Ping"}],
  395. **max_output_token_dict,
  396. }
  397. if input.config.type == ModelProviderTypeEnum.QWEN:
  398. data["enable_thinking"] = False
  399. async with httpx.AsyncClient(
  400. base_url=f"{endpoint}",
  401. proxy=input.proxy_url,
  402. trust_env=True,
  403. ) as client:
  404. headers = {}
  405. if input.config.type == ModelProviderTypeEnum.CLAUDE:
  406. headers["X-API-Key"] = input.api_token
  407. headers["anthropic-version"] = (
  408. getattr(input.config, "claudeVersion", None) or ANTHROPIC_API_VERSION
  409. )
  410. else:
  411. headers["Authorization"] = f"Bearer {input.api_token}"
  412. try:
  413. response = await client.post(
  414. url=completion_url, json=data, headers=headers, timeout=60
  415. )
  416. response.raise_for_status()
  417. return TestProviderModelResult(
  418. model_name=input.model_name,
  419. accessible=True,
  420. )
  421. except httpx.HTTPStatusError as exc:
  422. return TestProviderModelResult(
  423. model_name=input.model_name,
  424. accessible=False,
  425. error_message=f"Provider API error: {exc.response.status_code} {exc.response.text}",
  426. )
  427. except httpx.RequestError as exc:
  428. raise InternalServerErrorException(
  429. message=f"Network error: {exc.__class__.__name__}: {exc}"
  430. )
  431. @router.post(
  432. "/{id}/test-model",
  433. response_model=TestProviderModelResult,
  434. response_model_exclude_none=True,
  435. )
  436. async def try_model_with_specific_provider(
  437. session: SessionDep,
  438. ctx: TenantContextDep,
  439. id: int,
  440. input: TestProviderModelInput,
  441. ):
  442. provider = await ModelProvider.one_by_id(session=session, id=id)
  443. if not provider or provider.deleted_at is not None:
  444. raise NotFoundException(message=f"provider {id} not found")
  445. assert_resource_visible(
  446. ctx,
  447. provider,
  448. not_found_message=f"provider {id} not found",
  449. )
  450. if provider.api_tokens is None or len(provider.api_tokens) == 0:
  451. raise InvalidException(
  452. message=f"provider {provider.name} id: {id} has no API tokens configured"
  453. )
  454. proxy_url = (
  455. input.proxy_url if 'proxy_url' in input.model_fields_set else provider.proxy_url
  456. )
  457. return await try_model_with_provider(
  458. TestProviderModelInput(
  459. api_token=input.api_token or provider.api_tokens[0],
  460. config=input.config or provider.config,
  461. proxy_url=proxy_url,
  462. model_name=input.model_name,
  463. )
  464. )