rag_pipeline_runner.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. RAG 管线执行器 — 封装完整的审查要点提取 + 语义检索管线
  3. 用法:
  4. runner = RAGPipelineRunner()
  5. result = runner.run_single(content_text)
  6. print(result.summary())
  7. """
  8. import time
  9. import asyncio
  10. from dataclasses import dataclass, field
  11. from typing import Optional
  12. from foundation.observability.logger.loggering import review_logger as logger
  13. @dataclass
  14. class PipelineResult:
  15. """单次管线执行结果"""
  16. chunk_id: str = ""
  17. content: str = ""
  18. # Step 1: 审查要点提取
  19. review_points: list = field(default_factory=list)
  20. extract_time: float = 0.0
  21. extract_error: Optional[str] = None
  22. # Step 2: RAG 检索
  23. retrieval_results: list = field(default_factory=list) # 二维列表
  24. retrieval_time: float = 0.0
  25. retrieval_error: Optional[str] = None
  26. @property
  27. def total_time(self) -> float:
  28. return self.extract_time + self.retrieval_time
  29. @property
  30. def review_point_count(self) -> int:
  31. return len(self.review_points) if self.review_points else 0
  32. @property
  33. def total_retrieved(self) -> int:
  34. """检索到的总结果数"""
  35. if not self.retrieval_results:
  36. return 0
  37. return sum(len(r) for r in self.retrieval_results if r)
  38. @property
  39. def non_empty_pairs(self) -> int:
  40. """有检索结果的查询对数量"""
  41. if not self.retrieval_results:
  42. return 0
  43. return sum(1 for r in self.retrieval_results if r)
  44. def summary(self) -> str:
  45. """生成结果摘要"""
  46. lines = [
  47. f"=== Pipeline Result: {self.chunk_id} ===",
  48. f" 审查要点数: {self.review_point_count}",
  49. f" 提取耗时: {self.extract_time:.2f}s",
  50. f" 检索耗时: {self.retrieval_time:.2f}s",
  51. f" 总耗时: {self.total_time:.2f}s",
  52. f" 有结果的查询对: {self.non_empty_pairs}/{self.review_point_count}",
  53. f" 检索结果总数: {self.total_retrieved}",
  54. ]
  55. if self.extract_error:
  56. lines.append(f" [ERROR] 提取失败: {self.extract_error}")
  57. if self.retrieval_error:
  58. lines.append(f" [ERROR] 检索失败: {self.retrieval_error}")
  59. # 审查要点详情
  60. if self.review_points:
  61. lines.append(" --- 审查要点 ---")
  62. for i, rp in enumerate(self.review_points):
  63. label = rp.get('label', rp.get('entity', '?'))
  64. queries = rp.get('search_queries', rp.get('search_keywords', []))
  65. original = rp.get('original_text', rp.get('background', ''))[:60]
  66. param = rp.get('parameter', '')
  67. lines.append(f" [{i}] {label}")
  68. lines.append(f" queries: {queries}")
  69. lines.append(f" original: {original}...")
  70. if param:
  71. lines.append(f" parameter: {param}")
  72. # 检索结果概况
  73. if self.retrieval_results:
  74. lines.append(" --- 检索结果概况 ---")
  75. for i, results in enumerate(self.retrieval_results):
  76. if not results:
  77. rp_label = self.review_points[i].get('label', '?') if i < len(self.review_points) else '?'
  78. lines.append(f" [{i}] {rp_label}: 无结果")
  79. continue
  80. rp_label = self.review_points[i].get('label', '?') if i < len(self.review_points) else '?'
  81. top_score = max(r.get('rerank_score', 0) for r in results)
  82. bfp_score = max(r.get('bfp_rerank_score', 0) for r in results)
  83. lines.append(
  84. f" [{i}] {rp_label}: {len(results)} 条结果, "
  85. f"top_rerank={top_score:.4f}, top_bfp={bfp_score:.4f}"
  86. )
  87. return "\n".join(lines)
  88. class RAGPipelineRunner:
  89. """RAG 管线执行器"""
  90. def __init__(self):
  91. from foundation.ai.rag.retrieval.query_rewrite import query_rewrite_manager
  92. from foundation.ai.rag.retrieval.entities_enhance import review_point_retriever
  93. self.query_rewrite_manager = query_rewrite_manager
  94. self.review_point_retriever = review_point_retriever
  95. def run_single(self, content: str, chunk_id: str = "") -> PipelineResult:
  96. """
  97. 执行单条文本的完整 RAG 管线
  98. Args:
  99. content: 施工方案文本
  100. chunk_id: 文本标识符
  101. Returns:
  102. PipelineResult
  103. """
  104. result = PipelineResult(chunk_id=chunk_id, content=content)
  105. # Step 1: 审查要点提取
  106. logger.info(f"[RAG管线测试] 开始审查要点提取: {chunk_id}")
  107. t0 = time.time()
  108. try:
  109. review_points = self.query_rewrite_manager.query_extract(content)
  110. result.extract_time = time.time() - t0
  111. if review_points:
  112. result.review_points = review_points
  113. logger.info(
  114. f"[RAG管线测试] 提取到 {len(review_points)} 个审查要点, "
  115. f"耗时 {result.extract_time:.2f}s"
  116. )
  117. else:
  118. result.extract_error = "提取结果为空"
  119. logger.warning(f"[RAG管线测试] 审查要点提取为空: {chunk_id}")
  120. return result
  121. except Exception as e:
  122. result.extract_time = time.time() - t0
  123. result.extract_error = str(e)
  124. logger.error(f"[RAG管线测试] 审查要点提取失败: {e}")
  125. return result
  126. # Step 2: RAG 检索
  127. logger.info(f"[RAG管线测试] 开始 RAG 检索: {chunk_id}")
  128. t1 = time.time()
  129. try:
  130. retrieval_results = self.review_point_retriever.review_point_retrieval(
  131. review_points
  132. )
  133. result.retrieval_time = time.time() - t1
  134. result.retrieval_results = retrieval_results
  135. logger.info(
  136. f"[RAG管线测试] 检索完成, "
  137. f"{result.non_empty_pairs}/{result.review_point_count} 个查询对有结果, "
  138. f"耗时 {result.retrieval_time:.2f}s"
  139. )
  140. except Exception as e:
  141. result.retrieval_time = time.time() - t1
  142. result.retrieval_error = str(e)
  143. logger.error(f"[RAG管线测试] RAG 检索失败: {e}")
  144. return result
  145. def run_batch(self, test_samples: list) -> list:
  146. """
  147. 批量执行管线
  148. Args:
  149. test_samples: 测试样本列表,每个包含 chunk_id 和 content
  150. Returns:
  151. list[PipelineResult]
  152. """
  153. results = []
  154. for i, sample in enumerate(test_samples):
  155. chunk_id = sample["chunk_id"]
  156. logger.info(
  157. f"[RAG管线测试] ====== 样本 {i+1}/{len(test_samples)}: {chunk_id} ======"
  158. )
  159. result = self.run_single(sample["content"], chunk_id=chunk_id)
  160. results.append(result)
  161. logger.info(f"[RAG管线测试] {result.summary()}")
  162. return results