base_agent.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :base_agent.py
  6. @IDE :Cursor
  7. @Author :
  8. @Date :2025/7/26 05:00
  9. '''
  10. from datetime import datetime
  11. from io import StringIO
  12. from contextlib import redirect_stdout
  13. from typing import Dict, List, Optional
  14. from foundation.logger.loggering import server_logger
  15. from foundation.utils.redis_utils import get_redis_result_cache_data_and_delete_key
  16. class BaseAgent:
  17. """
  18. 基础智能助手类
  19. """
  20. def __init__(self):
  21. pass
  22. def get_pretty_message_str(self, message) -> str:
  23. """安全地捕获 pretty_print() 的输出"""
  24. captured_output = StringIO()
  25. with redirect_stdout(captured_output):
  26. message.pretty_print()
  27. return captured_output.getvalue()
  28. def log_stream_pretty_message(self , trace_id , event):
  29. """
  30. 流式打印agent 整个推理过程 pretty_print() 的输出
  31. """
  32. event_type = event.get('event', '')
  33. name = event.get('name', '')
  34. data = event.get('data', {})
  35. if event_type not in ['on_chain_start', 'on_chain_end', 'on_tool_start', 'on_tool_end', 'on_chat_model_start']:
  36. return
  37. server_logger.info(trace_id=trace_id , msg=f"\n================================= {event_type} ({name}) =================================")
  38. if 'messages' in event:
  39. for msg in event['messages']:
  40. #msg.pretty_print()
  41. output = self.get_pretty_message_str(msg)
  42. server_logger.info(trace_id=trace_id , msg=f"\n{output}")
  43. elif 'chunk' in data:
  44. chunk = data['chunk']
  45. if hasattr(chunk, 'content') and chunk.content:
  46. server_logger.info(trace_id=trace_id , msg=f"Content: {chunk.content}")
  47. if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
  48. server_logger.info(trace_id=trace_id , msg=f"Tool calls: {chunk.tool_calls}")
  49. elif 'output' in data:
  50. output = data['output']
  51. if hasattr(output, 'pretty_print'):
  52. #output.pretty_print()
  53. output = self.get_pretty_message_str(output)
  54. server_logger.info(trace_id=trace_id , msg=f"\n{output}")
  55. else:
  56. server_logger.info(trace_id=trace_id , msg=f"Output: {output}")
  57. def get_input_context(
  58. self,
  59. trace_id: str,
  60. task_prompt_info: dict,
  61. input_query: str,
  62. context: Optional[str] = None,
  63. supplement_info: Optional[str] = None
  64. ) -> tuple[str,str]:
  65. """构建场景优化的上下文提示"""
  66. context = context or "无相关数据"
  67. task_prompt_info_str = task_prompt_info["task_prompt"]
  68. # 场景优化的上下文模板
  69. context_template = """
  70. 助手会话 [ID: {trace_id}]
  71. 时间: {timestamp}
  72. 任务: {task_prompt_info_str}
  73. 用户提供上下文信息:
  74. {context}
  75. 用户输入问题:
  76. {input}
  77. """
  78. input_context = context_template.format(
  79. trace_id=trace_id,
  80. task_prompt_info_str=task_prompt_info_str,
  81. context=context,
  82. input=input_query,
  83. supplement_info=supplement_info,
  84. timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  85. )
  86. # 场景优化的上下文模板
  87. summary_context_template = """
  88. 助手会话 [ID: {trace_id}]
  89. 上下文信息:
  90. {context}
  91. 用户问题:
  92. {input}
  93. """
  94. input_summary_context = summary_context_template.format(
  95. trace_id=trace_id,
  96. context=context,
  97. input=input_query,
  98. )
  99. return input_context , input_summary_context
  100. def clean_json_output(self, raw_output: str) -> str:
  101. """去除开头和结尾的 ```json 和 ```"""
  102. cleaned = raw_output.strip()
  103. if cleaned.startswith("```json"):
  104. cleaned = cleaned[7:] # 去掉开头的 ```json
  105. if cleaned.endswith("```"):
  106. cleaned = cleaned[:-3] # 去掉结尾的 ```
  107. return cleaned.strip()
  108. async def get_redis_result_cache_data(self , trace_id: str):
  109. """
  110. 获取redis结果缓存数据
  111. @param data_type: 数据类型,
  112. 基本信息 cattle_info
  113. 体温信息 cattle_temperature
  114. 步数信息 cattle_walk
  115. 知识库检索溯源信息 retriever_resources
  116. @param trace_id: 链路跟踪ID
  117. """
  118. # 基本信息
  119. data_type = "cattle_info"
  120. cattle_info = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  121. data_type = "cattle_temperature"
  122. cattle_temperature = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  123. data_type = "cattle_walk"
  124. cattle_walk = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  125. data_type = "retriever_resources"
  126. retriever_resources = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
  127. return {
  128. "cattle_info": cattle_info,
  129. "cattle_temperature": cattle_temperature,
  130. "cattle_walk": cattle_walk,
  131. "retriever_resources": retriever_resources
  132. }