""" RAG 管线执行器 — 封装完整的审查要点提取 + 语义检索管线 用法: runner = RAGPipelineRunner() result = runner.run_single(content_text) print(result.summary()) """ import time import asyncio from dataclasses import dataclass, field from typing import Optional from foundation.observability.logger.loggering import review_logger as logger @dataclass class PipelineResult: """单次管线执行结果""" chunk_id: str = "" content: str = "" # Step 1: 审查要点提取 review_points: list = field(default_factory=list) extract_time: float = 0.0 extract_error: Optional[str] = None # Step 2: RAG 检索 retrieval_results: list = field(default_factory=list) # 二维列表 retrieval_time: float = 0.0 retrieval_error: Optional[str] = None @property def total_time(self) -> float: return self.extract_time + self.retrieval_time @property def review_point_count(self) -> int: return len(self.review_points) if self.review_points else 0 @property def total_retrieved(self) -> int: """检索到的总结果数""" if not self.retrieval_results: return 0 return sum(len(r) for r in self.retrieval_results if r) @property def non_empty_pairs(self) -> int: """有检索结果的查询对数量""" if not self.retrieval_results: return 0 return sum(1 for r in self.retrieval_results if r) def summary(self) -> str: """生成结果摘要""" lines = [ f"=== Pipeline Result: {self.chunk_id} ===", f" 审查要点数: {self.review_point_count}", f" 提取耗时: {self.extract_time:.2f}s", f" 检索耗时: {self.retrieval_time:.2f}s", f" 总耗时: {self.total_time:.2f}s", f" 有结果的查询对: {self.non_empty_pairs}/{self.review_point_count}", f" 检索结果总数: {self.total_retrieved}", ] if self.extract_error: lines.append(f" [ERROR] 提取失败: {self.extract_error}") if self.retrieval_error: lines.append(f" [ERROR] 检索失败: {self.retrieval_error}") # 审查要点详情 if self.review_points: lines.append(" --- 审查要点 ---") for i, rp in enumerate(self.review_points): label = rp.get('label', rp.get('entity', '?')) queries = rp.get('search_queries', rp.get('search_keywords', [])) original = rp.get('original_text', rp.get('background', ''))[:60] param = rp.get('parameter', '') lines.append(f" [{i}] {label}") lines.append(f" queries: {queries}") lines.append(f" original: {original}...") if param: lines.append(f" parameter: {param}") # 检索结果概况 if self.retrieval_results: lines.append(" --- 检索结果概况 ---") for i, results in enumerate(self.retrieval_results): if not results: rp_label = self.review_points[i].get('label', '?') if i < len(self.review_points) else '?' lines.append(f" [{i}] {rp_label}: 无结果") continue rp_label = self.review_points[i].get('label', '?') if i < len(self.review_points) else '?' top_score = max(r.get('rerank_score', 0) for r in results) bfp_score = max(r.get('bfp_rerank_score', 0) for r in results) lines.append( f" [{i}] {rp_label}: {len(results)} 条结果, " f"top_rerank={top_score:.4f}, top_bfp={bfp_score:.4f}" ) return "\n".join(lines) class RAGPipelineRunner: """RAG 管线执行器""" def __init__(self): from foundation.ai.rag.retrieval.query_rewrite import query_rewrite_manager from foundation.ai.rag.retrieval.entities_enhance import review_point_retriever self.query_rewrite_manager = query_rewrite_manager self.review_point_retriever = review_point_retriever def run_single(self, content: str, chunk_id: str = "") -> PipelineResult: """ 执行单条文本的完整 RAG 管线 Args: content: 施工方案文本 chunk_id: 文本标识符 Returns: PipelineResult """ result = PipelineResult(chunk_id=chunk_id, content=content) # Step 1: 审查要点提取 logger.info(f"[RAG管线测试] 开始审查要点提取: {chunk_id}") t0 = time.time() try: review_points = self.query_rewrite_manager.query_extract(content) result.extract_time = time.time() - t0 if review_points: result.review_points = review_points logger.info( f"[RAG管线测试] 提取到 {len(review_points)} 个审查要点, " f"耗时 {result.extract_time:.2f}s" ) else: result.extract_error = "提取结果为空" logger.warning(f"[RAG管线测试] 审查要点提取为空: {chunk_id}") return result except Exception as e: result.extract_time = time.time() - t0 result.extract_error = str(e) logger.error(f"[RAG管线测试] 审查要点提取失败: {e}") return result # Step 2: RAG 检索 logger.info(f"[RAG管线测试] 开始 RAG 检索: {chunk_id}") t1 = time.time() try: retrieval_results = self.review_point_retriever.review_point_retrieval( review_points ) result.retrieval_time = time.time() - t1 result.retrieval_results = retrieval_results logger.info( f"[RAG管线测试] 检索完成, " f"{result.non_empty_pairs}/{result.review_point_count} 个查询对有结果, " f"耗时 {result.retrieval_time:.2f}s" ) except Exception as e: result.retrieval_time = time.time() - t1 result.retrieval_error = str(e) logger.error(f"[RAG管线测试] RAG 检索失败: {e}") return result def run_batch(self, test_samples: list) -> list: """ 批量执行管线 Args: test_samples: 测试样本列表,每个包含 chunk_id 和 content Returns: list[PipelineResult] """ results = [] for i, sample in enumerate(test_samples): chunk_id = sample["chunk_id"] logger.info( f"[RAG管线测试] ====== 样本 {i+1}/{len(test_samples)}: {chunk_id} ======" ) result = self.run_single(sample["content"], chunk_id=chunk_id) results.append(result) logger.info(f"[RAG管线测试] {result.summary()}") return results