#!/usr/bin/env python # -*- coding: utf-8 -*- """ RAG链路监控装饰器 支持同步/异步函数的输入输出监控 """ import time import json import functools import inspect from typing import Callable, Optional, Dict, Any, Union from pathlib import Path from foundation.observability.logger.loggering import review_logger as logger class RAGMonitor: """RAG监控管理器""" def __init__(self, save_dir: str = "temp/rag_monitoring"): """ 初始化RAG监控器 Args: save_dir: 监控数据保存目录 """ self.save_dir = Path(save_dir) self.save_dir.mkdir(parents=True, exist_ok=True) self.pipeline_data = {} self.current_trace_id = None def start_trace(self, trace_id: str, metadata: Optional[Dict] = None): """ 开始一个新的追踪会话 Args: trace_id: 追踪会话ID metadata: 会话元数据 """ self.current_trace_id = trace_id self.pipeline_data[trace_id] = { "trace_id": trace_id, "start_time": time.time(), "metadata": metadata or {}, "steps": {} } logger.info(f"[RAG监控] 开始追踪会话: {trace_id}") def end_trace(self, trace_id: str = None) -> Optional[Dict]: """ 结束追踪会话并保存数据 Args: trace_id: 追踪会话ID,如果为None则使用当前会话 Returns: 追踪会话数据 """ trace_id = trace_id or self.current_trace_id if trace_id not in self.pipeline_data: logger.warning(f"[RAG监控] 追踪会话不存在: {trace_id}") return None data = self.pipeline_data[trace_id] data["end_time"] = time.time() data["total_duration"] = round(data["end_time"] - data["start_time"], 3) # 保存到文件 file_path = self.save_dir / f"{trace_id}.json" with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2, default=str) logger.info(f"[RAG监控] 追踪会话已保存: {file_path}, 总耗时: {data['total_duration']}秒") return data def get_trace_data(self, trace_id: str = None) -> Optional[Dict]: """ 获取追踪会话数据 Args: trace_id: 追踪会话ID,如果为None则使用当前会话 Returns: 追踪会话数据 """ trace_id = trace_id or self.current_trace_id return self.pipeline_data.get(trace_id) def monitor_step( self, step_name: str, capture_input: bool = True, capture_output: bool = True, input_transform: Optional[Callable] = None, output_transform: Optional[Callable] = None, max_input_length: int = 500, max_output_length: int = 1000 ): """ 监控装饰器 - 支持同步和异步函数 Args: step_name: 步骤名称 capture_input: 是否捕获输入参数 capture_output: 是否捕获输出结果 input_transform: 输入数据转换函数(用于过滤敏感信息或压缩数据) output_transform: 输出数据转换函数 max_input_length: 输入数据最大长度(超过会截断) max_output_length: 输出数据最大长度(超过会截断) Example: @rag_monitor.monitor_step("query_extract", capture_input=True) def extract_query(content: str): return query_rewrite_manager.query_extract(content) """ def decorator(func: Callable): # 判断是否为异步函数 is_async = inspect.iscoroutinefunction(func) if is_async: @functools.wraps(func) async def async_wrapper(*args, **kwargs): return await self._execute_with_monitoring( func, step_name, args, kwargs, capture_input, capture_output, input_transform, output_transform, max_input_length, max_output_length, is_async=True ) return async_wrapper else: @functools.wraps(func) def sync_wrapper(*args, **kwargs): return self._execute_with_monitoring( func, step_name, args, kwargs, capture_input, capture_output, input_transform, output_transform, max_input_length, max_output_length, is_async=False ) return sync_wrapper return decorator def _execute_with_monitoring( self, func: Callable, step_name: str, args: tuple, kwargs: dict, capture_input: bool, capture_output: bool, input_transform: Optional[Callable], output_transform: Optional[Callable], max_input_length: int, max_output_length: int, is_async: bool ): """执行函数并监控""" trace_id = self.current_trace_id if not trace_id: logger.warning(f"[RAG监控] 未找到活跃的追踪会话,跳过监控: {step_name}") # 即使没有追踪会话,也要正常执行函数 if is_async: import asyncio return asyncio.create_task(func(*args, **kwargs)) else: return func(*args, **kwargs) # 记录步骤数据 step_data = { "step_name": step_name, "function_name": func.__name__, "start_time": time.time() } # 捕获输入 if capture_input: input_data = { "args": self._safe_serialize(args, max_input_length), "kwargs": self._safe_serialize(kwargs, max_input_length) } if input_transform: try: input_data = input_transform(input_data) except Exception as e: logger.warning(f"[RAG监控] 输入转换失败: {e}") step_data["input"] = input_data # 执行函数 try: if is_async: # 对于异步函数,需要特殊处理 import asyncio async def async_exec(): result = await func(*args, **kwargs) self._finalize_step(step_data, result, trace_id, capture_output, output_transform, max_output_length, success=True) return result return asyncio.create_task(async_exec()) else: result = func(*args, **kwargs) self._finalize_step(step_data, result, trace_id, capture_output, output_transform, max_output_length, success=True) return result except Exception as e: self._finalize_step(step_data, None, trace_id, capture_output, output_transform, max_output_length, success=False, error=e) raise def _finalize_step( self, step_data: Dict, result: Any, trace_id: str, capture_output: bool, output_transform: Optional[Callable], max_output_length: int, success: bool, error: Optional[Exception] = None ): """完成步骤监控数据记录""" if success: step_data["status"] = "success" # 捕获输出 if capture_output: output_data = self._safe_serialize(result, max_output_length) if output_transform: try: output_data = output_transform(output_data) except Exception as e: logger.warning(f"[RAG监控] 输出转换失败: {e}") step_data["output"] = output_data else: step_data["status"] = "error" step_data["error"] = { "type": type(error).__name__, "message": str(error) } logger.error(f"[RAG监控] 步骤执行失败: {step_data['step_name']}, 错误: {error}") step_data["end_time"] = time.time() step_data["duration"] = round(step_data["end_time"] - step_data["start_time"], 3) # 保存步骤数据 if trace_id in self.pipeline_data: # 如果步骤名称已存在,添加序号 original_step_name = step_data['step_name'] step_name = original_step_name counter = 1 while step_name in self.pipeline_data[trace_id]["steps"]: step_name = f"{original_step_name}_{counter}" counter += 1 self.pipeline_data[trace_id]["steps"][step_name] = step_data logger.info(f"[RAG监控] 步骤完成: {step_name}, 耗时: {step_data['duration']}秒") def _safe_serialize(self, obj: Any, max_length: int = 500) -> Any: """ 安全序列化对象(防止大对象占用过多内存) Args: obj: 要序列化的对象 max_length: 字符串最大长度 Returns: 序列化后的对象 """ if obj is None: return None # 基本类型直接返回 if isinstance(obj, (int, float, bool)): return obj if isinstance(obj, str): if len(obj) > max_length: return { "type": "string", "length": len(obj), "preview": obj[:max_length], "truncated": True } return obj # 列表类型 if isinstance(obj, (list, tuple)): result = { "type": "list" if isinstance(obj, list) else "tuple", "count": len(obj) } # 只保留前3项的预览 if len(obj) > 0: result["preview"] = [self._safe_serialize(item, max_length) for item in obj[:3]] # 如果列表项少于10项,保存完整数据 if len(obj) <= 10: result["full_data"] = [self._safe_serialize(item, max_length) for item in obj] else: result["truncated"] = True return result # 字典类型 if isinstance(obj, dict): result = {} keys_list = list(obj.keys()) # 最多保留20个键 for key in keys_list[:20]: try: result[str(key)] = self._safe_serialize(obj[key], max_length) except Exception as e: result[str(key)] = f"<序列化失败: {e}>" if len(keys_list) > 20: result["__truncated__"] = f"省略了 {len(keys_list) - 20} 个键" return result # 其他类型尝试转为字符串 try: str_repr = str(obj) if len(str_repr) > max_length: return { "type": type(obj).__name__, "preview": str_repr[:max_length], "truncated": True } return {"type": type(obj).__name__, "value": str_repr} except Exception as e: return {"type": type(obj).__name__, "error": f"无法序列化: {e}"} # 全局监控实例 rag_monitor = RAGMonitor()