model_handler.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. AI模型处理器
  5. 用于管理生成、与嵌入模型的创建和配置
  6. 支持的模型类型:
  7. - doubao: 豆包模型
  8. - qwen: 通义千问模型
  9. - deepseek: DeepSeek模型
  10. - gemini: Gemini模型
  11. - lq_qwen3_8b: 本地Qwen3-8B模型
  12. - lq_qwen3_4b: 本地Qwen3-4B模型
  13. - qwen_local_14b: 本地Qwen3-14B模型
  14. - lq_qwen3_8b_emd: 本地Qwen3-Embedding-8B嵌入模型
  15. - lq_bge_reranker_v2_m3: 本地BGE-reranker-v2-m3重排序模型
  16. """
  17. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  18. from foundation.infrastructure.config.config import config_handler
  19. from foundation.observability.logger.loggering import server_logger as logger
  20. class ModelHandler:
  21. """
  22. AI模型处理器类,用于管理多种AI模型的创建和配置
  23. """
  24. def __init__(self):
  25. """
  26. 初始化模型处理器
  27. 加载配置处理器,用于后续读取各种模型的配置信息
  28. """
  29. self.config = config_handler
  30. def get_models(self):
  31. """
  32. 获取AI模型实例
  33. Returns:
  34. ChatOpenAI: 配置好的AI模型实例
  35. Note:
  36. 根据配置文件中的MODEL_TYPE参数选择对应模型
  37. 支持的模型类型:doubao, qwen, deepseek, lq_qwen3_8b, lq_qwen3_4b, qwen_local_14b
  38. 默认返回豆包模型
  39. """
  40. model_type = self.config.get("model", "MODEL_TYPE")
  41. logger.info(f"正在初始化AI模型,模型类型: {model_type}")
  42. if model_type == "doubao":
  43. model = self._get_doubao_model()
  44. elif model_type == "gemini":
  45. model = self._get_gemini_model()
  46. elif model_type == "qwen":
  47. model = self._get_qwen_model()
  48. elif model_type == "deepseek":
  49. model = self._get_deepseek_model()
  50. elif model_type == "lq_qwen3_8b":
  51. model = self._get_lq_qwen3_8b_model()
  52. elif model_type == "lq_qwen3_4b":
  53. model = self._get_lq_qwen3_4b_model()
  54. elif model_type == "qwen_local_14b":
  55. model = self._get_qwen_local_14b_model()
  56. else:
  57. # 默认返回gemini
  58. logger.warning(f"未知的模型类型 '{model_type}',使用默认gemini模型")
  59. model = self._get_gemini_model()
  60. logger.info(f"AI模型初始化完成: {model_type}")
  61. return model
  62. def _get_doubao_model(self):
  63. """
  64. 获取豆包模型
  65. Returns:
  66. ChatOpenAI: 配置好的豆包模型实例
  67. """
  68. doubao_url = self.config.get("doubao", "DOUBAO_SERVER_URL")
  69. doubao_model_id = self.config.get("doubao", "DOUBAO_MODEL_ID")
  70. doubao_api_key = self.config.get("doubao", "DOUBAO_API_KEY")
  71. llm = ChatOpenAI(
  72. base_url=doubao_url,
  73. model=doubao_model_id,
  74. api_key=doubao_api_key,
  75. temperature=0.7,
  76. extra_body={
  77. "enable_thinking": False,
  78. })
  79. return llm
  80. def _get_qwen_model(self):
  81. """
  82. 获取通义千问模型
  83. Returns:
  84. ChatOpenAI: 配置好的通义千问模型实例
  85. """
  86. qwen_url = self.config.get("qwen", "QWEN_SERVER_URL")
  87. qwen_model_id = self.config.get("qwen", "QWEN_MODEL_ID")
  88. qwen_api_key = self.config.get("qwen", "QWEN_API_KEY")
  89. llm = ChatOpenAI(
  90. base_url=qwen_url,
  91. model=qwen_model_id,
  92. api_key=qwen_api_key,
  93. temperature=0.7,
  94. extra_body={
  95. "enable_thinking": False,
  96. })
  97. return llm
  98. def _get_deepseek_model(self):
  99. """
  100. 获取DeepSeek模型
  101. Returns:
  102. ChatOpenAI: 配置好的DeepSeek模型实例
  103. """
  104. deepseek_url = self.config.get("deepseek", "DEEPSEEK_SERVER_URL")
  105. deepseek_model_id = self.config.get("deepseek", "DEEPSEEK_MODEL_ID")
  106. deepseek_api_key = self.config.get("deepseek", "DEEPSEEK_API_KEY")
  107. llm = ChatOpenAI(
  108. base_url=deepseek_url,
  109. model=deepseek_model_id,
  110. api_key=deepseek_api_key,
  111. temperature=0.7,
  112. extra_body={
  113. "enable_thinking": False,
  114. })
  115. return llm
  116. def _get_gemini_model(self):
  117. """
  118. 获取Gemini模型
  119. Returns:
  120. ChatOpenAI: 配置好的Gemini模型实例
  121. """
  122. gemini_url = self.config.get("gemini", "GEMINI_SERVER_URL")
  123. gemini_model_id = self.config.get("gemini", "GEMINI_MODEL_ID")
  124. gemini_api_key = self.config.get("gemini", "GEMINI_API_KEY")
  125. llm = ChatOpenAI(
  126. base_url=gemini_url,
  127. model=gemini_model_id,
  128. api_key=gemini_api_key,
  129. temperature=0.7,
  130. )
  131. return llm
  132. def _get_lq_qwen3_8b_model(self):
  133. """
  134. 获取本地Qwen3-8B-Instruct模型
  135. Returns:
  136. ChatOpenAI: 配置好的本地Qwen3-8B模型实例
  137. """
  138. llm = ChatOpenAI(
  139. base_url="http://192.168.91.253:9002/v1",
  140. model="Qwen3-8B",
  141. api_key="dummy", # 本地模型使用虚拟API key
  142. temperature=0.7,
  143. )
  144. return llm
  145. def _get_lq_qwen3_4b_model(self):
  146. """
  147. 获取本地Qwen3-4B-Instruct模型
  148. Returns:
  149. ChatOpenAI: 配置好的本地Qwen3-4B模型实例
  150. """
  151. llm = ChatOpenAI(
  152. base_url="http://192.168.91.253:9001/v1",
  153. model="Qwen3-4B",
  154. api_key="dummy", # 本地模型使用虚拟API key
  155. temperature=0.7,
  156. )
  157. return llm
  158. def _get_qwen_local_14b_model(self):
  159. """
  160. 获取本地Qwen3-14B-Instruct模型
  161. Returns:
  162. ChatOpenAI: 配置好的本地Qwen3-14B模型实例
  163. """
  164. llm = ChatOpenAI(
  165. base_url="http://192.168.91.253:9003/v1",
  166. model="Qwen3-14B",
  167. api_key="dummy", # 本地模型使用虚拟API key
  168. temperature=0.7,
  169. )
  170. return llm
  171. def _get_lq_qwen3_8b_emd(self):
  172. """
  173. 获取本地Qwen3-Embedding-8B嵌入模型
  174. Returns:
  175. OpenAIEmbeddings: 配置好的本地Qwen3-Embedding-8B嵌入模型实例
  176. """
  177. embeddings = OpenAIEmbeddings(
  178. base_url="http://192.168.91.253:9003/v1",
  179. model="Qwen3-Embedding-8B",
  180. api_key="dummy", # 本地模型使用虚拟API key
  181. )
  182. return embeddings
  183. # 创建全局实例
  184. model_handler = ModelHandler()
  185. def get_models():
  186. """
  187. 获取模型的全局函数
  188. Returns:
  189. tuple: (llm, chat, embed) - LLM模型、聊天模型和嵌入模型实例
  190. 注意:当前llm和chat使用相同模型实例,embed暂时返回None
  191. Note:
  192. 这是一个便捷函数,直接使用全局model_handler实例获取模型
  193. """
  194. llm = model_handler.get_models()
  195. # 暂时返回相同的模型作为chat和embed
  196. return llm, llm, None