| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- # coding=utf-8
- """
- OpenAI 兼容私域模型提供商
- 用于添加自定义 OpenAI 兼容 API 的私域模型
- """
- import os
- from django.utils.translation import gettext as _
- from common.utils.common import get_file_content
- from maxkb.conf import PROJECT_DIR
- from models_provider.base_model_provider import (
- IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
- )
- from models_provider.impl.openai_compatible_provider.credential.llm import (
- OpenAICompatibleLLMModelCredential
- )
- from models_provider.impl.openai_compatible_provider.model.llm import (
- OpenAICompatibleChatModel
- )
- # 默认凭证实例
- openai_compatible_credential = OpenAICompatibleLLMModelCredential()
- # 预置模型列表(用户也可自定义)
- model_info_list = [
- ModelInfo(
- 'gpt-4',
- _('GPT-4 via custom API'),
- ModelTypeConst.LLM,
- openai_compatible_credential,
- OpenAICompatibleChatModel
- ),
- ModelInfo(
- 'gpt-3.5-turbo',
- _('GPT-3.5 Turbo via custom API'),
- ModelTypeConst.LLM,
- openai_compatible_credential,
- OpenAICompatibleChatModel
- ),
- ModelInfo(
- 'deepseek-chat',
- _('DeepSeek Chat via custom API'),
- ModelTypeConst.LLM,
- openai_compatible_credential,
- OpenAICompatibleChatModel
- ),
- ]
- model_info_manage = (
- ModelInfoManage.builder()
- .append_model_info_list(model_info_list)
- .append_default_model_info(model_info_list[0])
- .build()
- )
- class OpenAICompatibleProvider(IModelProvider):
- """OpenAI 兼容私域模型提供商"""
- def get_model_info_manage(self):
- return model_info_manage
- def get_model_provide_info(self):
- return ModelProvideInfo(
- provider='model_openai_compatible_provider',
- name=_('OpenAI Compatible'),
- icon=get_file_content(
- os.path.join(
- PROJECT_DIR, "apps", 'models_provider', 'impl',
- 'openai_compatible_provider', 'icon', 'openai_compatible_icon_svg'
- )
- )
- )
- def get_base_model_list(self, model_credential):
- """获取可用模型列表 - 直接返回预置列表,用户也可自定义输入"""
- return [{'name': m.name, 'model_type': ModelTypeConst.LLM.name} for m in model_info_list]
|