base_agent.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :base_agent.py
  6. @IDE :Cursor
  7. @Author :
  8. @Date :2025/7/26 05:00
  9. '''
  10. from datetime import datetime
  11. from io import StringIO
  12. from contextlib import redirect_stdout
  13. from typing import Dict, List, Optional, AsyncGenerator, Any, OrderedDict
  14. from logger.loggering import server_logger
  15. from utils.redis_utils import get_redis_result_cache_data_and_delete_key
  16. from enums.common_enums import UserRoleEnum
  17. class BaseAgent:
  18. """
  19. 基础智能助手类
  20. """
  21. def __init__(self):
  22. pass
  23. def get_pretty_message_str(self, message) -> str:
  24. """安全地捕获 pretty_print() 的输出"""
  25. captured_output = StringIO()
  26. with redirect_stdout(captured_output):
  27. message.pretty_print()
  28. return captured_output.getvalue()
  29. def log_stream_pretty_message(self , trace_id , event):
  30. """
  31. 流式打印agent 整个推理过程 pretty_print() 的输出
  32. """
  33. event_type = event.get('event', '')
  34. name = event.get('name', '')
  35. data = event.get('data', {})
  36. if event_type not in ['on_chain_start', 'on_chain_end', 'on_tool_start', 'on_tool_end', 'on_chat_model_start']:
  37. return
  38. server_logger.info(trace_id=trace_id , msg=f"\n================================= {event_type} ({name}) =================================")
  39. if 'messages' in event:
  40. for msg in event['messages']:
  41. #msg.pretty_print()
  42. output = self.get_pretty_message_str(msg)
  43. server_logger.info(trace_id=trace_id , msg=f"\n{output}")
  44. elif 'chunk' in data:
  45. chunk = data['chunk']
  46. if hasattr(chunk, 'content') and chunk.content:
  47. server_logger.info(trace_id=trace_id , msg=f"Content: {chunk.content}")
  48. if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
  49. server_logger.info(trace_id=trace_id , msg=f"Tool calls: {chunk.tool_calls}")
  50. elif 'output' in data:
  51. output = data['output']
  52. if hasattr(output, 'pretty_print'):
  53. #output.pretty_print()
  54. output = self.get_pretty_message_str(output)
  55. server_logger.info(trace_id=trace_id , msg=f"\n{output}")
  56. else:
  57. server_logger.info(trace_id=trace_id , msg=f"Output: {output}")
  58. def get_input_context(
  59. self,
  60. trace_id: str,
  61. business_scene: str,
  62. task_prompt_info: dict,
  63. input_query: str,
  64. context: Optional[str] = None,
  65. supplement_info: Optional[str] = None,
  66. header_info: Optional[Dict] = None ,
  67. config_param: Optional[dict] = None
  68. ) -> tuple[str,str]:
  69. """构建场景优化的上下文提示"""
  70. context = context or "无相关数据"
  71. supplement_info = supplement_info or "无补充信息"
  72. token = header_info.get('token', '') if header_info else ''
  73. tenantId = header_info.get('tenantId', '') if header_info else ''
  74. user_role = config_param.get('user_role', UserRoleEnum.COMMON.code) if config_param else UserRoleEnum.COMMON.code
  75. task_prompt_info_str = task_prompt_info["task_prompt"]
  76. call_tools_return_data_type = "text"
  77. final_result_data_type = "text"
  78. # 如果配置按配置要求,如果未配置默认
  79. call_tools_return_data_type = call_tools_return_data_type if call_tools_return_data_type else "text"
  80. final_result_data_type = final_result_data_type if final_result_data_type else "Markdown"
  81. # 场景优化的上下文模板
  82. context_template = """
  83. 助手会话 [ID: {trace_id}]
  84. 时间: {timestamp}
  85. 任务: {task_prompt_info_str}
  86. 用户提供上下文信息:
  87. {context}
  88. 用户补充信息:
  89. {supplement_info}
  90. 用户输入问题:
  91. {input}
  92. 用户角色: {user_role}
  93. 安全验证: {token}
  94. 场ID: {tenantId}
  95. """
  96. input_context = context_template.format(
  97. trace_id=trace_id,
  98. business_scene=business_scene,
  99. task_prompt_info_str=task_prompt_info_str,
  100. context=context,
  101. input=input_query,
  102. call_tools_return_data_type=call_tools_return_data_type,
  103. final_result_data_type=final_result_data_type,
  104. supplement_info=supplement_info,
  105. user_role=user_role,
  106. token=token,
  107. tenantId=tenantId,
  108. timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  109. )
  110. # 场景优化的上下文模板
  111. summary_context_template = """
  112. 助手会话 [ID: {trace_id}]
  113. 用户意图场景: {business_scene}
  114. 上下文信息:
  115. {context}
  116. 补充信息:
  117. {supplement_info}
  118. 用户问题:
  119. {input}
  120. 用户角色: {user_role}
  121. 安全验证: {token}
  122. 场ID: {tenantId}
  123. """
  124. input_summary_context = summary_context_template.format(
  125. trace_id=trace_id,
  126. business_scene=business_scene,
  127. context=context,
  128. input=input_query,
  129. supplement_info=supplement_info,
  130. user_role=user_role,
  131. token=token,
  132. tenantId=tenantId
  133. )
  134. return input_context , input_summary_context
  135. def clean_json_output(self, raw_output: str) -> str:
  136. """去除开头和结尾的 ```json 和 ```"""
  137. cleaned = raw_output.strip()
  138. if cleaned.startswith("```json"):
  139. cleaned = cleaned[7:] # 去掉开头的 ```json
  140. if cleaned.endswith("```"):
  141. cleaned = cleaned[:-3] # 去掉结尾的 ```
  142. return cleaned.strip()
  143. async def get_redis_result_cache_data(self , trace_id: str):
  144. """
  145. 获取redis结果缓存数据
  146. @param data_type: 数据类型,
  147. 基本信息 cattle_info
  148. 体温信息 cattle_temperature
  149. 步数信息 cattle_walk
  150. 知识库检索溯源信息 retriever_resources
  151. @param trace_id: 链路跟踪ID
  152. """
  153. # 基本信息
  154. data_type = "cattle_info"
  155. cattle_info = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  156. data_type = "cattle_temperature"
  157. cattle_temperature = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  158. data_type = "cattle_walk"
  159. cattle_walk = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  160. data_type = "retriever_resources"
  161. retriever_resources = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  162. return {
  163. "cattle_info": cattle_info,
  164. "cattle_temperature": cattle_temperature,
  165. "cattle_walk": cattle_walk,
  166. "retriever_resources": retriever_resources
  167. }