| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # !/usr/bin/python
- # -*- coding: utf-8 -*-
- '''
- @Project : lq-agent-api
- @File :base_agent.py
- @IDE :Cursor
- @Author :
- @Date :2025/7/26 05:00
- '''
- from datetime import datetime
- from io import StringIO
- from contextlib import redirect_stdout
- from typing import Dict, List, Optional
- from foundation.logger.loggering import server_logger
- from foundation.utils.redis_utils import get_redis_result_cache_data_and_delete_key
- class BaseAgent:
- """
- 基础智能助手类
- """
- def __init__(self):
- pass
- def get_pretty_message_str(self, message) -> str:
- """安全地捕获 pretty_print() 的输出"""
- captured_output = StringIO()
- with redirect_stdout(captured_output):
- message.pretty_print()
- return captured_output.getvalue()
-
- def log_stream_pretty_message(self , trace_id , event):
- """
- 流式打印agent 整个推理过程 pretty_print() 的输出
- """
- event_type = event.get('event', '')
- name = event.get('name', '')
- data = event.get('data', {})
- if event_type not in ['on_chain_start', 'on_chain_end', 'on_tool_start', 'on_tool_end', 'on_chat_model_start']:
- return
-
- server_logger.info(trace_id=trace_id , msg=f"\n================================= {event_type} ({name}) =================================")
- if 'messages' in event:
- for msg in event['messages']:
- #msg.pretty_print()
- output = self.get_pretty_message_str(msg)
- server_logger.info(trace_id=trace_id , msg=f"\n{output}")
- elif 'chunk' in data:
- chunk = data['chunk']
- if hasattr(chunk, 'content') and chunk.content:
- server_logger.info(trace_id=trace_id , msg=f"Content: {chunk.content}")
- if hasattr(chunk, 'tool_calls') and chunk.tool_calls:
- server_logger.info(trace_id=trace_id , msg=f"Tool calls: {chunk.tool_calls}")
- elif 'output' in data:
- output = data['output']
- if hasattr(output, 'pretty_print'):
- #output.pretty_print()
- output = self.get_pretty_message_str(output)
- server_logger.info(trace_id=trace_id , msg=f"\n{output}")
- else:
- server_logger.info(trace_id=trace_id , msg=f"Output: {output}")
- def get_input_context(
- self,
- trace_id: str,
- task_prompt_info: dict,
- input_query: str,
- context: Optional[str] = None,
- supplement_info: Optional[str] = None
- ) -> tuple[str,str]:
- """构建场景优化的上下文提示"""
- context = context or "无相关数据"
- task_prompt_info_str = task_prompt_info["task_prompt"]
-
- # 场景优化的上下文模板
- context_template = """
- 助手会话 [ID: {trace_id}]
- 时间: {timestamp}
- 任务: {task_prompt_info_str}
-
- 用户提供上下文信息:
- {context}
- 用户输入问题:
- {input}
-
- """
- input_context = context_template.format(
- trace_id=trace_id,
- task_prompt_info_str=task_prompt_info_str,
- context=context,
- input=input_query,
- supplement_info=supplement_info,
- timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- )
-
- # 场景优化的上下文模板
- summary_context_template = """
- 助手会话 [ID: {trace_id}]
- 上下文信息:
- {context}
- 用户问题:
- {input}
- """
- input_summary_context = summary_context_template.format(
- trace_id=trace_id,
- context=context,
- input=input_query,
- )
- return input_context , input_summary_context
- def clean_json_output(self, raw_output: str) -> str:
- """去除开头和结尾的 ```json 和 ```"""
- cleaned = raw_output.strip()
- if cleaned.startswith("```json"):
- cleaned = cleaned[7:] # 去掉开头的 ```json
- if cleaned.endswith("```"):
- cleaned = cleaned[:-3] # 去掉结尾的 ```
- return cleaned.strip()
-
- async def get_redis_result_cache_data(self , trace_id: str):
- """
- 获取redis结果缓存数据
- @param data_type: 数据类型,
- 基本信息 cattle_info
- 体温信息 cattle_temperature
- 步数信息 cattle_walk
- 知识库检索溯源信息 retriever_resources
- @param trace_id: 链路跟踪ID
- """
- # 基本信息
- data_type = "cattle_info"
- cattle_info = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
- data_type = "cattle_temperature"
- cattle_temperature = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
- data_type = "cattle_walk"
- cattle_walk = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
- data_type = "retriever_resources"
- retriever_resources = await get_redis_result_cache_data_and_delete_key(data_type=data_type , trace_id=trace_id)
- return {
- "cattle_info": cattle_info,
- "cattle_temperature": cattle_temperature,
- "cattle_walk": cattle_walk,
- "retriever_resources": retriever_resources
- }
|