| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- RAG链路监控装饰器
- 支持同步/异步函数的输入输出监控
- """
- import time
- import json
- import functools
- import inspect
- from typing import Callable, Optional, Dict, Any, Union
- from pathlib import Path
- from foundation.observability.logger.loggering import review_logger as logger
- class RAGMonitor:
- """RAG监控管理器"""
- def __init__(self, save_dir: str = "temp/rag_monitoring"):
- """
- 初始化RAG监控器
- Args:
- save_dir: 监控数据保存目录
- """
- self.save_dir = Path(save_dir)
- self.save_dir.mkdir(parents=True, exist_ok=True)
- self.pipeline_data = {}
- self.current_trace_id = None
- def start_trace(self, trace_id: str, metadata: Optional[Dict] = None):
- """
- 开始一个新的追踪会话
- Args:
- trace_id: 追踪会话ID
- metadata: 会话元数据
- """
- self.current_trace_id = trace_id
- self.pipeline_data[trace_id] = {
- "trace_id": trace_id,
- "start_time": time.time(),
- "metadata": metadata or {},
- "steps": {}
- }
- logger.info(f"[RAG监控] 开始追踪会话: {trace_id}")
- def end_trace(self, trace_id: str = None) -> Optional[Dict]:
- """
- 结束追踪会话并保存数据
- Args:
- trace_id: 追踪会话ID,如果为None则使用当前会话
- Returns:
- 追踪会话数据
- """
- trace_id = trace_id or self.current_trace_id
- if trace_id not in self.pipeline_data:
- logger.warning(f"[RAG监控] 追踪会话不存在: {trace_id}")
- return None
- data = self.pipeline_data[trace_id]
- data["end_time"] = time.time()
- data["total_duration"] = round(data["end_time"] - data["start_time"], 3)
- # 保存到文件
- file_path = self.save_dir / f"{trace_id}.json"
- with open(file_path, 'w', encoding='utf-8') as f:
- json.dump(data, f, ensure_ascii=False, indent=2, default=str)
- logger.info(f"[RAG监控] 追踪会话已保存: {file_path}, 总耗时: {data['total_duration']}秒")
- return data
- def get_trace_data(self, trace_id: str = None) -> Optional[Dict]:
- """
- 获取追踪会话数据
- Args:
- trace_id: 追踪会话ID,如果为None则使用当前会话
- Returns:
- 追踪会话数据
- """
- trace_id = trace_id or self.current_trace_id
- return self.pipeline_data.get(trace_id)
- def monitor_step(
- self,
- step_name: str,
- capture_input: bool = True,
- capture_output: bool = True,
- input_transform: Optional[Callable] = None,
- output_transform: Optional[Callable] = None,
- max_input_length: int = 500,
- max_output_length: int = 1000
- ):
- """
- 监控装饰器 - 支持同步和异步函数
- Args:
- step_name: 步骤名称
- capture_input: 是否捕获输入参数
- capture_output: 是否捕获输出结果
- input_transform: 输入数据转换函数(用于过滤敏感信息或压缩数据)
- output_transform: 输出数据转换函数
- max_input_length: 输入数据最大长度(超过会截断)
- max_output_length: 输出数据最大长度(超过会截断)
- Example:
- @rag_monitor.monitor_step("query_extract", capture_input=True)
- def extract_query(content: str):
- return query_rewrite_manager.query_extract(content)
- """
- def decorator(func: Callable):
- # 判断是否为异步函数
- is_async = inspect.iscoroutinefunction(func)
- if is_async:
- @functools.wraps(func)
- async def async_wrapper(*args, **kwargs):
- return await self._execute_with_monitoring(
- func, step_name, args, kwargs,
- capture_input, capture_output,
- input_transform, output_transform,
- max_input_length, max_output_length,
- is_async=True
- )
- return async_wrapper
- else:
- @functools.wraps(func)
- def sync_wrapper(*args, **kwargs):
- return self._execute_with_monitoring(
- func, step_name, args, kwargs,
- capture_input, capture_output,
- input_transform, output_transform,
- max_input_length, max_output_length,
- is_async=False
- )
- return sync_wrapper
- return decorator
- def _execute_with_monitoring(
- self,
- func: Callable,
- step_name: str,
- args: tuple,
- kwargs: dict,
- capture_input: bool,
- capture_output: bool,
- input_transform: Optional[Callable],
- output_transform: Optional[Callable],
- max_input_length: int,
- max_output_length: int,
- is_async: bool
- ):
- """执行函数并监控"""
- trace_id = self.current_trace_id
- if not trace_id:
- logger.warning(f"[RAG监控] 未找到活跃的追踪会话,跳过监控: {step_name}")
- # 即使没有追踪会话,也要正常执行函数
- if is_async:
- import asyncio
- return asyncio.create_task(func(*args, **kwargs))
- else:
- return func(*args, **kwargs)
- # 记录步骤数据
- step_data = {
- "step_name": step_name,
- "function_name": func.__name__,
- "start_time": time.time()
- }
- # 捕获输入
- if capture_input:
- input_data = {
- "args": self._safe_serialize(args, max_input_length),
- "kwargs": self._safe_serialize(kwargs, max_input_length)
- }
- if input_transform:
- try:
- input_data = input_transform(input_data)
- except Exception as e:
- logger.warning(f"[RAG监控] 输入转换失败: {e}")
- step_data["input"] = input_data
- # 执行函数
- try:
- if is_async:
- # 对于异步函数,需要特殊处理
- import asyncio
- async def async_exec():
- result = await func(*args, **kwargs)
- self._finalize_step(step_data, result, trace_id, capture_output,
- output_transform, max_output_length, success=True)
- return result
- return asyncio.create_task(async_exec())
- else:
- result = func(*args, **kwargs)
- self._finalize_step(step_data, result, trace_id, capture_output,
- output_transform, max_output_length, success=True)
- return result
- except Exception as e:
- self._finalize_step(step_data, None, trace_id, capture_output,
- output_transform, max_output_length, success=False, error=e)
- raise
- def _finalize_step(
- self,
- step_data: Dict,
- result: Any,
- trace_id: str,
- capture_output: bool,
- output_transform: Optional[Callable],
- max_output_length: int,
- success: bool,
- error: Optional[Exception] = None
- ):
- """完成步骤监控数据记录"""
- if success:
- step_data["status"] = "success"
- # 捕获输出
- if capture_output:
- output_data = self._safe_serialize(result, max_output_length)
- if output_transform:
- try:
- output_data = output_transform(output_data)
- except Exception as e:
- logger.warning(f"[RAG监控] 输出转换失败: {e}")
- step_data["output"] = output_data
- else:
- step_data["status"] = "error"
- step_data["error"] = {
- "type": type(error).__name__,
- "message": str(error)
- }
- logger.error(f"[RAG监控] 步骤执行失败: {step_data['step_name']}, 错误: {error}")
- step_data["end_time"] = time.time()
- step_data["duration"] = round(step_data["end_time"] - step_data["start_time"], 3)
- # 保存步骤数据
- if trace_id in self.pipeline_data:
- # 如果步骤名称已存在,添加序号
- original_step_name = step_data['step_name']
- step_name = original_step_name
- counter = 1
- while step_name in self.pipeline_data[trace_id]["steps"]:
- step_name = f"{original_step_name}_{counter}"
- counter += 1
- self.pipeline_data[trace_id]["steps"][step_name] = step_data
- logger.info(f"[RAG监控] 步骤完成: {step_name}, 耗时: {step_data['duration']}秒")
- def _safe_serialize(self, obj: Any, max_length: int = 500) -> Any:
- """
- 安全序列化对象(防止大对象占用过多内存)
- Args:
- obj: 要序列化的对象
- max_length: 字符串最大长度
- Returns:
- 序列化后的对象
- """
- if obj is None:
- return None
- # 基本类型直接返回
- if isinstance(obj, (int, float, bool)):
- return obj
- if isinstance(obj, str):
- if len(obj) > max_length:
- return {
- "type": "string",
- "length": len(obj),
- "preview": obj[:max_length],
- "truncated": True
- }
- return obj
- # 列表类型
- if isinstance(obj, (list, tuple)):
- result = {
- "type": "list" if isinstance(obj, list) else "tuple",
- "count": len(obj)
- }
- # 只保留前3项的预览
- if len(obj) > 0:
- result["preview"] = [self._safe_serialize(item, max_length) for item in obj[:3]]
- # 如果列表项少于10项,保存完整数据
- if len(obj) <= 10:
- result["full_data"] = [self._safe_serialize(item, max_length) for item in obj]
- else:
- result["truncated"] = True
- return result
- # 字典类型
- if isinstance(obj, dict):
- result = {}
- keys_list = list(obj.keys())
- # 最多保留20个键
- for key in keys_list[:20]:
- try:
- result[str(key)] = self._safe_serialize(obj[key], max_length)
- except Exception as e:
- result[str(key)] = f"<序列化失败: {e}>"
- if len(keys_list) > 20:
- result["__truncated__"] = f"省略了 {len(keys_list) - 20} 个键"
- return result
- # 其他类型尝试转为字符串
- try:
- str_repr = str(obj)
- if len(str_repr) > max_length:
- return {
- "type": type(obj).__name__,
- "preview": str_repr[:max_length],
- "truncated": True
- }
- return {"type": type(obj).__name__, "value": str_repr}
- except Exception as e:
- return {"type": type(obj).__name__, "error": f"无法序列化: {e}"}
- # 全局监控实例
- rag_monitor = RAGMonitor()
|