| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- """
- RAG 管线执行器 — 封装完整的审查要点提取 + 语义检索管线
- 用法:
- runner = RAGPipelineRunner()
- result = runner.run_single(content_text)
- print(result.summary())
- """
- import time
- import asyncio
- from dataclasses import dataclass, field
- from typing import Optional
- from foundation.observability.logger.loggering import review_logger as logger
- @dataclass
- class PipelineResult:
- """单次管线执行结果"""
- chunk_id: str = ""
- content: str = ""
- # Step 1: 审查要点提取
- review_points: list = field(default_factory=list)
- extract_time: float = 0.0
- extract_error: Optional[str] = None
- # Step 2: RAG 检索
- retrieval_results: list = field(default_factory=list) # 二维列表
- retrieval_time: float = 0.0
- retrieval_error: Optional[str] = None
- @property
- def total_time(self) -> float:
- return self.extract_time + self.retrieval_time
- @property
- def review_point_count(self) -> int:
- return len(self.review_points) if self.review_points else 0
- @property
- def total_retrieved(self) -> int:
- """检索到的总结果数"""
- if not self.retrieval_results:
- return 0
- return sum(len(r) for r in self.retrieval_results if r)
- @property
- def non_empty_pairs(self) -> int:
- """有检索结果的查询对数量"""
- if not self.retrieval_results:
- return 0
- return sum(1 for r in self.retrieval_results if r)
- def summary(self) -> str:
- """生成结果摘要"""
- lines = [
- f"=== Pipeline Result: {self.chunk_id} ===",
- f" 审查要点数: {self.review_point_count}",
- f" 提取耗时: {self.extract_time:.2f}s",
- f" 检索耗时: {self.retrieval_time:.2f}s",
- f" 总耗时: {self.total_time:.2f}s",
- f" 有结果的查询对: {self.non_empty_pairs}/{self.review_point_count}",
- f" 检索结果总数: {self.total_retrieved}",
- ]
- if self.extract_error:
- lines.append(f" [ERROR] 提取失败: {self.extract_error}")
- if self.retrieval_error:
- lines.append(f" [ERROR] 检索失败: {self.retrieval_error}")
- # 审查要点详情
- if self.review_points:
- lines.append(" --- 审查要点 ---")
- for i, rp in enumerate(self.review_points):
- label = rp.get('label', rp.get('entity', '?'))
- queries = rp.get('search_queries', rp.get('search_keywords', []))
- original = rp.get('original_text', rp.get('background', ''))[:60]
- param = rp.get('parameter', '')
- lines.append(f" [{i}] {label}")
- lines.append(f" queries: {queries}")
- lines.append(f" original: {original}...")
- if param:
- lines.append(f" parameter: {param}")
- # 检索结果概况
- if self.retrieval_results:
- lines.append(" --- 检索结果概况 ---")
- for i, results in enumerate(self.retrieval_results):
- if not results:
- rp_label = self.review_points[i].get('label', '?') if i < len(self.review_points) else '?'
- lines.append(f" [{i}] {rp_label}: 无结果")
- continue
- rp_label = self.review_points[i].get('label', '?') if i < len(self.review_points) else '?'
- top_score = max(r.get('rerank_score', 0) for r in results)
- bfp_score = max(r.get('bfp_rerank_score', 0) for r in results)
- lines.append(
- f" [{i}] {rp_label}: {len(results)} 条结果, "
- f"top_rerank={top_score:.4f}, top_bfp={bfp_score:.4f}"
- )
- return "\n".join(lines)
- class RAGPipelineRunner:
- """RAG 管线执行器"""
- def __init__(self):
- from foundation.ai.rag.retrieval.query_rewrite import query_rewrite_manager
- from foundation.ai.rag.retrieval.entities_enhance import review_point_retriever
- self.query_rewrite_manager = query_rewrite_manager
- self.review_point_retriever = review_point_retriever
- def run_single(self, content: str, chunk_id: str = "") -> PipelineResult:
- """
- 执行单条文本的完整 RAG 管线
- Args:
- content: 施工方案文本
- chunk_id: 文本标识符
- Returns:
- PipelineResult
- """
- result = PipelineResult(chunk_id=chunk_id, content=content)
- # Step 1: 审查要点提取
- logger.info(f"[RAG管线测试] 开始审查要点提取: {chunk_id}")
- t0 = time.time()
- try:
- review_points = self.query_rewrite_manager.query_extract(content)
- result.extract_time = time.time() - t0
- if review_points:
- result.review_points = review_points
- logger.info(
- f"[RAG管线测试] 提取到 {len(review_points)} 个审查要点, "
- f"耗时 {result.extract_time:.2f}s"
- )
- else:
- result.extract_error = "提取结果为空"
- logger.warning(f"[RAG管线测试] 审查要点提取为空: {chunk_id}")
- return result
- except Exception as e:
- result.extract_time = time.time() - t0
- result.extract_error = str(e)
- logger.error(f"[RAG管线测试] 审查要点提取失败: {e}")
- return result
- # Step 2: RAG 检索
- logger.info(f"[RAG管线测试] 开始 RAG 检索: {chunk_id}")
- t1 = time.time()
- try:
- retrieval_results = self.review_point_retriever.review_point_retrieval(
- review_points
- )
- result.retrieval_time = time.time() - t1
- result.retrieval_results = retrieval_results
- logger.info(
- f"[RAG管线测试] 检索完成, "
- f"{result.non_empty_pairs}/{result.review_point_count} 个查询对有结果, "
- f"耗时 {result.retrieval_time:.2f}s"
- )
- except Exception as e:
- result.retrieval_time = time.time() - t1
- result.retrieval_error = str(e)
- logger.error(f"[RAG管线测试] RAG 检索失败: {e}")
- return result
- def run_batch(self, test_samples: list) -> list:
- """
- 批量执行管线
- Args:
- test_samples: 测试样本列表,每个包含 chunk_id 和 content
- Returns:
- list[PipelineResult]
- """
- results = []
- for i, sample in enumerate(test_samples):
- chunk_id = sample["chunk_id"]
- logger.info(
- f"[RAG管线测试] ====== 样本 {i+1}/{len(test_samples)}: {chunk_id} ======"
- )
- result = self.run_single(sample["content"], chunk_id=chunk_id)
- results.append(result)
- logger.info(f"[RAG管线测试] {result.summary()}")
- return results
|