base_model_provider.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # coding=utf-8
  2. from abc import ABC, abstractmethod
  3. from enum import Enum
  4. from functools import reduce
  5. from typing import Dict, Iterator, Type, List
  6. from pydantic import BaseModel
  7. from common.exception.app_exception import AppApiException
  8. from django.utils.translation import gettext_lazy as _
  9. from common.utils.common import encryption
  10. class DownModelChunkStatus(Enum):
  11. success = "success"
  12. error = "error"
  13. pulling = "pulling"
  14. unknown = 'unknown'
  15. class ValidCode(Enum):
  16. valid_error = 500
  17. model_not_fount = 404
  18. class DownModelChunk:
  19. def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
  20. self.details = details
  21. self.status = status
  22. self.digest = digest
  23. self.progress = progress
  24. self.index = index
  25. def to_dict(self):
  26. return {
  27. "details": self.details,
  28. "status": self.status.value,
  29. "digest": self.digest,
  30. "progress": self.progress,
  31. "index": self.index
  32. }
  33. class IModelProvider(ABC):
  34. @abstractmethod
  35. def get_model_info_manage(self):
  36. pass
  37. @abstractmethod
  38. def get_model_provide_info(self):
  39. pass
  40. def get_model_type_list(self):
  41. return self.get_model_info_manage().get_model_type_list()
  42. def get_model_list(self, model_type):
  43. if model_type is None:
  44. raise AppApiException(500, _('Model type cannot be empty'))
  45. return self.get_model_info_manage().get_model_list_by_model_type(model_type)
  46. def get_model_credential(self, model_type, model_name):
  47. model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
  48. model_credential = model_info.model_credential
  49. if model_type == 'TTI' and model_name.startswith(('qwen', 'wan2.6', 'wan')):
  50. if hasattr(model_credential, 'api_base'):
  51. api_base = model_credential.api_base
  52. if hasattr(api_base, 'default_value') and not api_base.default_value:
  53. api_base.default_value = 'https://dashscope.aliyuncs.com/api/v1'
  54. return model_credential
  55. def get_model_params(self, model_type, model_name):
  56. model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
  57. return model_info.model_credential
  58. def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object],
  59. model_params: Dict[str, object], raise_exception=False):
  60. model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
  61. return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self,
  62. raise_exception=raise_exception)
  63. def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel:
  64. model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
  65. return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs)
  66. def get_dialogue_number(self):
  67. return 3
  68. def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
  69. raise AppApiException(500, _('The current platform does not support downloading models'))
  70. class MaxKBBaseModel(ABC):
  71. @staticmethod
  72. @abstractmethod
  73. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  74. pass
  75. @staticmethod
  76. def is_cache_model():
  77. return True
  78. @staticmethod
  79. def filter_optional_params(model_kwargs):
  80. optional_params = {}
  81. for key, value in model_kwargs.items():
  82. if key not in ['model_id', 'use_local', 'streaming', 'show_ref_label', 'stream']:
  83. optional_params[key] = value
  84. return optional_params
  85. class BaseModelCredential(ABC):
  86. @abstractmethod
  87. def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider,
  88. raise_exception=True):
  89. pass
  90. @abstractmethod
  91. def encryption_dict(self, model_info: Dict[str, object]):
  92. """
  93. :param model_info: 模型数据
  94. :return: 加密后数据
  95. """
  96. pass
  97. def get_model_params_setting_form(self, model_name):
  98. """
  99. 模型参数设置表单
  100. :return:
  101. """
  102. pass
  103. @staticmethod
  104. def encryption(message: str):
  105. """
  106. 加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
  107. :param message:
  108. :return:
  109. """
  110. return encryption(message)
  111. class ModelTypeConst(Enum):
  112. LLM = {'code': 'LLM', 'message': _('LLM')}
  113. EMBEDDING = {'code': 'EMBEDDING', 'message': _('Embedding Model')}
  114. STT = {'code': 'STT', 'message': _('Speech2Text')}
  115. TTS = {'code': 'TTS', 'message': _('TTS')}
  116. IMAGE = {'code': 'IMAGE', 'message': _('Vision Model')}
  117. TTI = {'code': 'TTI', 'message': _('Image Generation')}
  118. RERANKER = {'code': 'RERANKER', 'message': _('Rerank')}
  119. # 文生视频 图生视频
  120. TTV = {'code': 'TTV', 'message': _('Text to Video')}
  121. ITV = {'code': 'ITV', 'message': _('Image to Video')}
  122. class ModelInfo:
  123. def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential,
  124. model_class: Type[MaxKBBaseModel],
  125. **keywords):
  126. self.name = name
  127. self.desc = desc
  128. self.model_type = model_type.name
  129. self.model_credential = model_credential
  130. self.model_class = model_class
  131. if keywords is not None:
  132. for key in keywords.keys():
  133. self.__setattr__(key, keywords.get(key))
  134. def get_name(self):
  135. """
  136. 获取模型名称
  137. :return: 模型名称
  138. """
  139. return self.name
  140. def get_desc(self):
  141. """
  142. 获取模型描述
  143. :return: 模型描述
  144. """
  145. return self.desc
  146. def get_model_type(self):
  147. return self.model_type
  148. def get_model_class(self):
  149. return self.model_class
  150. def to_dict(self):
  151. return reduce(lambda x, y: {**x, **y},
  152. [{attr: self.__getattribute__(attr)} for attr in vars(self) if
  153. not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {})
  154. class ModelInfoManage:
  155. def __init__(self):
  156. self.model_dict = {}
  157. self.model_list = []
  158. self.default_model_list = []
  159. self.default_model_dict = {}
  160. def append_model_info(self, model_info: ModelInfo):
  161. self.model_list.append(model_info)
  162. model_type_dict = self.model_dict.get(model_info.model_type)
  163. if model_type_dict is None:
  164. self.model_dict[model_info.model_type] = {model_info.name: model_info}
  165. else:
  166. model_type_dict[model_info.name] = model_info
  167. def append_default_model_info(self, model_info: ModelInfo):
  168. self.default_model_list.append(model_info)
  169. self.default_model_dict[model_info.model_type] = model_info
  170. def get_model_list(self):
  171. return [model.to_dict() for model in self.model_list]
  172. def get_model_list_by_model_type(self, model_type):
  173. return [model.to_dict() for model in self.model_list if model.model_type == model_type]
  174. def get_model_type_list(self):
  175. return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
  176. len([model for model in self.model_list if model.model_type == _type.name]) > 0]
  177. def get_model_info(self, model_type, model_name) -> ModelInfo:
  178. model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type))
  179. if model_info is None:
  180. raise AppApiException(500, _('The model does not support'))
  181. return model_info
  182. class builder:
  183. def __init__(self):
  184. self.modelInfoManage = ModelInfoManage()
  185. def append_model_info(self, model_info: ModelInfo):
  186. self.modelInfoManage.append_model_info(model_info)
  187. return self
  188. def append_model_info_list(self, model_info_list: List[ModelInfo]):
  189. for model_info in model_info_list:
  190. self.modelInfoManage.append_model_info(model_info)
  191. return self
  192. def append_default_model_info(self, model_info: ModelInfo):
  193. self.modelInfoManage.append_default_model_info(model_info)
  194. return self
  195. def build(self):
  196. return self.modelInfoManage
  197. class ModelProvideInfo:
  198. def __init__(self, provider: str, name: str, icon: str):
  199. self.provider = provider
  200. self.name = name
  201. self.icon = icon
  202. def to_dict(self):
  203. return reduce(lambda x, y: {**x, **y},
  204. [{attr: self.__getattribute__(attr)} for attr in vars(self) if
  205. not attr.startswith("__")], {})