test_agent.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :agent_mcp.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/21 10:12
  9. '''
  10. import json
  11. from langgraph.prebuilt import create_react_agent
  12. from sqlalchemy.sql.functions import user
  13. from foundation.observability.logger.loggering import server_logger
  14. from foundation.utils.common import handler_err
  15. from foundation.ai.models import get_models
  16. from foundation.utils.yaml_utils import get_system_prompt_config
  17. import threading
  18. import time
  19. from typing import Dict, List, Optional, AsyncGenerator, Any, OrderedDict
  20. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  21. from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
  22. from langchain_core.runnables import RunnableConfig
  23. from foundation.ai.agent.base_agent import BaseAgent
  24. from foundation.schemas.test_schemas import TestForm
  25. # from foundation.agent.function.test_funciton import test_funtion
  26. class TestAgentClient(BaseAgent):
  27. """
  28. Xiwuzc 智能助手+MCP(带完整会话管理) - 针对场景优化
  29. 添加会话锁定机制,确保同一时间只有一个客户端可以使用特定会话
  30. """
  31. # 单例实例和线程锁
  32. _instance = None
  33. _singleton_lock = threading.Lock()
  34. def __new__(cls):
  35. """线程安全的单例模式实现"""
  36. if cls._instance is None:
  37. with cls._singleton_lock:
  38. if cls._instance is None:
  39. cls._instance = super().__new__(cls)
  40. cls._instance._initialize()
  41. return cls._instance
  42. def _initialize(self):
  43. """初始化模型和会话管理"""
  44. llm, chat, embed = get_models()
  45. self.llm = llm
  46. self.chat = chat
  47. self.embed = embed
  48. self.agent_executor = None
  49. self.initialized = False
  50. self.psutil_available = True
  51. # 固定系统提示词
  52. self.system_prompt = get_system_prompt_config()["system_prompt"]
  53. # 清理任务
  54. self.cleanup_task = None
  55. server_logger.info(" client initialized")
  56. async def init_agent(self):
  57. """初始化agent_executor(只需一次)"""
  58. if self.initialized:
  59. return
  60. # 获取部署的模型列表
  61. server_logger.info(f"系统提示词 system_prompt:{self.system_prompt}")
  62. # 创建提示词模板 - 使用固定的系统提示词
  63. prompt = ChatPromptTemplate.from_messages([
  64. ("system", self.system_prompt),
  65. MessagesPlaceholder(variable_name="messages"),
  66. ("placeholder", "{agent_scratchpad}")
  67. ])
  68. # # 创建Agent - 不再使用MemorySaver
  69. # self.agent_executor = create_react_agent(
  70. # self.llm,
  71. # tools=[test_funtion.query_info , test_funtion.execute , test_funtion.handle] , # 专用工具集 + 私有知识库检索工具
  72. # prompt=prompt
  73. # )
  74. self.initialized = True
  75. server_logger.info(" agent initialized")
  76. async def handle_query(self, trace_id: str, task_prompt_info: dict, input_query, context=None,
  77. config_param: TestForm = None):
  78. try:
  79. # 确保agent已初始化
  80. if not self.initialized:
  81. await self.init_agent()
  82. session_id = config_param.session_id
  83. try:
  84. # 构建输入消息
  85. input_message , input_summary_context = self.get_input_context(
  86. trace_id=trace_id,
  87. task_prompt_info=task_prompt_info,
  88. input_query=input_query,
  89. context=context
  90. )
  91. # 用于模型对话使用
  92. input_human_message = HumanMessage(content=input_message)
  93. # 用于对话历史记录摘要
  94. input_human_summary_message = HumanMessage(content=input_summary_context)
  95. # 获取历史消息
  96. history_messages = []
  97. # 构造完整的消息列表
  98. all_messages = list(history_messages) + [input_human_message]
  99. # 配置执行上下文
  100. config = RunnableConfig(
  101. configurable={"thread_id": session_id},
  102. runnable_kwargs={"recursion_limit": 15}
  103. )
  104. # 执行智能体
  105. events = self.agent_executor.astream(
  106. {"messages": all_messages},
  107. config=config,
  108. stream_mode="values"
  109. )
  110. # 处理结果
  111. full_response = []
  112. async for event in events:
  113. if isinstance(event["messages"][-1], AIMessage):
  114. chunk = event["messages"][-1].content
  115. full_response.append(chunk)
  116. log_content = self.get_pretty_message_str(event["messages"][-1])
  117. server_logger.info("\n" + log_content.strip(), trace_id=trace_id)
  118. if full_response:
  119. full_text = "".join(full_response)
  120. server_logger.info(trace_id=trace_id, msg=f"full_response: {full_text}")
  121. full_text = self.clean_json_output(full_text)
  122. return full_text
  123. finally:
  124. # 确保释放会话锁
  125. pass
  126. except PermissionError as e:
  127. # 处理会话被其他设备锁定的情况
  128. return str(e)
  129. except Exception as e:
  130. handler_err(server_logger, trace_id=trace_id, err=e, err_name='agent/chat')
  131. return f"系统错误: {str(e)}"
  132. async def handle_query_stream(
  133. self,
  134. trace_id: str,
  135. task_prompt_info: dict,
  136. input_query: str,
  137. context: Optional[str] = None,
  138. header_info: Optional[Dict] = None,
  139. config_param: TestForm = None,
  140. ) -> AsyncGenerator[str, None]:
  141. """流式处理查询(优化缓冲管理)"""
  142. try:
  143. # 确保agent已初始化
  144. if not self.initialized:
  145. await self.init_agent()
  146. session_id = config_param.session_id
  147. try:
  148. # 构建输入消息
  149. input_message , input_summary_context = self.get_input_context(
  150. trace_id=trace_id,
  151. task_prompt_info=task_prompt_info,
  152. input_query=input_query,
  153. context=context
  154. )
  155. server_logger.info(trace_id=trace_id, msg=f"input_context: {input_message}")
  156. # 用于模型对话使用
  157. input_human_message = HumanMessage(content=input_message)
  158. # 用于对话历史记录摘要
  159. input_human_summary_message = HumanMessage(content=input_summary_context)
  160. # 获取历史消息
  161. history_messages = []
  162. # 构造完整的消息列表
  163. all_messages = list(history_messages) + [input_human_message]
  164. # 配置执行上下文
  165. config = RunnableConfig(
  166. configurable={"thread_id": session_id},
  167. runnable_kwargs={"recursion_limit": 15}
  168. )
  169. # 流式执行
  170. events = self.agent_executor.astream_events(
  171. {"messages": all_messages},
  172. config=config,
  173. stream_mode="values"
  174. )
  175. full_response = []
  176. buffer = []
  177. last_flush_time = time.time()
  178. # 流式处理事件
  179. async for event in events:
  180. # 只在特定事件类型时打印日志
  181. self.log_stream_pretty_message(trace_id=trace_id, event=event)
  182. if 'chunk' in event['data'] and "on_chat_model_stream" in event['event']:
  183. chunk = event['data']['chunk'].content
  184. full_response.append(chunk)
  185. # 缓冲管理策略
  186. buffer.append(chunk)
  187. current_time = time.time()
  188. # 满足以下任一条件即刷新缓冲区
  189. if (len(buffer) >= 3 or # 达到最小块数
  190. (current_time - last_flush_time) > 0.5 or # 超时
  191. any(chunk.endswith((c, f"{c} ")) for c in
  192. ['.', '。', '!', '?', '\n', ';', ';'])): # 自然断点
  193. # 合并并发送缓冲内容
  194. combined = ''.join(buffer)
  195. yield combined
  196. # 重置缓冲
  197. buffer.clear()
  198. last_flush_time = current_time
  199. # 处理剩余内容
  200. if buffer:
  201. yield ''.join(buffer)
  202. # 将完整响应添加到历史并进行压缩
  203. if full_response:
  204. full_text = "".join(full_response)
  205. server_logger.info(trace_id=trace_id, msg=f"full_response: {full_text}")
  206. finally:
  207. # 确保释放会话锁
  208. pass
  209. except PermissionError as e:
  210. yield json.dumps({"error": str(e)})
  211. except Exception as e:
  212. handler_err(server_logger, trace_id=trace_id, err=e, err_name='test_stream')
  213. yield json.dumps({"error": f"系统错误: {str(e)}"})
  214. test_agent_client = TestAgentClient()