| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- # !/usr/bin/ python
- # -*- coding: utf-8 -*-
- '''
- @Project : lq-agent-api
- @File :model_generate.py
- @IDE :PyCharm
- @Author :
- @Date :2025/7/14 14:22
- '''
- from langchain_core.prompts import ChatPromptTemplate
- from langchain_core.messages import BaseMessage, SystemMessage, HumanMessage
- from foundation.ai.models.model_handler import model_handler
- from foundation.observability.logger.loggering import review_logger as logger
- import asyncio
- import time
- from typing import Optional, Callable, Any, List, Union
- class GenerateModelClient:
- """
- 主要是生成式模型
- """
- def __init__(self, default_timeout: int = 60, max_retries: int = 3, backoff_factor: float = 1.0):
- # 获取默认模型
- self.llm = model_handler.get_models()
- self.chat = self.llm # 当前chat和llm使用相同模型
- # 配置参数
- self.default_timeout = default_timeout
- self.max_retries = max_retries
- self.backoff_factor = backoff_factor
- # 保存model_handler引用,用于动态获取模型
- self.model_handler = model_handler
- async def _retry_with_backoff(self, func: Callable, *args, timeout: Optional[int] = None, **kwargs):
- """
- 带指数退避的重试机制,每次重试都有独立的超时控制
- """
- current_timeout = timeout or self.default_timeout
- for attempt in range(self.max_retries + 1):
- try:
- # 每次重试都有独立的超时时间
- return await asyncio.wait_for(
- func(*args, **kwargs),
- timeout=current_timeout
- )
- except asyncio.TimeoutError:
- if attempt == self.max_retries:
- logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终超时")
- raise TimeoutError(f"模型调用在 {self.max_retries} 次重试后均超时")
- wait_time = self.backoff_factor * (2 ** attempt)
- logger.warning(f"[模型调用] 第 {attempt + 1} 次超时, {wait_time}秒后重试...")
- await asyncio.sleep(wait_time)
- except Exception as e:
- if attempt == self.max_retries:
- logger.error(f"[模型调用] 达到最大重试次数 {self.max_retries},最终失败: {str(e)}")
- raise
- wait_time = self.backoff_factor * (2 ** attempt)
- logger.warning(f"[模型调用] 第 {attempt + 1} 次尝试失败: {str(e)}, {wait_time}秒后重试...")
- await asyncio.sleep(wait_time)
- async def get_model_generate_invoke(
- self,
- trace_id: str,
- task_prompt_info: Optional[dict] = None,
- messages: Optional[List[BaseMessage]] = None,
- system_prompt: Optional[str] = None,
- user_prompt: Optional[str] = None,
- prompt: Optional[str] = None,
- timeout: Optional[int] = None,
- model_name: Optional[str] = None
- ) -> str:
- """模型非流式生成(异步)
- 支持多种调用方式(优先级从高到低):
- 1. messages: 直接传入 LangChain Message 对象列表
- 2. system_prompt + user_prompt: 分别传入系统和用户提示词
- 3. prompt: 传入单条用户提示词字符串
- 4. task_prompt_info: 传入包含 ChatPromptTemplate 的字典(兼容旧接口)
- Args:
- trace_id: 追踪ID
- task_prompt_info: 任务提示词信息(兼容旧接口),需包含 format_messages() 方法
- messages: LangChain Message 对象列表(如 [SystemMessage, HumanMessage])
- system_prompt: 系统提示词字符串
- user_prompt: 用户提示词字符串
- prompt: 单条用户提示词字符串(无系统提示时使用)
- timeout: 超时时间(秒),默认使用构造时的 default_timeout
- model_name: 模型名称(可选),支持 doubao/qwen/deepseek/gemini 等
- Returns:
- str: 模型生成的文本内容
- Raises:
- ValueError: 参数组合错误
- TimeoutError: 调用超时
- Exception: 模型调用异常
- Examples:
- # 方式1: 使用 Message 列表(推荐)
- messages = [SystemMessage(content="你是专家"), HumanMessage(content="请分析...")]
- result = await client.get_model_generate_invoke("trace-001", messages=messages)
- # 方式2: 分别传入系统和用户提示词
- result = await client.get_model_generate_invoke(
- "trace-001",
- system_prompt="你是专家",
- user_prompt="请分析..."
- )
- # 方式3: 传入单条提示词
- result = await client.get_model_generate_invoke("trace-001", prompt="请分析...")
- # 方式4: 兼容旧接口(使用 PromptLoader)
- task_prompt_info = {"task_prompt": chat_template}
- result = await client.get_model_generate_invoke("trace-001", task_prompt_info=task_prompt_info)
- """
- start_time = time.time()
- current_timeout = timeout or self.default_timeout
- try:
- # 选择模型
- llm_to_use = self.model_handler.get_model_by_name(model_name) if model_name else self.llm
- logger.info(f"[模型调用] 使用{'指定' if model_name else '默认'}模型: {model_name or 'default'}, trace_id: {trace_id}")
- # 构建消息列表(按优先级)
- final_messages = self._build_messages(
- messages=messages,
- system_prompt=system_prompt,
- user_prompt=user_prompt,
- prompt=prompt,
- task_prompt_info=task_prompt_info
- )
- # 定义模型调用函数,使用原生 ainvoke
- async def _invoke():
- return await llm_to_use.ainvoke(final_messages)
- # 调用带重试机制
- response = await self._retry_with_backoff(_invoke, timeout=current_timeout)
- elapsed_time = time.time() - start_time
- logger.info(f"[模型调用] 成功 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s")
- return response.content
- except asyncio.TimeoutError:
- elapsed_time = time.time() - start_time
- logger.error(f"[模型调用] 超时 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 超时阈值: {current_timeout}s")
- raise TimeoutError(f"模型调用超时,trace_id: {trace_id}")
- except Exception as e:
- elapsed_time = time.time() - start_time
- logger.error(f"[模型调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
- raise
- def _build_messages(
- self,
- messages: Optional[List[BaseMessage]] = None,
- system_prompt: Optional[str] = None,
- user_prompt: Optional[str] = None,
- prompt: Optional[str] = None,
- task_prompt_info: Optional[dict] = None
- ) -> List[BaseMessage]:
- """构建消息列表(内部方法)
- 优先级:messages > system_prompt+user_prompt > prompt > task_prompt_info
- """
- # 方式1: 直接使用传入的 Message 列表
- if messages is not None:
- if not isinstance(messages, list):
- raise ValueError("messages 必须是列表")
- if len(messages) == 0:
- raise ValueError("messages 不能为空列表")
- logger.debug(f"使用传入的 messages 列表,共 {len(messages)} 条消息")
- return messages
- # 方式2: system_prompt + user_prompt
- if system_prompt is not None and user_prompt is not None:
- logger.debug("使用 system_prompt + user_prompt 构建消息")
- return [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
- # 方式3: 单独 system_prompt(可能是特殊情况)
- if system_prompt is not None:
- logger.debug("使用单独的 system_prompt 构建消息")
- return [SystemMessage(content=system_prompt)]
- # 方式4: 单条 prompt 字符串
- if prompt is not None:
- logger.debug("使用单条 prompt 字符串构建消息")
- return [HumanMessage(content=prompt)]
- # 方式5: 兼容旧接口 task_prompt_info
- if task_prompt_info is not None:
- if "task_prompt" not in task_prompt_info:
- raise ValueError("task_prompt_info 必须包含 'task_prompt' 键")
- task_prompt = task_prompt_info["task_prompt"]
- if hasattr(task_prompt, 'format_messages'):
- logger.debug("使用 task_prompt_info 中的 ChatPromptTemplate 构建消息")
- return task_prompt.format_messages()
- elif isinstance(task_prompt, str):
- logger.debug("使用 task_prompt_info 中的字符串构建消息")
- return [HumanMessage(content=task_prompt)]
- else:
- raise ValueError(f"task_prompt 类型不支持: {type(task_prompt)}")
- # 没有提供任何有效参数
- raise ValueError(
- "必须提供以下参数之一: "
- "messages, system_prompt+user_prompt, prompt, 或 task_prompt_info"
- )
- def get_model_generate_stream(
- self,
- trace_id: str,
- task_prompt_info: Optional[dict] = None,
- messages: Optional[List[BaseMessage]] = None,
- system_prompt: Optional[str] = None,
- user_prompt: Optional[str] = None,
- prompt: Optional[str] = None,
- timeout: Optional[int] = None
- ):
- """模型流式生成(同步生成器)
- 支持多种调用方式(优先级从高到低):
- 1. messages: 直接传入 LangChain Message 对象列表
- 2. system_prompt + user_prompt: 分别传入系统和用户提示词
- 3. prompt: 传入单条用户提示词字符串
- 4. task_prompt_info: 传入包含 ChatPromptTemplate 的字典(兼容旧接口)
- Args:
- trace_id: 追踪ID
- task_prompt_info: 任务提示词信息(兼容旧接口)
- messages: LangChain Message 对象列表
- system_prompt: 系统提示词字符串
- user_prompt: 用户提示词字符串
- prompt: 单条用户提示词字符串
- timeout: 超时时间(秒)
- Yields:
- str: 生成的文本块
- Raises:
- ValueError: 参数组合错误
- """
- start_time = time.time()
- current_timeout = timeout or self.default_timeout
- try:
- logger.info(f"[模型流式调用] 开始处理 trace_id: {trace_id}, 超时配置: {current_timeout}s")
- # 构建消息列表
- final_messages = self._build_messages(
- messages=messages,
- system_prompt=system_prompt,
- user_prompt=user_prompt,
- prompt=prompt,
- task_prompt_info=task_prompt_info
- )
- response = self.llm.stream(final_messages)
- chunk_count = 0
- for chunk in response:
- chunk_count += 1
- if hasattr(chunk, 'content') and chunk.content:
- yield chunk.content
- elif chunk:
- yield chunk
- elapsed_time = time.time() - start_time
- logger.info(f"[模型流式调用] 成功 trace_id: {trace_id}, 生成块数: {chunk_count}, 耗时: {elapsed_time:.2f}s")
- except Exception as e:
- elapsed_time = time.time() - start_time
- logger.error(f"[模型流式调用] 异常 trace_id: {trace_id}, 耗时: {elapsed_time:.2f}s, 错误: {type(e).__name__}: {str(e)}")
- raise
- generate_model_client = GenerateModelClient(default_timeout=15, max_retries=2, backoff_factor=0.5)
|