entities_enhance.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import asyncio
  2. from foundation.observability.monitoring.time_statistics import track_execution_time
  3. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  4. from foundation.observability.logger.loggering import review_logger as server_logger
  5. from foundation.infrastructure.config.config import config_handler
  6. class ReviewPointRetriever():
  7. """
  8. 审查要点检索器 — 直接从施工方案文本搜索规范条文
  9. 替代旧的 EntitiesEnhance(实体增强检索),新流程:
  10. 1. 对每个 search_query → hybrid_search(CHILDREN_COLLECTION) → 一次重排序(rerank)
  11. 2. 合并所有 search_query 的候选结果并去重
  12. 3. 用 original_text 做二次重排序(语义锚点对齐)
  13. 4. 返回 top-K 结果
  14. 核心改进:
  15. - 跳过 ENTITY_COLLECTION 中间跳,直接搜索规范条文集合
  16. - 用原文摘录(而非 LLM 概括的 background)做二次重排序,语义更精确
  17. """
  18. def __init__(self):
  19. self.result_lists = []
  20. self._search_cache = {} # 检索结果缓存
  21. def _get_cache_key(self, query: str) -> str:
  22. return f"search::{query[:100]}"
  23. def _get_children_collection(self) -> str:
  24. return config_handler.get('rag_collections', 'CHILDREN_COLLECTION', 'rag_children_hybrid')
  25. @track_execution_time
  26. def review_point_retrieval(self, review_points):
  27. """
  28. 审查要点检索 — 替代 entities_enhance_retrieval
  29. Args:
  30. review_points: 审查要点列表, 每个要点包含:
  31. - label: 标签
  32. - search_queries: 规范检索语句列表
  33. - original_text: 原文摘录
  34. - parameter: 技术参数
  35. Returns:
  36. list[list[dict]]: 二维列表, 每个子列表对应一个审查要点的检索结果
  37. """
  38. def run_async(coro):
  39. """在合适的环境中运行异步函数"""
  40. try:
  41. loop = asyncio.get_running_loop()
  42. import concurrent.futures
  43. with concurrent.futures.ThreadPoolExecutor() as executor:
  44. future = executor.submit(asyncio.run, coro)
  45. return future.result()
  46. except RuntimeError:
  47. return asyncio.run(coro)
  48. self.result_lists = []
  49. children_collection = self._get_children_collection()
  50. for point_idx, point in enumerate(review_points):
  51. # 兼容新旧字段名
  52. label = point.get('label', point.get('entity', ''))
  53. search_queries = point.get('search_queries', point.get('search_keywords', []))
  54. original_text = point.get('original_text', point.get('background', ''))
  55. server_logger.info(
  56. f"正在处理审查要点 [{point_idx}]: {label}, "
  57. f"检索语句数: {len(search_queries)}, "
  58. f"原文长度: {len(original_text)}"
  59. )
  60. # Step 1: 对每个 search_query 执行 hybrid_search + 一次重排序
  61. all_candidates = []
  62. for query in search_queries:
  63. cache_key = self._get_cache_key(query)
  64. if cache_key in self._search_cache:
  65. query_results = self._search_cache[cache_key]
  66. server_logger.info(f"[缓存命中] search_query: {query[:30]}...")
  67. else:
  68. query_results = run_async(
  69. retrieval_manager.async_multi_stage_recall(
  70. collection_name=children_collection,
  71. query_text=query,
  72. hybrid_top_k=10,
  73. top_k=5
  74. )
  75. )
  76. self._search_cache[cache_key] = query_results
  77. server_logger.info(
  78. f"[检索完成] search_query: {query[:30]}... "
  79. f"召回 {len(query_results)} 个候选"
  80. )
  81. all_candidates.extend(query_results)
  82. if not all_candidates:
  83. server_logger.warning(f"审查要点 '{label}' 所有检索语句均无结果")
  84. self.result_lists.append([])
  85. continue
  86. # Step 2: 去重 (基于 text_content)
  87. seen_texts = set()
  88. unique_candidates = []
  89. for item in all_candidates:
  90. text = item.get('text_content', '')
  91. if text and text not in seen_texts:
  92. seen_texts.add(text)
  93. unique_candidates.append(item)
  94. server_logger.info(
  95. f"审查要点 '{label}': 合并 {len(all_candidates)} 个候选, "
  96. f"去重后 {len(unique_candidates)} 个"
  97. )
  98. # Step 3: 筛选高分候选 (rerank_score > 0.5)
  99. high_score = [c for c in unique_candidates if (c.get('rerank_score') or 0) > 0.5]
  100. if not high_score:
  101. # 无高分候选 → 该审查要点无相关规范,直接跳过(不进入后续流程)
  102. max_score = max((c.get('rerank_score') or 0) for c in unique_candidates) if unique_candidates else 0
  103. server_logger.warning(
  104. f"审查要点 '{label}': 无高分候选(>0.5), 共 {len(unique_candidates)} 个候选均低于阈值, "
  105. f"最高分={max_score:.4f}, 跳过该审查要点"
  106. )
  107. self.result_lists.append([])
  108. continue
  109. # Step 4: 二次重排序 — 用 original_text 作为语义锚点
  110. if original_text and len(original_text) > 10 and len(high_score) > 1:
  111. final_results = self._secondary_rerank(original_text, high_score, top_k=5)
  112. server_logger.info(
  113. f"审查要点 '{label}': 二次重排序完成, "
  114. f"返回 {len(final_results)} 个结果"
  115. )
  116. else:
  117. final_results = high_score[:5]
  118. server_logger.info(
  119. f"审查要点 '{label}': 跳过二次重排序 "
  120. f"(原文长度={len(original_text)}, 候选数={len(high_score)})"
  121. )
  122. # Step 5: 标记来源信息 (backward compat)
  123. for result in final_results:
  124. result['source_entity'] = label
  125. self.result_lists.append(final_results)
  126. return self.result_lists
  127. def _secondary_rerank(self, original_text, candidates, top_k=5):
  128. """
  129. 二次重排序: 用 original_text(原文摘录)作为 query,对候选文档重排序
  130. 核心创新: 用施工原文(而非 entity description 或 LLM 概括的 background)做 rerank,
  131. 确保检索到的规范条文与施工文本的语义精确对齐
  132. """
  133. # 提取候选文本(去重)
  134. candidate_texts = []
  135. seen = set()
  136. for item in candidates:
  137. text = item.get('text_content', '')
  138. if text and text not in seen:
  139. seen.add(text)
  140. candidate_texts.append(text)
  141. if not candidate_texts:
  142. return candidates[:top_k]
  143. try:
  144. rerank_results = retrieval_manager._get_rerank_results(
  145. original_text, candidate_texts, top_k
  146. )
  147. except Exception as e:
  148. server_logger.error(f"二次重排序失败: {e}")
  149. return candidates[:top_k]
  150. # 将 rerank 分数映射回原始结果
  151. text_to_items = {}
  152. for item in candidates:
  153. text = item.get('text_content', '')
  154. if text not in text_to_items:
  155. text_to_items[text] = []
  156. text_to_items[text].append(item)
  157. final_results = []
  158. added_texts = set()
  159. for rerank_item in rerank_results:
  160. text = rerank_item.get('text', '')
  161. score = rerank_item.get('score', 0.0)
  162. if text in text_to_items and text not in added_texts:
  163. best_candidate = max(
  164. text_to_items[text],
  165. key=lambda x: x.get('rerank_score', 0.0)
  166. )
  167. result_item = best_candidate.copy()
  168. result_item['bfp_rerank_score'] = score # 二次重排序分数 (backward compat)
  169. result_item['bfp_rerank_parent_id'] = result_item.get(
  170. 'metadata', {}
  171. ).get('parent_id', '')
  172. final_results.append(result_item)
  173. added_texts.add(text)
  174. return final_results
  175. def clear_cache(self):
  176. """清空检索缓存"""
  177. self._search_cache.clear()
  178. server_logger.info("[缓存清理] 审查要点检索缓存已清空")
  179. # 向后兼容:旧代码调用 entities_enhance_retrieval 时自动转发
  180. def entities_enhance_retrieval(self, query_pairs):
  181. """向后兼容入口,转发到 review_point_retrieval"""
  182. return self.review_point_retrieval(query_pairs)
  183. # 全局实例 — 新名称
  184. review_point_retriever = ReviewPointRetriever()
  185. # 向后兼容:旧代码 import entity_enhance 时不会报错
  186. entity_enhance = review_point_retriever