test_agent.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :agent_mcp.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/21 10:12
  9. '''
  10. import json
  11. from langgraph.prebuilt import create_react_agent
  12. from sqlalchemy.sql.functions import user
  13. from logger.loggering import server_logger
  14. from utils.common import handler_err
  15. from utils.utils import get_models
  16. from utils.yaml_utils import system_prompt_config
  17. import threading
  18. import time
  19. from typing import Dict, List, Optional, AsyncGenerator, Any, OrderedDict
  20. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  21. from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
  22. from langchain_core.runnables import RunnableConfig
  23. from agent.base_agent import BaseAgent
  24. from schemas.test_schemas import FormConfig
  25. class TestAgentClient(BaseAgent):
  26. """
  27. Xiwuzc 智能助手+MCP(带完整会话管理) - 针对场景优化
  28. 添加会话锁定机制,确保同一时间只有一个客户端可以使用特定会话
  29. """
  30. # 单例实例和线程锁
  31. _instance = None
  32. _singleton_lock = threading.Lock()
  33. def __new__(cls):
  34. """线程安全的单例模式实现"""
  35. if cls._instance is None:
  36. with cls._singleton_lock:
  37. if cls._instance is None:
  38. cls._instance = super().__new__(cls)
  39. cls._instance._initialize()
  40. return cls._instance
  41. def _initialize(self):
  42. """初始化模型和会话管理"""
  43. llm, chat, embed = get_models()
  44. self.llm = llm
  45. self.chat = chat
  46. self.embed = embed
  47. self.agent_executor = None
  48. self.initialized = False
  49. self.psutil_available = True
  50. # 固定系统提示词
  51. self.system_prompt = system_prompt_config["system_prompt"]
  52. # 清理任务
  53. self.cleanup_task = None
  54. server_logger.info(" client initialized")
  55. async def init_agent(self):
  56. """初始化agent_executor(只需一次)"""
  57. if self.initialized:
  58. return
  59. # 获取部署的模型列表
  60. server_logger.info(f"系统提示词 system_prompt:{self.system_prompt}")
  61. # 创建提示词模板 - 使用固定的系统提示词
  62. prompt = ChatPromptTemplate.from_messages([
  63. ("system", self.system_prompt),
  64. MessagesPlaceholder(variable_name="messages"),
  65. ("placeholder", "{agent_scratchpad}")
  66. ])
  67. # 创建Agent - 不再使用MemorySaver
  68. self.agent_executor = create_react_agent(
  69. self.llm,
  70. tools=[] , # 专用工具集 + 私有知识库检索工具
  71. prompt=prompt
  72. )
  73. self.initialized = True
  74. server_logger.info(" agent initialized")
  75. async def handle_query(self, trace_id: str, task_prompt_info: dict, input_query, context=None,
  76. config_param: FormConfig = None):
  77. try:
  78. # 确保agent已初始化
  79. if not self.initialized:
  80. await self.init_agent()
  81. session_id = config_param.session_id
  82. try:
  83. # 构建输入消息
  84. input_message , input_summary_context = self.get_input_context(
  85. trace_id=trace_id,
  86. task_prompt_info=task_prompt_info,
  87. input_query=input_query,
  88. context=context
  89. )
  90. # 用于模型对话使用
  91. input_human_message = HumanMessage(content=input_message)
  92. # 用于对话历史记录摘要
  93. input_human_summary_message = HumanMessage(content=input_summary_context)
  94. # 获取历史消息
  95. history_messages = []
  96. # 构造完整的消息列表
  97. all_messages = list(history_messages) + [input_human_message]
  98. # 配置执行上下文
  99. config = RunnableConfig(
  100. configurable={"thread_id": session_id},
  101. runnable_kwargs={"recursion_limit": 15}
  102. )
  103. # 执行智能体
  104. events = self.agent_executor.astream(
  105. {"messages": all_messages},
  106. config=config,
  107. stream_mode="values"
  108. )
  109. # 处理结果
  110. full_response = []
  111. async for event in events:
  112. if isinstance(event["messages"][-1], AIMessage):
  113. chunk = event["messages"][-1].content
  114. full_response.append(chunk)
  115. log_content = self.get_pretty_message_str(event["messages"][-1])
  116. server_logger.info("\n" + log_content.strip(), trace_id=trace_id)
  117. if full_response:
  118. full_text = "".join(full_response)
  119. server_logger.info(trace_id=trace_id, msg=f"full_response: {full_text}")
  120. full_text = self.clean_json_output(full_text)
  121. return full_text
  122. finally:
  123. # 确保释放会话锁
  124. pass
  125. except PermissionError as e:
  126. # 处理会话被其他设备锁定的情况
  127. return str(e)
  128. except Exception as e:
  129. handler_err(server_logger, trace_id=trace_id, err=e, err_name='agent/chat')
  130. return f"系统错误: {str(e)}"
  131. async def handle_query_stream(
  132. self,
  133. trace_id: str,
  134. task_prompt_info: dict,
  135. input_query: str,
  136. context: Optional[str] = None,
  137. header_info: Optional[Dict] = None,
  138. config_param: FormConfig = None,
  139. ) -> AsyncGenerator[str, None]:
  140. """流式处理查询(优化缓冲管理)"""
  141. try:
  142. # 确保agent已初始化
  143. if not self.initialized:
  144. await self.init_agent()
  145. session_id = config_param.session_id
  146. try:
  147. # 构建输入消息
  148. input_message , input_summary_context = self.get_input_context(
  149. trace_id=trace_id,
  150. task_prompt_info=task_prompt_info,
  151. input_query=input_query,
  152. context=context
  153. )
  154. server_logger.info(trace_id=trace_id, msg=f"input_context: {input_message}")
  155. # 用于模型对话使用
  156. input_human_message = HumanMessage(content=input_message)
  157. # 用于对话历史记录摘要
  158. input_human_summary_message = HumanMessage(content=input_summary_context)
  159. # 获取历史消息
  160. history_messages = []
  161. # 构造完整的消息列表
  162. all_messages = list(history_messages) + [input_human_message]
  163. # 配置执行上下文
  164. config = RunnableConfig(
  165. configurable={"thread_id": session_id},
  166. runnable_kwargs={"recursion_limit": 15}
  167. )
  168. # 流式执行
  169. events = self.agent_executor.astream_events(
  170. {"messages": all_messages},
  171. config=config,
  172. stream_mode="values"
  173. )
  174. full_response = []
  175. buffer = []
  176. last_flush_time = time.time()
  177. # 流式处理事件
  178. async for event in events:
  179. # 只在特定事件类型时打印日志
  180. self.log_stream_pretty_message(trace_id=trace_id, event=event)
  181. if 'chunk' in event['data'] and "on_chat_model_stream" in event['event']:
  182. chunk = event['data']['chunk'].content
  183. full_response.append(chunk)
  184. # 缓冲管理策略
  185. buffer.append(chunk)
  186. current_time = time.time()
  187. # 满足以下任一条件即刷新缓冲区
  188. if (len(buffer) >= 3 or # 达到最小块数
  189. (current_time - last_flush_time) > 0.5 or # 超时
  190. any(chunk.endswith((c, f"{c} ")) for c in
  191. ['.', '。', '!', '?', '\n', ';', ';'])): # 自然断点
  192. # 合并并发送缓冲内容
  193. combined = ''.join(buffer)
  194. yield combined
  195. # 重置缓冲
  196. buffer.clear()
  197. last_flush_time = current_time
  198. # 处理剩余内容
  199. if buffer:
  200. yield ''.join(buffer)
  201. # 将完整响应添加到历史并进行压缩
  202. if full_response:
  203. full_text = "".join(full_response)
  204. server_logger.info(trace_id=trace_id, msg=f"full_response: {full_text}")
  205. finally:
  206. # 确保释放会话锁
  207. pass
  208. except PermissionError as e:
  209. yield json.dumps({"error": str(e)})
  210. except Exception as e:
  211. handler_err(server_logger, trace_id=trace_id, err=e, err_name='test_stream')
  212. yield json.dumps({"error": f"系统错误: {str(e)}"})
  213. test_agent_client = TestAgentClient()