model_generate.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :model_generate.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/14 14:22
  9. '''
  10. from langchain_core.prompts import ChatPromptTemplate
  11. from foundation.ai.models.model_handler import get_models
  12. from foundation.observability.logger.loggering import server_logger as logger
  13. import asyncio
  14. import time
  15. from typing import Optional, Callable, Any
  16. class GenerateModelClient:
  17. """
  18. 主要是生成式模型
  19. """
  20. def __init__(self, default_timeout: int = 60, max_retries: int = 3, backoff_factor: float = 1.0):
  21. # 获取部署的模型列表
  22. llm, chat, embed = get_models()
  23. self.llm = llm
  24. self.chat = chat
  25. # 配置参数
  26. self.default_timeout = default_timeout
  27. self.max_retries = max_retries
  28. self.backoff_factor = backoff_factor
  29. async def _retry_with_backoff(self, func: Callable, *args, timeout: Optional[int] = None, **kwargs):
  30. """
  31. 带指数退避的重试机制,每次重试都有独立的超时控制
  32. """
  33. current_timeout = timeout or self.default_timeout
  34. for attempt in range(self.max_retries + 1):
  35. try:
  36. # 每次重试都有独立的超时时间
  37. return await asyncio.wait_for(
  38. func(*args, **kwargs),
  39. timeout=current_timeout
  40. )
  41. except asyncio.TimeoutError:
  42. if attempt == self.max_retries:
  43. logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终超时")
  44. raise TimeoutError(f"模型调用在 {self.max_retries} 次重试后均超时")
  45. wait_time = self.backoff_factor * (2 ** attempt)
  46. logger.warning(f"[模型调用] 第 {attempt + 1} 次超时, {wait_time}秒后重试...")
  47. await asyncio.sleep(wait_time)
  48. except Exception as e:
  49. if attempt == self.max_retries:
  50. logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终失败: {str(e)}")
  51. raise
  52. wait_time = self.backoff_factor * (2 ** attempt)
  53. logger.warning(f"[模型调用] 第 {attempt + 1} 次尝试失败: {str(e)}, {wait_time}秒后重试...")
  54. await asyncio.sleep(wait_time)
  55. async def get_model_generate_invoke(self, trace_id: str, task_prompt_info: dict, timeout: Optional[int] = None):
  56. """
  57. 模型非流式生成(异步)
  58. """
  59. start_time = time.time()
  60. current_timeout = timeout or self.default_timeout
  61. try:
  62. logger.info(f"[模型调用] 开始处理 trace_id: {trace_id}, 超时配置: {current_timeout}s")
  63. prompt_template = task_prompt_info["task_prompt"]
  64. messages = prompt_template.format_messages()
  65. async def _invoke_model():
  66. loop = asyncio.get_event_loop()
  67. return await loop.run_in_executor(None, self.llm.invoke, messages)
  68. # 调用带重试机制的方法,超时控制在重试机制内部处理
  69. response = await self._retry_with_backoff(_invoke_model, timeout=current_timeout)
  70. elapsed_time = time.time() - start_time
  71. logger.info(f"[模型调用] 成功 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
  72. return response.content
  73. except asyncio.TimeoutError:
  74. elapsed_time = time.time() - start_time
  75. logger.error(f"[模型调用] 超时 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 超时阈值: {current_timeout}s")
  76. raise TimeoutError(f"模型调用超时,trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
  77. except Exception as e:
  78. elapsed_time = time.time() - start_time
  79. logger.error(f"[模型调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误类型: {type(e).__name__}, 错误信息: {str(e)}")
  80. raise
  81. def get_model_generate_stream(self, trace_id: str, task_prompt_info: dict, timeout: Optional[int] = None):
  82. """
  83. 模型流式生成(同步生成器)- 带异常处理
  84. """
  85. start_time = time.time()
  86. current_timeout = timeout or self.default_timeout
  87. try:
  88. logger.info(f"[模型流式调用] 开始处理 trace_id: {trace_id}, 超时配置: {current_timeout}s")
  89. prompt_template = task_prompt_info["task_prompt"]
  90. messages = prompt_template.format_messages()
  91. response = self.llm.stream(messages)
  92. chunk_count = 0
  93. for chunk in response:
  94. chunk_count += 1
  95. if hasattr(chunk, 'content') and chunk.content:
  96. yield chunk.content
  97. elif chunk: # 处理直接返回字符串的情况
  98. yield chunk
  99. elapsed_time = time.time() - start_time
  100. logger.info(f"[模型流式调用] 成功 trace_id: {trace_id}, 生成块数: {chunk_count}, 耗时: {elapsed_time:.2f}s")
  101. except Exception as e:
  102. elapsed_time = time.time() - start_time
  103. logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误类型: {type(e).__name__}, 错误信息: {str(e)}")
  104. raise
  105. generate_model_client = GenerateModelClient(default_timeout=15, max_retries=2, backoff_factor=0.5)