rag_monitor.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. RAG链路监控装饰器
  5. 支持同步/异步函数的输入输出监控
  6. """
  7. import time
  8. import json
  9. import functools
  10. import inspect
  11. from typing import Callable, Optional, Dict, Any, Union
  12. from pathlib import Path
  13. from foundation.observability.logger.loggering import review_logger as logger
  14. class RAGMonitor:
  15. """RAG监控管理器"""
  16. def __init__(self, save_dir: str = "temp/rag_monitoring"):
  17. """
  18. 初始化RAG监控器
  19. Args:
  20. save_dir: 监控数据保存目录
  21. """
  22. self.save_dir = Path(save_dir)
  23. self.save_dir.mkdir(parents=True, exist_ok=True)
  24. self.pipeline_data = {}
  25. self.current_trace_id = None
  26. def start_trace(self, trace_id: str, metadata: Optional[Dict] = None):
  27. """
  28. 开始一个新的追踪会话
  29. Args:
  30. trace_id: 追踪会话ID
  31. metadata: 会话元数据
  32. """
  33. self.current_trace_id = trace_id
  34. self.pipeline_data[trace_id] = {
  35. "trace_id": trace_id,
  36. "start_time": time.time(),
  37. "metadata": metadata or {},
  38. "steps": {}
  39. }
  40. logger.info(f"[RAG监控] 开始追踪会话: {trace_id}")
  41. def end_trace(self, trace_id: str = None) -> Optional[Dict]:
  42. """
  43. 结束追踪会话并保存数据
  44. Args:
  45. trace_id: 追踪会话ID,如果为None则使用当前会话
  46. Returns:
  47. 追踪会话数据
  48. """
  49. trace_id = trace_id or self.current_trace_id
  50. if trace_id not in self.pipeline_data:
  51. logger.warning(f"[RAG监控] 追踪会话不存在: {trace_id}")
  52. return None
  53. data = self.pipeline_data[trace_id]
  54. data["end_time"] = time.time()
  55. data["total_duration"] = round(data["end_time"] - data["start_time"], 3)
  56. # 保存到文件
  57. file_path = self.save_dir / f"{trace_id}.json"
  58. with open(file_path, 'w', encoding='utf-8') as f:
  59. json.dump(data, f, ensure_ascii=False, indent=2, default=str)
  60. logger.info(f"[RAG监控] 追踪会话已保存: {file_path}, 总耗时: {data['total_duration']}秒")
  61. return data
  62. def get_trace_data(self, trace_id: str = None) -> Optional[Dict]:
  63. """
  64. 获取追踪会话数据
  65. Args:
  66. trace_id: 追踪会话ID,如果为None则使用当前会话
  67. Returns:
  68. 追踪会话数据
  69. """
  70. trace_id = trace_id or self.current_trace_id
  71. return self.pipeline_data.get(trace_id)
  72. def monitor_step(
  73. self,
  74. step_name: str,
  75. capture_input: bool = True,
  76. capture_output: bool = True,
  77. input_transform: Optional[Callable] = None,
  78. output_transform: Optional[Callable] = None,
  79. max_input_length: int = 500,
  80. max_output_length: int = 1000
  81. ):
  82. """
  83. 监控装饰器 - 支持同步和异步函数
  84. Args:
  85. step_name: 步骤名称
  86. capture_input: 是否捕获输入参数
  87. capture_output: 是否捕获输出结果
  88. input_transform: 输入数据转换函数(用于过滤敏感信息或压缩数据)
  89. output_transform: 输出数据转换函数
  90. max_input_length: 输入数据最大长度(超过会截断)
  91. max_output_length: 输出数据最大长度(超过会截断)
  92. Example:
  93. @rag_monitor.monitor_step("query_extract", capture_input=True)
  94. def extract_query(content: str):
  95. return query_rewrite_manager.query_extract(content)
  96. """
  97. def decorator(func: Callable):
  98. # 判断是否为异步函数
  99. is_async = inspect.iscoroutinefunction(func)
  100. if is_async:
  101. @functools.wraps(func)
  102. async def async_wrapper(*args, **kwargs):
  103. return await self._execute_with_monitoring(
  104. func, step_name, args, kwargs,
  105. capture_input, capture_output,
  106. input_transform, output_transform,
  107. max_input_length, max_output_length,
  108. is_async=True
  109. )
  110. return async_wrapper
  111. else:
  112. @functools.wraps(func)
  113. def sync_wrapper(*args, **kwargs):
  114. return self._execute_with_monitoring(
  115. func, step_name, args, kwargs,
  116. capture_input, capture_output,
  117. input_transform, output_transform,
  118. max_input_length, max_output_length,
  119. is_async=False
  120. )
  121. return sync_wrapper
  122. return decorator
  123. def _execute_with_monitoring(
  124. self,
  125. func: Callable,
  126. step_name: str,
  127. args: tuple,
  128. kwargs: dict,
  129. capture_input: bool,
  130. capture_output: bool,
  131. input_transform: Optional[Callable],
  132. output_transform: Optional[Callable],
  133. max_input_length: int,
  134. max_output_length: int,
  135. is_async: bool
  136. ):
  137. """执行函数并监控"""
  138. trace_id = self.current_trace_id
  139. if not trace_id:
  140. logger.warning(f"[RAG监控] 未找到活跃的追踪会话,跳过监控: {step_name}")
  141. # 即使没有追踪会话,也要正常执行函数
  142. if is_async:
  143. import asyncio
  144. return asyncio.create_task(func(*args, **kwargs))
  145. else:
  146. return func(*args, **kwargs)
  147. # 记录步骤数据
  148. step_data = {
  149. "step_name": step_name,
  150. "function_name": func.__name__,
  151. "start_time": time.time()
  152. }
  153. # 捕获输入
  154. if capture_input:
  155. input_data = {
  156. "args": self._safe_serialize(args, max_input_length),
  157. "kwargs": self._safe_serialize(kwargs, max_input_length)
  158. }
  159. if input_transform:
  160. try:
  161. input_data = input_transform(input_data)
  162. except Exception as e:
  163. logger.warning(f"[RAG监控] 输入转换失败: {e}")
  164. step_data["input"] = input_data
  165. # 执行函数
  166. try:
  167. if is_async:
  168. # 对于异步函数,需要特殊处理
  169. import asyncio
  170. async def async_exec():
  171. result = await func(*args, **kwargs)
  172. self._finalize_step(step_data, result, trace_id, capture_output,
  173. output_transform, max_output_length, success=True)
  174. return result
  175. return asyncio.create_task(async_exec())
  176. else:
  177. result = func(*args, **kwargs)
  178. self._finalize_step(step_data, result, trace_id, capture_output,
  179. output_transform, max_output_length, success=True)
  180. return result
  181. except Exception as e:
  182. self._finalize_step(step_data, None, trace_id, capture_output,
  183. output_transform, max_output_length, success=False, error=e)
  184. raise
  185. def _finalize_step(
  186. self,
  187. step_data: Dict,
  188. result: Any,
  189. trace_id: str,
  190. capture_output: bool,
  191. output_transform: Optional[Callable],
  192. max_output_length: int,
  193. success: bool,
  194. error: Optional[Exception] = None
  195. ):
  196. """完成步骤监控数据记录"""
  197. if success:
  198. step_data["status"] = "success"
  199. # 捕获输出
  200. if capture_output:
  201. output_data = self._safe_serialize(result, max_output_length)
  202. if output_transform:
  203. try:
  204. output_data = output_transform(output_data)
  205. except Exception as e:
  206. logger.warning(f"[RAG监控] 输出转换失败: {e}")
  207. step_data["output"] = output_data
  208. else:
  209. step_data["status"] = "error"
  210. step_data["error"] = {
  211. "type": type(error).__name__,
  212. "message": str(error)
  213. }
  214. logger.error(f"[RAG监控] 步骤执行失败: {step_data['step_name']}, 错误: {error}")
  215. step_data["end_time"] = time.time()
  216. step_data["duration"] = round(step_data["end_time"] - step_data["start_time"], 3)
  217. # 保存步骤数据
  218. if trace_id in self.pipeline_data:
  219. # 如果步骤名称已存在,添加序号
  220. original_step_name = step_data['step_name']
  221. step_name = original_step_name
  222. counter = 1
  223. while step_name in self.pipeline_data[trace_id]["steps"]:
  224. step_name = f"{original_step_name}_{counter}"
  225. counter += 1
  226. self.pipeline_data[trace_id]["steps"][step_name] = step_data
  227. logger.info(f"[RAG监控] 步骤完成: {step_name}, 耗时: {step_data['duration']}秒")
  228. def _safe_serialize(self, obj: Any, max_length: int = 500) -> Any:
  229. """
  230. 安全序列化对象(防止大对象占用过多内存)
  231. Args:
  232. obj: 要序列化的对象
  233. max_length: 字符串最大长度
  234. Returns:
  235. 序列化后的对象
  236. """
  237. if obj is None:
  238. return None
  239. # 基本类型直接返回
  240. if isinstance(obj, (int, float, bool)):
  241. return obj
  242. if isinstance(obj, str):
  243. if len(obj) > max_length:
  244. return {
  245. "type": "string",
  246. "length": len(obj),
  247. "preview": obj[:max_length],
  248. "truncated": True
  249. }
  250. return obj
  251. # 列表类型
  252. if isinstance(obj, (list, tuple)):
  253. result = {
  254. "type": "list" if isinstance(obj, list) else "tuple",
  255. "count": len(obj)
  256. }
  257. # 只保留前3项的预览
  258. if len(obj) > 0:
  259. result["preview"] = [self._safe_serialize(item, max_length) for item in obj[:3]]
  260. # 如果列表项少于10项,保存完整数据
  261. if len(obj) <= 10:
  262. result["full_data"] = [self._safe_serialize(item, max_length) for item in obj]
  263. else:
  264. result["truncated"] = True
  265. return result
  266. # 字典类型
  267. if isinstance(obj, dict):
  268. result = {}
  269. keys_list = list(obj.keys())
  270. # 最多保留20个键
  271. for key in keys_list[:20]:
  272. try:
  273. result[str(key)] = self._safe_serialize(obj[key], max_length)
  274. except Exception as e:
  275. result[str(key)] = f"<序列化失败: {e}>"
  276. if len(keys_list) > 20:
  277. result["__truncated__"] = f"省略了 {len(keys_list) - 20} 个键"
  278. return result
  279. # 其他类型尝试转为字符串
  280. try:
  281. str_repr = str(obj)
  282. if len(str_repr) > max_length:
  283. return {
  284. "type": type(obj).__name__,
  285. "preview": str_repr[:max_length],
  286. "truncated": True
  287. }
  288. return {"type": type(obj).__name__, "value": str_repr}
  289. except Exception as e:
  290. return {"type": type(obj).__name__, "error": f"无法序列化: {e}"}
  291. # 全局监控实例
  292. rag_monitor = RAGMonitor()