tools.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: tools.py
  6. @date:2024/7/22 11:18
  7. @desc:
  8. """
  9. from django.db import connection
  10. from django.db.models import QuerySet
  11. from common.config.embedding_config import ModelManage
  12. from common.database_model_manage.database_model_manage import DatabaseModelManage
  13. from models_provider.models import Model
  14. from django.utils.translation import gettext_lazy as _
  15. import json
  16. from typing import Dict
  17. from common.utils.rsa_util import rsa_long_decrypt
  18. from models_provider.constants.model_provider_constants import ModelProvideConstants
  19. def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs):
  20. """
  21. 获取模型实例
  22. @param provider: 供应商
  23. @param model_type: 模型类型
  24. @param model_name: 模型名称
  25. @param credential: 认证信息
  26. @param model_id: 模型id
  27. @param use_local: 是否调用本地模型 只适用于本地供应商
  28. @return: 模型实例
  29. """
  30. model = get_provider(provider).get_model(model_type, model_name,
  31. json.loads(
  32. rsa_long_decrypt(credential)),
  33. model_id=model_id,
  34. use_local=use_local,
  35. streaming=True, **kwargs)
  36. return model
  37. def get_model(model, **kwargs):
  38. """
  39. 获取模型实例
  40. @param model: model 数据库Model实例对象
  41. @return: 模型实例
  42. """
  43. return get_model_(model.provider, model.model_type, model.model_name, model.credential, str(model.id), **kwargs)
  44. def get_provider(provider):
  45. """
  46. 获取供应商实例
  47. @param provider: 供应商字符串
  48. @return: 供应商实例
  49. """
  50. return ModelProvideConstants[provider].value
  51. def get_model_list(provider, model_type):
  52. """
  53. 获取模型列表
  54. @param provider: 供应商字符串
  55. @param model_type: 模型类型
  56. @return: 模型列表
  57. """
  58. return get_provider(provider).get_model_list(model_type)
  59. def get_model_credential(provider, model_type, model_name):
  60. """
  61. 获取模型认证实例
  62. @param provider: 供应商字符串
  63. @param model_type: 模型类型
  64. @param model_name: 模型名称
  65. @return: 认证实例对象
  66. """
  67. return get_provider(provider).get_model_credential(model_type, model_name)
  68. def get_model_type_list(provider):
  69. """
  70. 获取模型类型列表
  71. @param provider: 供应商字符串
  72. @return: 模型类型列表
  73. """
  74. return get_provider(provider).get_model_type_list()
  75. def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params,
  76. raise_exception=False):
  77. """
  78. 校验模型认证参数
  79. @param provider: 供应商字符串
  80. @param model_type: 模型类型
  81. @param model_name: 模型名称
  82. @param model_credential: 模型认证数据
  83. @param raise_exception: 是否抛出错误
  84. @return: True|False
  85. """
  86. return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params,
  87. raise_exception)
  88. def get_model_by_id(_id, workspace_id):
  89. model = QuerySet(Model).filter(id=_id).first()
  90. # 归还链接到连接池
  91. connection.close()
  92. get_authorized_model = DatabaseModelManage.get_model("get_authorized_model")
  93. if model and model.workspace_id != workspace_id and get_authorized_model is not None:
  94. model = get_authorized_model(QuerySet(Model).filter(id=_id), workspace_id).first()
  95. if model is None:
  96. raise Exception(_("Model does not exist"))
  97. return model
  98. def get_model_default_params(model):
  99. def convert_to_int(value):
  100. if isinstance(value, str):
  101. try:
  102. return int(value)
  103. except ValueError:
  104. return value
  105. return value
  106. return {
  107. p.get('field'): convert_to_int(p.get('default_value'))
  108. for p in model.model_params_form
  109. if p.get('default_value') is not None
  110. }
  111. def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
  112. """
  113. 获取模型实例,根据模型相关数据
  114. @param model_id: 模型id
  115. @param workspace_id: 工作空间id
  116. @return: 模型实例
  117. """
  118. model = get_model_by_id(model_id, workspace_id)
  119. s = get_model_default_params(model)
  120. return ModelManage.get_model(model_id, lambda _id: get_model(model, **{**s, **kwargs}))