| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617 |
- import hashlib
- from typing import Tuple
- from urllib.parse import urlparse
- from enum import Enum
- from typing import (
- ClassVar,
- Optional,
- List,
- Union,
- TYPE_CHECKING,
- Literal,
- Mapping,
- Dict,
- Any,
- )
- from pydantic import (
- BaseModel,
- ConfigDict,
- field_validator,
- model_validator,
- Field as PydanticField,
- )
- from sqlmodel import (
- Field,
- Column,
- ForeignKey,
- Integer,
- JSON,
- SQLModel,
- Relationship,
- )
- from gpustack.mixins import BaseModelMixin
- from gpustack.schemas.common import (
- PublicFields,
- ListParams,
- PaginatedList,
- pydantic_column_type,
- )
- if TYPE_CHECKING:
- from gpustack.schemas.model_routes import ModelRouteTarget
- # The provider types should be synced with higress ai-proxy supported providers
- class ModelProviderTypeEnum(str, Enum):
- AI360 = "ai360"
- AZURE = "azure"
- BAICHUAN = "baichuan"
- BAIDU = "baidu"
- BEDROCK = "bedrock"
- CLAUDE = "claude"
- CLOUDFLARE = "cloudflare"
- COHERE = "cohere"
- COZE = "coze"
- DEEPL = "deepl"
- DEEPSEEK = "deepseek"
- DIFY = "dify"
- DOUBAO = "doubao"
- FIREWORKS = "fireworks"
- GEMINI = "gemini"
- GENERIC = "generic"
- GITHUB = "github"
- GROK = "grok"
- GROQ = "groq"
- HUNYUAN = "hunyuan"
- LONGCAT = "longcat"
- MINIMAX = "minimax"
- MISTRAL = "mistral"
- MOONSHOT = "moonshot"
- OLLAMA = "ollama"
- OPENAI = "openai"
- OPENROUTER = "openrouter"
- QWEN = "qwen"
- SPARK = "spark"
- STEPFUN = "stepfun"
- TOGETHERAI = "together-ai"
- TRITON = "triton"
- YI = "yi"
- ZHIPUAI = "zhipuai"
- # following types are not supported yet
- # For vertex, It has more complex configuration than other providers. Keep it unsupported for now.
- # VERTEX = "vertex"
- # For vllm, most of the vllm provider functions can be replaced by open-ai compatible provider.
- # VLLM = "vllm"
- class BaseProviderConfig(BaseModel):
- model_config: ConfigDict = {
- "extra": "allow",
- }
- _chat_uri: Optional[str] = "/v1/chat/completions"
- _public_endpoint: Optional[str] = None
- _default_schema = "https"
- _model_uri = None
- def get_base_url(self) -> Optional[str]:
- if self._public_endpoint:
- return f"{self._default_schema}://{self._public_endpoint}"
- return None
- def check_required_fields(self):
- missing_fields = []
- for name, field in self.__class__.model_fields.items():
- schema_extra = field.json_schema_extra or {}
- if schema_extra.get("field_required", False):
- value = getattr(self, name)
- if value is None:
- missing_fields.append(name)
- if missing_fields:
- raise ValueError(
- f"Missing required fields for provider {self.type}: {', '.join(missing_fields)}"
- )
- return self
- def get_model_url(self) -> Tuple[Optional[str], Optional[str]]:
- base_url = self.get_base_url()
- if base_url:
- base_url = base_url.rstrip("/")
- return base_url, self._model_uri
- def get_chat_url(self) -> Tuple[Optional[str], Optional[str]]:
- base_url = self.get_base_url()
- if base_url:
- base_url = base_url.rstrip("/")
- return base_url, self._chat_uri
- def model_dump_with_default_override(self) -> Dict[str, Any]:
- """Dumps the model, excluding unset fields, and then merges with `_default_override` values.
- This method is used to generate a configuration dictionary for services
- that require certain default values to be present, even if they are not
- explicitly set by the user. User-set values will take precedence over
- the default override values.
- The `_default_override` attribute should be a dictionary defined on the
- config subclass.
- """
- default_override = getattr(self, "_default_override", {})
- values = {
- **default_override,
- **self.model_dump(exclude_unset=True, exclude={"type"}),
- }
- return values
- class Ai360Config(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.AI360]
- _public_endpoint: str = "api.360.cn"
- class AzureOpenAIConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.AZURE]
- azureServiceUrl: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- def get_base_url(self) -> Optional[str]:
- return self.azureServiceUrl
- class BaichuanConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.BAICHUAN]
- _public_endpoint: str = "api.baichuan-ai.com"
- _model_uri = "/v1/models"
- class BaiduConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.BAIDU]
- _public_endpoint: str = "qianfan.baidubce.com"
- _model_uri = "/v1/models"
- class BedrockConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.BEDROCK]
- awsAccessKey: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- awsSecretKey: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- awsRegion: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- bedrockAdditionalFields: Optional[dict] = None
- def get_base_url(self):
- return (
- f"{self._default_schema}://bedrock-runtime.{self.awsRegion}.amazonaws.com"
- )
- class ClaudeConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.CLAUDE]
- claudeVersion: Optional[str] = None
- _public_endpoint: str = "api.anthropic.com"
- _model_uri = "/v1/models"
- _chat_uri = "/v1/messages"
- class CloudflareConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.CLOUDFLARE]
- cloudflareAccountId: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- _public_endpoint: str = "api.cloudflare.com"
- _model_uri = None
- class CohereConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.COHERE]
- _public_endpoint: str = "api.cohere.com"
- class CozeConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.COZE]
- _public_endpoint: str = "api.coze.cn"
- class DeeplConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.DEEPL]
- targetLang: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- _public_endpoint: str = "api-free.deepl.com"
- class DeepseekConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.DEEPSEEK]
- _public_endpoint: str = "api.deepseek.com"
- _model_uri = "/v1/models"
- class DifyConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.DIFY]
- difyApiUrl: Optional[str] = None
- botType: Optional[str] = None
- inputVariable: Optional[str] = None
- outputVariable: Optional[str] = None
- _public_endpoint: str = "api.dify.ai"
- def get_base_url(self) -> Optional[str]:
- if self.difyApiUrl:
- return self.difyApiUrl
- return super().get_base_url()
- class DoubaoConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.DOUBAO]
- doubaoDomain: Optional[str] = None
- _public_endpoint: str = "ark.cn-beijing.volces.com"
- _model_uri = "/api/v3/models"
- _chat_uri = "/api/v3/chat/completions"
- def get_base_url(self):
- domain = self.doubaoDomain or self._public_endpoint
- return f"{self._default_schema}://{domain}"
- class FireworksConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.FIREWORKS]
- _public_endpoint: str = "api.fireworks.ai"
- _model_uri = "/v1/models"
- class GeminiConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.GEMINI]
- geminiSafetySetting: Optional[Mapping[str, str]] = None
- apiVersion: Optional[str] = None
- geminiThinkingBudget: Optional[float] = None
- _public_endpoint: str = "generativelanguage.googleapis.com"
- _default_override = {"apiVersion": "v1beta"}
- class GenericConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.GENERIC]
- _public_endpoint: str = ""
- def get_base_url(self) -> Optional[str]:
- return None
- class GithubConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.GITHUB]
- _public_endpoint: str = "models.inference.ai.azure.com"
- class GrokConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.GROK]
- _public_endpoint: str = "api.x.ai"
- class GroqConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.GROQ]
- _public_endpoint: str = "api.groq.com"
- class HunyuanConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.HUNYUAN]
- hunyuanAuthId: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- hunyuanAuthKey: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- _public_endpoint: str = "hunyuan.tencentcloudapi.com"
- class LongcatConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.LONGCAT]
- _public_endpoint: str = "api.longcat.chat"
- _model_uri = "/v1/models"
- class MinimaxConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.MINIMAX]
- minimaxApiType: Optional[str] = None
- minimaxGroupId: Optional[str] = None
- _public_endpoint: str = "api.minimax.chat"
- _default_override = {"minimaxApiType": "v2"}
- class MistralConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.MISTRAL]
- _public_endpoint: str = "api.mistral.ai"
- class MoonshotConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.MOONSHOT]
- moonshotFileId: Optional[str] = None
- _public_endpoint: str = "api.moonshot.cn"
- _model_uri = "/v1/models"
- class OllamaConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.OLLAMA]
- ollamaServerHost: Optional[str] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- ollamaServerPort: Optional[int] = PydanticField(
- default=None, json_schema_extra={"field_required": True}
- )
- _default_schema = "http"
- _model_uri = "/v1/models"
- def get_base_url(self):
- if not self.ollamaServerHost:
- return None
- port_suffix = f":{self.ollamaServerPort}" if self.ollamaServerPort else ""
- domain = f"{self.ollamaServerHost}{port_suffix}"
- return f"{self._default_schema}://{domain}"
- class OpenAIConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.OPENAI]
- openaiCustomUrl: Optional[str] = None
- responseJsonSchema: Optional[dict] = None
- _public_endpoint: str = "api.openai.com"
- _model_uri = "/v1/models"
- def get_base_url(self) -> Optional[str]:
- if self.openaiCustomUrl:
- parsed_url = urlparse(self.openaiCustomUrl)
- return f"{parsed_url.scheme}://{parsed_url.netloc}"
- return super().get_base_url()
- def get_model_url(self) -> Tuple[Optional[str], Optional[str]]:
- if not self.openaiCustomUrl:
- return super().get_model_url()
- parsed_url = urlparse(self.openaiCustomUrl)
- model_uri = f"{parsed_url.path.rstrip('/')}/models"
- return self.get_base_url(), model_uri
- def get_chat_url(self):
- if not self.openaiCustomUrl:
- return super().get_chat_url()
- parsed_url = urlparse(self.openaiCustomUrl)
- chat_uri = f"{parsed_url.path.rstrip('/')}/chat/completions"
- return self.get_base_url(), chat_uri
- class OpenrouterConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.OPENROUTER]
- _public_endpoint: str = "openrouter.ai"
- class QwenConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.QWEN]
- qwenEnableSearch: Optional[bool] = None
- qwenFileIds: Optional[List[str]] = None
- qwenEnableCompatible: Optional[bool] = None
- _public_endpoint: str = "dashscope.aliyuncs.com"
- _model_uri = "/compatible-mode/v1/models"
- _chat_uri = "/compatible-mode/v1/chat/completions"
- _default_override = {"qwenEnableCompatible": True}
- class SparkConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.SPARK]
- _public_endpoint: str = "spark-api-open.xf-yun.com"
- class StepfunConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.STEPFUN]
- _public_endpoint: str = "api.stepfun.com"
- class TogetherAIConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.TOGETHERAI]
- _public_endpoint: str = "api.together.xyz"
- class TritonConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.TRITON]
- modelVersion: Optional[str] = None
- tritonDomain: Optional[str] = None
- def get_base_url(self) -> Optional[str]:
- if not self.tritonDomain:
- return None
- return f"{self._default_schema}://{self.tritonDomain}"
- class YiConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.YI]
- _public_endpoint: str = "api.lingyiwanwu.com"
- class ZhipuaiConfig(BaseProviderConfig):
- type: Literal[ModelProviderTypeEnum.ZHIPUAI]
- _public_endpoint: str = "open.bigmodel.cn"
- ProviderConfigType = Union[
- Ai360Config,
- AzureOpenAIConfig,
- BaichuanConfig,
- BaiduConfig,
- BedrockConfig,
- ClaudeConfig,
- CloudflareConfig,
- CohereConfig,
- CozeConfig,
- DeeplConfig,
- DeepseekConfig,
- DifyConfig,
- DoubaoConfig,
- FireworksConfig,
- GeminiConfig,
- GithubConfig,
- GrokConfig,
- GroqConfig,
- HunyuanConfig,
- LongcatConfig,
- MinimaxConfig,
- MistralConfig,
- MoonshotConfig,
- OllamaConfig,
- OpenAIConfig,
- OpenrouterConfig,
- QwenConfig,
- SparkConfig,
- StepfunConfig,
- TogetherAIConfig,
- TritonConfig,
- YiConfig,
- ZhipuaiConfig,
- ]
- class ProviderModel(BaseModel):
- name: str
- category: Optional[str] = None
- class MaskedAPIToken(BaseModel):
- input: Optional[str] = None
- hash: Optional[str] = None
- @model_validator(mode="after")
- def check_fields(self):
- if self.input is None and self.hash is None:
- raise ValueError(
- "Either 'input' or 'hash' must be provided for a masked API token."
- )
- if self.input is not None and self.hash is not None:
- raise ValueError(
- "Only one of 'input' or 'hash' can be provided for a masked API token."
- )
- if self.input is not None and not self.input.strip():
- raise ValueError("API token input cannot be empty or just whitespace.")
- return self
- class ModelProviderBase(SQLModel):
- name: str = Field(index=True, nullable=False, unique=True)
- description: Optional[str] = Field(default=None, nullable=True)
- timeout: int = Field(default=120, nullable=False)
- config: ProviderConfigType = Field(
- description="provider specific configuration",
- sa_column=Column(
- pydantic_column_type(
- ProviderConfigType,
- exclude_defaults=True,
- exclude_none=True,
- exclude_unset=True,
- ),
- ),
- )
- models: Optional[List[ProviderModel]] = Field(
- default=[],
- sa_column=Column(
- pydantic_column_type(List[ProviderModel]),
- nullable=True,
- ),
- )
- proxy_url: Optional[str] = Field(default=None, nullable=True)
- proxy_timeout: Optional[int] = Field(default=None, nullable=True)
- @model_validator(mode="after")
- def check_all(self):
- if self.timeout <= 0:
- raise ValueError("timeout must be a positive integer")
- if self.proxy_timeout is not None and self.proxy_timeout <= 0:
- raise ValueError("proxy_timeout must be a positive integer")
- if self.proxy_timeout is not None and self.proxy_url is None:
- raise ValueError("proxy_url must be set when proxy_timeout is set")
- return self
- class ModelProviderUpdate(ModelProviderBase):
- api_tokens: List[MaskedAPIToken] = PydanticField(
- default=[],
- )
- @field_validator("api_tokens")
- def check_api_tokens(cls, v):
- if v is not None:
- if not isinstance(v, list) or len(v) == 0:
- raise ValueError("api_tokens must be a non-empty list")
- return v
- class ModelProviderCreate(ModelProviderUpdate):
- clone_from_id: Optional[int] = PydanticField(default=None)
- class ModelProvider(ModelProviderBase, BaseModelMixin, table=True):
- __tablename__ = "model_providers"
- id: Optional[int] = Field(default=None, primary_key=True)
- # Tenant scope. NULL = global (admin-managed). Org-owned
- # providers carry the owning Org id.
- owner_principal_id: Optional[int] = Field(
- default=None,
- sa_column=Column(Integer, ForeignKey("principals.id"), nullable=True),
- )
- api_tokens: List[str] = Field(
- sa_column=Column(JSON, nullable=False),
- default=[],
- )
- model_route_targets: List["ModelRouteTarget"] = Relationship(
- back_populates="provider",
- sa_relationship_kwargs={"lazy": "noload", "cascade": "delete"},
- )
- @classmethod
- def _convert_to_public_class(cls, data) -> "ModelProviderPublic":
- # somehow when updating model provider while deleting targets
- # the result of await ModelProvider.one_by_id(session=session, id=id) is not fully correct.
- # e.g. the provider.config is a dict instead of correct config class and it will
- # yields validation warnings when model_dump it. So setting warnings=False to ignore
- # the warnings and convert it to correct config class by ourselves.
- dict_data = data if isinstance(data, dict) else data.model_dump(warnings=False)
- current_tokens: List[str] = dict_data.pop("api_tokens", None)
- masked_tokens: List[MaskedAPIToken] = []
- if current_tokens:
- masked_tokens = [
- {"hash": hashlib.sha256(token.encode()).hexdigest()}
- for token in current_tokens
- ]
- dict_data["api_tokens"] = masked_tokens
- return ModelProviderPublic.model_validate(dict_data)
- class ModelProviderPublic(ModelProviderUpdate, PublicFields):
- pass
- ModelProvidersPublic = PaginatedList[ModelProviderPublic]
- class ModelProviderListParams(ListParams):
- sortable_fields: ClassVar[List[str]] = [
- "id",
- "name",
- "created_at",
- "updated_at",
- ]
- class ProviderModelsInput(BaseModel):
- api_token: Optional[str] = None
- config: Optional[ProviderConfigType] = None
- proxy_url: Optional[str] = None
- class TestProviderModelInput(ProviderModelsInput):
- model_name: str
- class TestProviderModelResult(BaseModel):
- model_name: str
- accessible: bool
- error_message: Optional[str] = None
|