model_provider.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. import hashlib
  2. from typing import Tuple
  3. from urllib.parse import urlparse
  4. from enum import Enum
  5. from typing import (
  6. ClassVar,
  7. Optional,
  8. List,
  9. Union,
  10. TYPE_CHECKING,
  11. Literal,
  12. Mapping,
  13. Dict,
  14. Any,
  15. )
  16. from pydantic import (
  17. BaseModel,
  18. ConfigDict,
  19. field_validator,
  20. model_validator,
  21. Field as PydanticField,
  22. )
  23. from sqlmodel import (
  24. Field,
  25. Column,
  26. ForeignKey,
  27. Integer,
  28. JSON,
  29. SQLModel,
  30. Relationship,
  31. )
  32. from gpustack.mixins import BaseModelMixin
  33. from gpustack.schemas.common import (
  34. PublicFields,
  35. ListParams,
  36. PaginatedList,
  37. pydantic_column_type,
  38. )
  39. if TYPE_CHECKING:
  40. from gpustack.schemas.model_routes import ModelRouteTarget
  41. # The provider types should be synced with higress ai-proxy supported providers
  42. class ModelProviderTypeEnum(str, Enum):
  43. AI360 = "ai360"
  44. AZURE = "azure"
  45. BAICHUAN = "baichuan"
  46. BAIDU = "baidu"
  47. BEDROCK = "bedrock"
  48. CLAUDE = "claude"
  49. CLOUDFLARE = "cloudflare"
  50. COHERE = "cohere"
  51. COZE = "coze"
  52. DEEPL = "deepl"
  53. DEEPSEEK = "deepseek"
  54. DIFY = "dify"
  55. DOUBAO = "doubao"
  56. FIREWORKS = "fireworks"
  57. GEMINI = "gemini"
  58. GENERIC = "generic"
  59. GITHUB = "github"
  60. GROK = "grok"
  61. GROQ = "groq"
  62. HUNYUAN = "hunyuan"
  63. LONGCAT = "longcat"
  64. MINIMAX = "minimax"
  65. MISTRAL = "mistral"
  66. MOONSHOT = "moonshot"
  67. OLLAMA = "ollama"
  68. OPENAI = "openai"
  69. OPENROUTER = "openrouter"
  70. QWEN = "qwen"
  71. SPARK = "spark"
  72. STEPFUN = "stepfun"
  73. TOGETHERAI = "together-ai"
  74. TRITON = "triton"
  75. YI = "yi"
  76. ZHIPUAI = "zhipuai"
  77. # following types are not supported yet
  78. # For vertex, It has more complex configuration than other providers. Keep it unsupported for now.
  79. # VERTEX = "vertex"
  80. # For vllm, most of the vllm provider functions can be replaced by open-ai compatible provider.
  81. # VLLM = "vllm"
  82. class BaseProviderConfig(BaseModel):
  83. model_config: ConfigDict = {
  84. "extra": "allow",
  85. }
  86. _chat_uri: Optional[str] = "/v1/chat/completions"
  87. _public_endpoint: Optional[str] = None
  88. _default_schema = "https"
  89. _model_uri = None
  90. def get_base_url(self) -> Optional[str]:
  91. if self._public_endpoint:
  92. return f"{self._default_schema}://{self._public_endpoint}"
  93. return None
  94. def check_required_fields(self):
  95. missing_fields = []
  96. for name, field in self.__class__.model_fields.items():
  97. schema_extra = field.json_schema_extra or {}
  98. if schema_extra.get("field_required", False):
  99. value = getattr(self, name)
  100. if value is None:
  101. missing_fields.append(name)
  102. if missing_fields:
  103. raise ValueError(
  104. f"Missing required fields for provider {self.type}: {', '.join(missing_fields)}"
  105. )
  106. return self
  107. def get_model_url(self) -> Tuple[Optional[str], Optional[str]]:
  108. base_url = self.get_base_url()
  109. if base_url:
  110. base_url = base_url.rstrip("/")
  111. return base_url, self._model_uri
  112. def get_chat_url(self) -> Tuple[Optional[str], Optional[str]]:
  113. base_url = self.get_base_url()
  114. if base_url:
  115. base_url = base_url.rstrip("/")
  116. return base_url, self._chat_uri
  117. def model_dump_with_default_override(self) -> Dict[str, Any]:
  118. """Dumps the model, excluding unset fields, and then merges with `_default_override` values.
  119. This method is used to generate a configuration dictionary for services
  120. that require certain default values to be present, even if they are not
  121. explicitly set by the user. User-set values will take precedence over
  122. the default override values.
  123. The `_default_override` attribute should be a dictionary defined on the
  124. config subclass.
  125. """
  126. default_override = getattr(self, "_default_override", {})
  127. values = {
  128. **default_override,
  129. **self.model_dump(exclude_unset=True, exclude={"type"}),
  130. }
  131. return values
  132. class Ai360Config(BaseProviderConfig):
  133. type: Literal[ModelProviderTypeEnum.AI360]
  134. _public_endpoint: str = "api.360.cn"
  135. class AzureOpenAIConfig(BaseProviderConfig):
  136. type: Literal[ModelProviderTypeEnum.AZURE]
  137. azureServiceUrl: Optional[str] = PydanticField(
  138. default=None, json_schema_extra={"field_required": True}
  139. )
  140. def get_base_url(self) -> Optional[str]:
  141. return self.azureServiceUrl
  142. class BaichuanConfig(BaseProviderConfig):
  143. type: Literal[ModelProviderTypeEnum.BAICHUAN]
  144. _public_endpoint: str = "api.baichuan-ai.com"
  145. _model_uri = "/v1/models"
  146. class BaiduConfig(BaseProviderConfig):
  147. type: Literal[ModelProviderTypeEnum.BAIDU]
  148. _public_endpoint: str = "qianfan.baidubce.com"
  149. _model_uri = "/v1/models"
  150. class BedrockConfig(BaseProviderConfig):
  151. type: Literal[ModelProviderTypeEnum.BEDROCK]
  152. awsAccessKey: Optional[str] = PydanticField(
  153. default=None, json_schema_extra={"field_required": True}
  154. )
  155. awsSecretKey: Optional[str] = PydanticField(
  156. default=None, json_schema_extra={"field_required": True}
  157. )
  158. awsRegion: Optional[str] = PydanticField(
  159. default=None, json_schema_extra={"field_required": True}
  160. )
  161. bedrockAdditionalFields: Optional[dict] = None
  162. def get_base_url(self):
  163. return (
  164. f"{self._default_schema}://bedrock-runtime.{self.awsRegion}.amazonaws.com"
  165. )
  166. class ClaudeConfig(BaseProviderConfig):
  167. type: Literal[ModelProviderTypeEnum.CLAUDE]
  168. claudeVersion: Optional[str] = None
  169. _public_endpoint: str = "api.anthropic.com"
  170. _model_uri = "/v1/models"
  171. _chat_uri = "/v1/messages"
  172. class CloudflareConfig(BaseProviderConfig):
  173. type: Literal[ModelProviderTypeEnum.CLOUDFLARE]
  174. cloudflareAccountId: Optional[str] = PydanticField(
  175. default=None, json_schema_extra={"field_required": True}
  176. )
  177. _public_endpoint: str = "api.cloudflare.com"
  178. _model_uri = None
  179. class CohereConfig(BaseProviderConfig):
  180. type: Literal[ModelProviderTypeEnum.COHERE]
  181. _public_endpoint: str = "api.cohere.com"
  182. class CozeConfig(BaseProviderConfig):
  183. type: Literal[ModelProviderTypeEnum.COZE]
  184. _public_endpoint: str = "api.coze.cn"
  185. class DeeplConfig(BaseProviderConfig):
  186. type: Literal[ModelProviderTypeEnum.DEEPL]
  187. targetLang: Optional[str] = PydanticField(
  188. default=None, json_schema_extra={"field_required": True}
  189. )
  190. _public_endpoint: str = "api-free.deepl.com"
  191. class DeepseekConfig(BaseProviderConfig):
  192. type: Literal[ModelProviderTypeEnum.DEEPSEEK]
  193. _public_endpoint: str = "api.deepseek.com"
  194. _model_uri = "/v1/models"
  195. class DifyConfig(BaseProviderConfig):
  196. type: Literal[ModelProviderTypeEnum.DIFY]
  197. difyApiUrl: Optional[str] = None
  198. botType: Optional[str] = None
  199. inputVariable: Optional[str] = None
  200. outputVariable: Optional[str] = None
  201. _public_endpoint: str = "api.dify.ai"
  202. def get_base_url(self) -> Optional[str]:
  203. if self.difyApiUrl:
  204. return self.difyApiUrl
  205. return super().get_base_url()
  206. class DoubaoConfig(BaseProviderConfig):
  207. type: Literal[ModelProviderTypeEnum.DOUBAO]
  208. doubaoDomain: Optional[str] = None
  209. _public_endpoint: str = "ark.cn-beijing.volces.com"
  210. _model_uri = "/api/v3/models"
  211. _chat_uri = "/api/v3/chat/completions"
  212. def get_base_url(self):
  213. domain = self.doubaoDomain or self._public_endpoint
  214. return f"{self._default_schema}://{domain}"
  215. class FireworksConfig(BaseProviderConfig):
  216. type: Literal[ModelProviderTypeEnum.FIREWORKS]
  217. _public_endpoint: str = "api.fireworks.ai"
  218. _model_uri = "/v1/models"
  219. class GeminiConfig(BaseProviderConfig):
  220. type: Literal[ModelProviderTypeEnum.GEMINI]
  221. geminiSafetySetting: Optional[Mapping[str, str]] = None
  222. apiVersion: Optional[str] = None
  223. geminiThinkingBudget: Optional[float] = None
  224. _public_endpoint: str = "generativelanguage.googleapis.com"
  225. _default_override = {"apiVersion": "v1beta"}
  226. class GenericConfig(BaseProviderConfig):
  227. type: Literal[ModelProviderTypeEnum.GENERIC]
  228. _public_endpoint: str = ""
  229. def get_base_url(self) -> Optional[str]:
  230. return None
  231. class GithubConfig(BaseProviderConfig):
  232. type: Literal[ModelProviderTypeEnum.GITHUB]
  233. _public_endpoint: str = "models.inference.ai.azure.com"
  234. class GrokConfig(BaseProviderConfig):
  235. type: Literal[ModelProviderTypeEnum.GROK]
  236. _public_endpoint: str = "api.x.ai"
  237. class GroqConfig(BaseProviderConfig):
  238. type: Literal[ModelProviderTypeEnum.GROQ]
  239. _public_endpoint: str = "api.groq.com"
  240. class HunyuanConfig(BaseProviderConfig):
  241. type: Literal[ModelProviderTypeEnum.HUNYUAN]
  242. hunyuanAuthId: Optional[str] = PydanticField(
  243. default=None, json_schema_extra={"field_required": True}
  244. )
  245. hunyuanAuthKey: Optional[str] = PydanticField(
  246. default=None, json_schema_extra={"field_required": True}
  247. )
  248. _public_endpoint: str = "hunyuan.tencentcloudapi.com"
  249. class LongcatConfig(BaseProviderConfig):
  250. type: Literal[ModelProviderTypeEnum.LONGCAT]
  251. _public_endpoint: str = "api.longcat.chat"
  252. _model_uri = "/v1/models"
  253. class MinimaxConfig(BaseProviderConfig):
  254. type: Literal[ModelProviderTypeEnum.MINIMAX]
  255. minimaxApiType: Optional[str] = None
  256. minimaxGroupId: Optional[str] = None
  257. _public_endpoint: str = "api.minimax.chat"
  258. _default_override = {"minimaxApiType": "v2"}
  259. class MistralConfig(BaseProviderConfig):
  260. type: Literal[ModelProviderTypeEnum.MISTRAL]
  261. _public_endpoint: str = "api.mistral.ai"
  262. class MoonshotConfig(BaseProviderConfig):
  263. type: Literal[ModelProviderTypeEnum.MOONSHOT]
  264. moonshotFileId: Optional[str] = None
  265. _public_endpoint: str = "api.moonshot.cn"
  266. _model_uri = "/v1/models"
  267. class OllamaConfig(BaseProviderConfig):
  268. type: Literal[ModelProviderTypeEnum.OLLAMA]
  269. ollamaServerHost: Optional[str] = PydanticField(
  270. default=None, json_schema_extra={"field_required": True}
  271. )
  272. ollamaServerPort: Optional[int] = PydanticField(
  273. default=None, json_schema_extra={"field_required": True}
  274. )
  275. _default_schema = "http"
  276. _model_uri = "/v1/models"
  277. def get_base_url(self):
  278. if not self.ollamaServerHost:
  279. return None
  280. port_suffix = f":{self.ollamaServerPort}" if self.ollamaServerPort else ""
  281. domain = f"{self.ollamaServerHost}{port_suffix}"
  282. return f"{self._default_schema}://{domain}"
  283. class OpenAIConfig(BaseProviderConfig):
  284. type: Literal[ModelProviderTypeEnum.OPENAI]
  285. openaiCustomUrl: Optional[str] = None
  286. responseJsonSchema: Optional[dict] = None
  287. _public_endpoint: str = "api.openai.com"
  288. _model_uri = "/v1/models"
  289. def get_base_url(self) -> Optional[str]:
  290. if self.openaiCustomUrl:
  291. parsed_url = urlparse(self.openaiCustomUrl)
  292. return f"{parsed_url.scheme}://{parsed_url.netloc}"
  293. return super().get_base_url()
  294. def get_model_url(self) -> Tuple[Optional[str], Optional[str]]:
  295. if not self.openaiCustomUrl:
  296. return super().get_model_url()
  297. parsed_url = urlparse(self.openaiCustomUrl)
  298. model_uri = f"{parsed_url.path.rstrip('/')}/models"
  299. return self.get_base_url(), model_uri
  300. def get_chat_url(self):
  301. if not self.openaiCustomUrl:
  302. return super().get_chat_url()
  303. parsed_url = urlparse(self.openaiCustomUrl)
  304. chat_uri = f"{parsed_url.path.rstrip('/')}/chat/completions"
  305. return self.get_base_url(), chat_uri
  306. class OpenrouterConfig(BaseProviderConfig):
  307. type: Literal[ModelProviderTypeEnum.OPENROUTER]
  308. _public_endpoint: str = "openrouter.ai"
  309. class QwenConfig(BaseProviderConfig):
  310. type: Literal[ModelProviderTypeEnum.QWEN]
  311. qwenEnableSearch: Optional[bool] = None
  312. qwenFileIds: Optional[List[str]] = None
  313. qwenEnableCompatible: Optional[bool] = None
  314. _public_endpoint: str = "dashscope.aliyuncs.com"
  315. _model_uri = "/compatible-mode/v1/models"
  316. _chat_uri = "/compatible-mode/v1/chat/completions"
  317. _default_override = {"qwenEnableCompatible": True}
  318. class SparkConfig(BaseProviderConfig):
  319. type: Literal[ModelProviderTypeEnum.SPARK]
  320. _public_endpoint: str = "spark-api-open.xf-yun.com"
  321. class StepfunConfig(BaseProviderConfig):
  322. type: Literal[ModelProviderTypeEnum.STEPFUN]
  323. _public_endpoint: str = "api.stepfun.com"
  324. class TogetherAIConfig(BaseProviderConfig):
  325. type: Literal[ModelProviderTypeEnum.TOGETHERAI]
  326. _public_endpoint: str = "api.together.xyz"
  327. class TritonConfig(BaseProviderConfig):
  328. type: Literal[ModelProviderTypeEnum.TRITON]
  329. modelVersion: Optional[str] = None
  330. tritonDomain: Optional[str] = None
  331. def get_base_url(self) -> Optional[str]:
  332. if not self.tritonDomain:
  333. return None
  334. return f"{self._default_schema}://{self.tritonDomain}"
  335. class YiConfig(BaseProviderConfig):
  336. type: Literal[ModelProviderTypeEnum.YI]
  337. _public_endpoint: str = "api.lingyiwanwu.com"
  338. class ZhipuaiConfig(BaseProviderConfig):
  339. type: Literal[ModelProviderTypeEnum.ZHIPUAI]
  340. _public_endpoint: str = "open.bigmodel.cn"
  341. ProviderConfigType = Union[
  342. Ai360Config,
  343. AzureOpenAIConfig,
  344. BaichuanConfig,
  345. BaiduConfig,
  346. BedrockConfig,
  347. ClaudeConfig,
  348. CloudflareConfig,
  349. CohereConfig,
  350. CozeConfig,
  351. DeeplConfig,
  352. DeepseekConfig,
  353. DifyConfig,
  354. DoubaoConfig,
  355. FireworksConfig,
  356. GeminiConfig,
  357. GithubConfig,
  358. GrokConfig,
  359. GroqConfig,
  360. HunyuanConfig,
  361. LongcatConfig,
  362. MinimaxConfig,
  363. MistralConfig,
  364. MoonshotConfig,
  365. OllamaConfig,
  366. OpenAIConfig,
  367. OpenrouterConfig,
  368. QwenConfig,
  369. SparkConfig,
  370. StepfunConfig,
  371. TogetherAIConfig,
  372. TritonConfig,
  373. YiConfig,
  374. ZhipuaiConfig,
  375. ]
  376. class ProviderModel(BaseModel):
  377. name: str
  378. category: Optional[str] = None
  379. class MaskedAPIToken(BaseModel):
  380. input: Optional[str] = None
  381. hash: Optional[str] = None
  382. @model_validator(mode="after")
  383. def check_fields(self):
  384. if self.input is None and self.hash is None:
  385. raise ValueError(
  386. "Either 'input' or 'hash' must be provided for a masked API token."
  387. )
  388. if self.input is not None and self.hash is not None:
  389. raise ValueError(
  390. "Only one of 'input' or 'hash' can be provided for a masked API token."
  391. )
  392. if self.input is not None and not self.input.strip():
  393. raise ValueError("API token input cannot be empty or just whitespace.")
  394. return self
  395. class ModelProviderBase(SQLModel):
  396. name: str = Field(index=True, nullable=False, unique=True)
  397. description: Optional[str] = Field(default=None, nullable=True)
  398. timeout: int = Field(default=120, nullable=False)
  399. config: ProviderConfigType = Field(
  400. description="provider specific configuration",
  401. sa_column=Column(
  402. pydantic_column_type(
  403. ProviderConfigType,
  404. exclude_defaults=True,
  405. exclude_none=True,
  406. exclude_unset=True,
  407. ),
  408. ),
  409. )
  410. models: Optional[List[ProviderModel]] = Field(
  411. default=[],
  412. sa_column=Column(
  413. pydantic_column_type(List[ProviderModel]),
  414. nullable=True,
  415. ),
  416. )
  417. proxy_url: Optional[str] = Field(default=None, nullable=True)
  418. proxy_timeout: Optional[int] = Field(default=None, nullable=True)
  419. @model_validator(mode="after")
  420. def check_all(self):
  421. if self.timeout <= 0:
  422. raise ValueError("timeout must be a positive integer")
  423. if self.proxy_timeout is not None and self.proxy_timeout <= 0:
  424. raise ValueError("proxy_timeout must be a positive integer")
  425. if self.proxy_timeout is not None and self.proxy_url is None:
  426. raise ValueError("proxy_url must be set when proxy_timeout is set")
  427. return self
  428. class ModelProviderUpdate(ModelProviderBase):
  429. api_tokens: List[MaskedAPIToken] = PydanticField(
  430. default=[],
  431. )
  432. @field_validator("api_tokens")
  433. def check_api_tokens(cls, v):
  434. if v is not None:
  435. if not isinstance(v, list) or len(v) == 0:
  436. raise ValueError("api_tokens must be a non-empty list")
  437. return v
  438. class ModelProviderCreate(ModelProviderUpdate):
  439. clone_from_id: Optional[int] = PydanticField(default=None)
  440. class ModelProvider(ModelProviderBase, BaseModelMixin, table=True):
  441. __tablename__ = "model_providers"
  442. id: Optional[int] = Field(default=None, primary_key=True)
  443. # Tenant scope. NULL = global (admin-managed). Org-owned
  444. # providers carry the owning Org id.
  445. owner_principal_id: Optional[int] = Field(
  446. default=None,
  447. sa_column=Column(Integer, ForeignKey("principals.id"), nullable=True),
  448. )
  449. api_tokens: List[str] = Field(
  450. sa_column=Column(JSON, nullable=False),
  451. default=[],
  452. )
  453. model_route_targets: List["ModelRouteTarget"] = Relationship(
  454. back_populates="provider",
  455. sa_relationship_kwargs={"lazy": "noload", "cascade": "delete"},
  456. )
  457. @classmethod
  458. def _convert_to_public_class(cls, data) -> "ModelProviderPublic":
  459. # somehow when updating model provider while deleting targets
  460. # the result of await ModelProvider.one_by_id(session=session, id=id) is not fully correct.
  461. # e.g. the provider.config is a dict instead of correct config class and it will
  462. # yields validation warnings when model_dump it. So setting warnings=False to ignore
  463. # the warnings and convert it to correct config class by ourselves.
  464. dict_data = data if isinstance(data, dict) else data.model_dump(warnings=False)
  465. current_tokens: List[str] = dict_data.pop("api_tokens", None)
  466. masked_tokens: List[MaskedAPIToken] = []
  467. if current_tokens:
  468. masked_tokens = [
  469. {"hash": hashlib.sha256(token.encode()).hexdigest()}
  470. for token in current_tokens
  471. ]
  472. dict_data["api_tokens"] = masked_tokens
  473. return ModelProviderPublic.model_validate(dict_data)
  474. class ModelProviderPublic(ModelProviderUpdate, PublicFields):
  475. pass
  476. ModelProvidersPublic = PaginatedList[ModelProviderPublic]
  477. class ModelProviderListParams(ListParams):
  478. sortable_fields: ClassVar[List[str]] = [
  479. "id",
  480. "name",
  481. "created_at",
  482. "updated_at",
  483. ]
  484. class ProviderModelsInput(BaseModel):
  485. api_token: Optional[str] = None
  486. config: Optional[ProviderConfigType] = None
  487. proxy_url: Optional[str] = None
  488. class TestProviderModelInput(ProviderModelsInput):
  489. model_name: str
  490. class TestProviderModelResult(BaseModel):
  491. model_name: str
  492. accessible: bool
  493. error_message: Optional[str] = None