openai_compatible_provider.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # coding=utf-8
  2. """
  3. OpenAI 兼容私域模型提供商
  4. 用于添加自定义 OpenAI 兼容 API 的私域模型
  5. """
  6. import os
  7. from django.utils.translation import gettext as _
  8. from common.utils.common import get_file_content
  9. from maxkb.conf import PROJECT_DIR
  10. from models_provider.base_model_provider import (
  11. IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
  12. )
  13. from models_provider.impl.openai_compatible_provider.credential.llm import (
  14. OpenAICompatibleLLMModelCredential
  15. )
  16. from models_provider.impl.openai_compatible_provider.model.llm import (
  17. OpenAICompatibleChatModel
  18. )
  19. # 默认凭证实例
  20. openai_compatible_credential = OpenAICompatibleLLMModelCredential()
  21. # 预置模型列表(用户也可自定义)
  22. model_info_list = [
  23. ModelInfo(
  24. 'gpt-4',
  25. _('GPT-4 via custom API'),
  26. ModelTypeConst.LLM,
  27. openai_compatible_credential,
  28. OpenAICompatibleChatModel
  29. ),
  30. ModelInfo(
  31. 'gpt-3.5-turbo',
  32. _('GPT-3.5 Turbo via custom API'),
  33. ModelTypeConst.LLM,
  34. openai_compatible_credential,
  35. OpenAICompatibleChatModel
  36. ),
  37. ModelInfo(
  38. 'deepseek-chat',
  39. _('DeepSeek Chat via custom API'),
  40. ModelTypeConst.LLM,
  41. openai_compatible_credential,
  42. OpenAICompatibleChatModel
  43. ),
  44. ]
  45. model_info_manage = (
  46. ModelInfoManage.builder()
  47. .append_model_info_list(model_info_list)
  48. .append_default_model_info(model_info_list[0])
  49. .build()
  50. )
  51. class OpenAICompatibleProvider(IModelProvider):
  52. """OpenAI 兼容私域模型提供商"""
  53. def get_model_info_manage(self):
  54. return model_info_manage
  55. def get_model_provide_info(self):
  56. return ModelProvideInfo(
  57. provider='model_openai_compatible_provider',
  58. name=_('OpenAI Compatible'),
  59. icon=get_file_content(
  60. os.path.join(
  61. PROJECT_DIR, "apps", 'models_provider', 'impl',
  62. 'openai_compatible_provider', 'icon', 'openai_compatible_icon_svg'
  63. )
  64. )
  65. )
  66. def get_base_model_list(self, model_credential):
  67. """获取可用模型列表 - 直接返回预置列表,用户也可自定义输入"""
  68. return [{'name': m.name, 'model_type': ModelTypeConst.LLM.name} for m in model_info_list]