model_generate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :model_generate.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/14 14:22
  9. '''
  10. from langchain_core.prompts import ChatPromptTemplate
  11. from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
  12. from foundation.ai.models.model_handler import model_handler
  13. from foundation.observability.logger.loggering import review_logger as logger
  14. import asyncio
  15. import time
  16. from typing import Optional, Callable, Any, List, Union
  17. class GenerateModelClient:
  18. """
  19. 主要是生成式模型
  20. """
  21. def __init__(self, default_timeout: int = 60, max_retries: int = 3, backoff_factor: float = 1.0):
  22. # 获取默认模型
  23. self.llm = model_handler.get_models()
  24. self.chat = self.llm # 当前chat和llm使用相同模型
  25. # 配置参数
  26. self.default_timeout = default_timeout
  27. self.max_retries = max_retries
  28. self.backoff_factor = backoff_factor
  29. # 保存model_handler引用,用于动态获取模型
  30. self.model_handler = model_handler
  31. async def _retry_with_backoff(self, func: Callable, *args, timeout: Optional[int] = None, trace_id: Optional[str] = None, model_name: Optional[str] = None, **kwargs):
  32. """
  33. 带指数退避的重试机制,每次重试都有独立的超时控制
  34. 注意:对于 502/503/504 等服务不可用错误,立即失败不重试,
  35. 避免在服务端过载时继续加重负载。
  36. """
  37. current_timeout = timeout or self.default_timeout
  38. model_info = model_name or "default"
  39. def _is_server_unavailable_error(error: Exception) -> bool:
  40. """判断是否为服务端不可用错误(应立即失败)"""
  41. error_str = str(error).lower()
  42. # 502: Bad Gateway, 503: Service Unavailable, 504: Gateway Timeout
  43. unavailable_codes = ['502', '503', '504', 'internal server error']
  44. return any(code in error_str for code in unavailable_codes)
  45. for attempt in range(self.max_retries + 1):
  46. try:
  47. # 每次重试都有独立的超时时间
  48. return await asyncio.wait_for(
  49. func(*args, **kwargs),
  50. timeout=current_timeout
  51. )
  52. except asyncio.TimeoutError as e:
  53. if attempt == self.max_retries:
  54. logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终超时 | trace_id: {trace_id}, model: {model_info}, timeout: {current_timeout}s, error_type: {type(e).__name__}, error_msg: {str(e)}")
  55. raise TimeoutError(f"模型调用在 {self.max_retries} 次重试后均超时")
  56. wait_time = self.backoff_factor * (2 ** attempt)
  57. logger.warning(f"[模型调用] 第 {attempt + 1} 次超时, {wait_time}秒后重试... | trace_id: {trace_id}, model: {model_info}, timeout: {current_timeout}s, error_type: {type(e).__name__}, error_msg: {str(e)}")
  58. await asyncio.sleep(wait_time)
  59. except Exception as e:
  60. error_str = str(e)
  61. # 服务端不可用错误(502/503/504)立即失败,不重试
  62. if _is_server_unavailable_error(e):
  63. logger.error(f"[模型调用] 服务端不可用,立即失败: {error_str} | trace_id: {trace_id}, model: {model_info}")
  64. raise
  65. if attempt == self.max_retries:
  66. logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终失败: {error_str} | trace_id: {trace_id}, model: {model_info}")
  67. raise
  68. wait_time = self.backoff_factor * (2 ** attempt)
  69. logger.warning(f"[模型调用] 第 {attempt + 1} 次尝试失败: {error_str}, {wait_time}秒后重试... | trace_id: {trace_id}, model: {model_info}")
  70. await asyncio.sleep(wait_time)
  71. async def get_model_generate_invoke(
  72. self,
  73. trace_id: str,
  74. task_prompt_info: Optional[dict] = None,
  75. messages: Optional[List[BaseMessage]] = None,
  76. system_prompt: Optional[str] = None,
  77. user_prompt: Optional[str] = None,
  78. prompt: Optional[str] = None,
  79. timeout: Optional[int] = None,
  80. model_name: Optional[str] = None
  81. ) -> str:
  82. """模型非流式生成(异步)
  83. 支持多种调用方式(优先级从高到低):
  84. 1. messages: 直接传入 LangChain Message 对象列表
  85. 2. system_prompt + user_prompt: 分别传入系统和用户提示词
  86. 3. prompt: 传入单条用户提示词字符串
  87. 4. task_prompt_info: 传入包含 ChatPromptTemplate 的字典(兼容旧接口)
  88. Args:
  89. trace_id: 追踪ID
  90. task_prompt_info: 任务提示词信息(兼容旧接口),需包含 format_messages() 方法
  91. messages: LangChain Message 对象列表(如 [SystemMessage, HumanMessage])
  92. system_prompt: 系统提示词字符串
  93. user_prompt: 用户提示词字符串
  94. prompt: 单条用户提示词字符串(无系统提示时使用)
  95. timeout: 超时时间(秒),默认使用构造时的 default_timeout
  96. model_name: 模型名称(可选),支持 doubao/qwen/deepseek/gemini 等
  97. Returns:
  98. str: 模型生成的文本内容
  99. Raises:
  100. ValueError: 参数组合错误
  101. TimeoutError: 调用超时
  102. Exception: 模型调用异常
  103. Examples:
  104. # 方式1: 使用 Message 列表(推荐)
  105. messages = [SystemMessage(content="你是专家"), HumanMessage(content="请分析...")]
  106. result = await client.get_model_generate_invoke("trace-001", messages=messages)
  107. # 方式2: 分别传入系统和用户提示词
  108. result = await client.get_model_generate_invoke(
  109. "trace-001",
  110. system_prompt="你是专家",
  111. user_prompt="请分析..."
  112. )
  113. # 方式3: 传入单条提示词
  114. result = await client.get_model_generate_invoke("trace-001", prompt="请分析...")
  115. # 方式4: 兼容旧接口(使用 PromptLoader)
  116. task_prompt_info = {"task_prompt": chat_template}
  117. result = await client.get_model_generate_invoke("trace-001", task_prompt_info=task_prompt_info)
  118. """
  119. start_time = time.time()
  120. current_timeout = timeout or self.default_timeout
  121. try:
  122. # 选择模型
  123. llm_to_use = self.model_handler.get_model_by_name(model_name) if model_name else self.llm
  124. logger.info(f"[模型调用] 使用{'指定' if model_name else '默认'}模型: {model_name or 'default'}, trace_id: {trace_id}")
  125. # 构建消息列表(按优先级)
  126. final_messages = self._build_messages(
  127. messages=messages,
  128. system_prompt=system_prompt,
  129. user_prompt=user_prompt,
  130. prompt=prompt,
  131. task_prompt_info=task_prompt_info
  132. )
  133. # 定义模型调用函数,使用原生 ainvoke
  134. async def _invoke():
  135. return await llm_to_use.ainvoke(final_messages)
  136. # 调用带重试机制
  137. response = await self._retry_with_backoff(_invoke, timeout=current_timeout, trace_id=trace_id, model_name=model_name or "default")
  138. elapsed_time = time.time() - start_time
  139. logger.info(f"[模型调用] 成功 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
  140. return response.content
  141. except asyncio.TimeoutError:
  142. elapsed_time = time.time() - start_time
  143. logger.error(f"[模型调用] 超时 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 超时阈值: {current_timeout}s")
  144. raise TimeoutError(f"模型调用超时,trace_id: {trace_id}")
  145. except Exception as e:
  146. elapsed_time = time.time() - start_time
  147. logger.error(f"[模型调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
  148. raise
  149. def _build_messages(
  150. self,
  151. messages: Optional[List[BaseMessage]] = None,
  152. system_prompt: Optional[str] = None,
  153. user_prompt: Optional[str] = None,
  154. prompt: Optional[str] = None,
  155. task_prompt_info: Optional[dict] = None
  156. ) -> List[BaseMessage]:
  157. """构建消息列表(内部方法)
  158. 优先级:messages > system_prompt+user_prompt > prompt > task_prompt_info
  159. """
  160. # 方式1: 直接使用传入的 Message 列表
  161. if messages is not None:
  162. if not isinstance(messages, list):
  163. raise ValueError("messages 必须是列表")
  164. if len(messages) == 0:
  165. raise ValueError("messages 不能为空列表")
  166. logger.debug(f"使用传入的 messages 列表,共 {len(messages)} 条消息")
  167. return messages
  168. # 方式2: system_prompt + user_prompt
  169. if system_prompt is not None and user_prompt is not None:
  170. logger.debug("使用 system_prompt + user_prompt 构建消息")
  171. return [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
  172. # 方式3: 单独 system_prompt(可能是特殊情况)
  173. if system_prompt is not None:
  174. logger.debug("使用单独的 system_prompt 构建消息")
  175. return [SystemMessage(content=system_prompt)]
  176. # 方式4: 单条 prompt 字符串
  177. if prompt is not None:
  178. logger.debug("使用单条 prompt 字符串构建消息")
  179. return [HumanMessage(content=prompt)]
  180. # 方式5: 兼容旧接口 task_prompt_info
  181. if task_prompt_info is not None:
  182. if "task_prompt" not in task_prompt_info:
  183. raise ValueError("task_prompt_info 必须包含 'task_prompt' 键")
  184. task_prompt = task_prompt_info["task_prompt"]
  185. if hasattr(task_prompt, 'format_messages'):
  186. logger.debug("使用 task_prompt_info 中的 ChatPromptTemplate 构建消息")
  187. return task_prompt.format_messages()
  188. elif isinstance(task_prompt, str):
  189. logger.debug("使用 task_prompt_info 中的字符串构建消息")
  190. return [HumanMessage(content=task_prompt)]
  191. else:
  192. raise ValueError(f"task_prompt 类型不支持: {type(task_prompt)}")
  193. # 没有提供任何有效参数
  194. raise ValueError(
  195. "必须提供以下参数之一: "
  196. "messages, system_prompt+user_prompt, prompt, 或 task_prompt_info"
  197. )
  198. def get_model_generate_stream(
  199. self,
  200. trace_id: str,
  201. task_prompt_info: Optional[dict] = None,
  202. messages: Optional[List[BaseMessage]] = None,
  203. system_prompt: Optional[str] = None,
  204. user_prompt: Optional[str] = None,
  205. prompt: Optional[str] = None,
  206. timeout: Optional[int] = None,
  207. model_name: Optional[str] = None
  208. ):
  209. """模型流式生成(同步生成器)
  210. 支持多种调用方式(优先级从高到低):
  211. 1. messages: 直接传入 LangChain Message 对象列表
  212. 2. system_prompt + user_prompt: 分别传入系统和用户提示词
  213. 3. prompt: 传入单条用户提示词字符串
  214. 4. task_prompt_info: 传入包含 ChatPromptTemplate 的字典(兼容旧接口)
  215. Args:
  216. trace_id: 追踪ID
  217. task_prompt_info: 任务提示词信息(兼容旧接口)
  218. messages: LangChain Message 对象列表
  219. system_prompt: 系统提示词字符串
  220. user_prompt: 用户提示词字符串
  221. prompt: 单条用户提示词字符串
  222. timeout: 超时时间(秒)
  223. model_name: 模型名称(可选),支持 doubao/qwen/deepseek/gemini 等
  224. Yields:
  225. str: 生成的文本块
  226. Raises:
  227. ValueError: 参数组合错误
  228. """
  229. start_time = time.time()
  230. current_timeout = timeout or self.default_timeout
  231. try:
  232. # 选择模型
  233. llm_to_use = self.model_handler.get_model_by_name(model_name) if model_name else self.llm
  234. logger.info(f"[模型流式调用] 使用{'指定' if model_name else '默认'}模型:{model_name or 'default'}, trace_id: {trace_id}")
  235. logger.info(f"[模型流式调用] 开始处理 trace_id: {trace_id}, 超时配置: {current_timeout}s")
  236. # 构建消息列表
  237. final_messages = self._build_messages(
  238. messages=messages,
  239. system_prompt=system_prompt,
  240. user_prompt=user_prompt,
  241. prompt=prompt,
  242. task_prompt_info=task_prompt_info
  243. )
  244. response = llm_to_use.stream(final_messages)
  245. chunk_count = 0
  246. for chunk in response:
  247. chunk_count += 1
  248. if hasattr(chunk, 'content') and chunk.content:
  249. yield chunk.content
  250. elif chunk:
  251. yield chunk
  252. elapsed_time = time.time() - start_time
  253. logger.info(f"[模型流式调用] 成功 trace_id: {trace_id}, 生成块数: {chunk_count}, 耗时: {elapsed_time:.2f}s")
  254. except Exception as e:
  255. elapsed_time = time.time() - start_time
  256. logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
  257. raise
  258. generate_model_client = GenerateModelClient(default_timeout=60, max_retries=10, backoff_factor=0.5)