model_generate.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # !/usr/bin/ python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :model_generate.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/14 14:22
  9. '''
  10. from langchain_core.prompts import ChatPromptTemplate
  11. from foundation.ai.models.model_handler import get_models
  12. from foundation.observability.logger.loggering import server_logger as logger
  13. import asyncio
  14. import time
  15. from typing import Optional, Callable, Any
  16. class GenerateModelClient:
  17. """
  18. 主要是生成式模型
  19. """
  20. def __init__(self, default_timeout: int = 60, max_retries: int = 3, backoff_factor: float = 1.0):
  21. # 获取部署的模型列表
  22. llm, chat, embed = get_models()
  23. self.llm = llm
  24. self.chat = chat
  25. # 配置参数
  26. self.default_timeout = default_timeout
  27. self.max_retries = max_retries
  28. self.backoff_factor = backoff_factor
  29. async def _retry_with_backoff(self, func: Callable, *args, **kwargs):
  30. """
  31. 带指数退避的重试机制
  32. """
  33. for attempt in range(self.max_retries + 1):
  34. try:
  35. return await func(*args, **kwargs)
  36. except Exception as e:
  37. if attempt == self.max_retries:
  38. logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终失败: {str(e)}")
  39. raise
  40. wait_time = self.backoff_factor * (2 ** attempt)
  41. logger.warning(f"[模型调用] 第 {attempt + 1} 次尝试失败: {str(e)}, {wait_time}秒后重试...")
  42. await asyncio.sleep(wait_time)
  43. async def get_model_generate_invoke(self, trace_id: str, task_prompt_info: dict, timeout: Optional[int] = None):
  44. """
  45. 模型非流式生成(异步)
  46. """
  47. start_time = time.time()
  48. current_timeout = timeout or self.default_timeout
  49. try:
  50. logger.info(f"[模型调用] 开始处理 trace_id: {trace_id}, 超时时间: {current_timeout}s")
  51. prompt_template = task_prompt_info["task_prompt"]
  52. messages = prompt_template.format_messages()
  53. async def _invoke_model():
  54. loop = asyncio.get_event_loop()
  55. return await loop.run_in_executor(None, self.llm.invoke, messages)
  56. # 使用超时包装调用
  57. response = await asyncio.wait_for(
  58. self._retry_with_backoff(_invoke_model),
  59. timeout=current_timeout
  60. )
  61. elapsed_time = time.time() - start_time
  62. logger.info(f"[模型调用] 成功 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
  63. return response.content
  64. except asyncio.TimeoutError:
  65. elapsed_time = time.time() - start_time
  66. logger.error(f"[模型调用] 超时 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 超时阈值: {current_timeout}s")
  67. raise TimeoutError(f"模型调用超时,trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
  68. except Exception as e:
  69. elapsed_time = time.time() - start_time
  70. logger.error(f"[模型调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误类型: {type(e).__name__}, 错误信息: {str(e)}")
  71. raise
  72. def get_model_generate_stream(self, trace_id: str, task_prompt_info: dict, timeout: Optional[int] = None):
  73. """
  74. 模型流式生成(同步生成器)- 带异常处理
  75. """
  76. start_time = time.time()
  77. current_timeout = timeout or self.default_timeout
  78. try:
  79. logger.info(f"[模型流式调用] 开始处理 trace_id: {trace_id}, 超时时间: {current_timeout}s")
  80. prompt_template = task_prompt_info["task_prompt"]
  81. messages = prompt_template.format_messages()
  82. response = self.llm.stream(messages)
  83. chunk_count = 0
  84. for chunk in response:
  85. chunk_count += 1
  86. if hasattr(chunk, 'content') and chunk.content:
  87. yield chunk.content
  88. elif chunk: # 处理直接返回字符串的情况
  89. yield chunk
  90. elapsed_time = time.time() - start_time
  91. logger.info(f"[模型流式调用] 成功 trace_id: {trace_id}, 生成块数: {chunk_count}, 耗时: {elapsed_time:.2f}s")
  92. except Exception as e:
  93. elapsed_time = time.time() - start_time
  94. logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误类型: {type(e).__name__}, 错误信息: {str(e)}")
  95. raise
  96. generate_model_client = GenerateModelClient(default_timeout=120, max_retries=3, backoff_factor=1.0)