retrieval.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. import asyncio
  2. import json
  3. from typing import List, Dict, Any, Optional
  4. from foundation.ai.models.rerank_model import rerank_model
  5. from foundation.observability.monitoring.time_statistics import track_execution_time
  6. from foundation.infrastructure.config.config import config_handler
  7. from foundation.observability.logger.loggering import server_logger
  8. from foundation.database.base.vector.milvus_vector import MilvusVectorManager
  9. class RetrievalManager:
  10. """
  11. 召回管理器,实现多路召回功能
  12. """
  13. def __init__(self):
  14. """
  15. 初始化召回管理器
  16. """
  17. self.vector_manager = MilvusVectorManager()
  18. self.logger = server_logger
  19. self.dense_weight = config_handler.get('hybrid_search', 'DENSE_WEIGHT', 0.7)
  20. self.sparse_weight = config_handler.get('hybrid_search', 'SPARSE_WEIGHT', 0.3)
  21. # 重排序模型配置
  22. self.rerank_model_type = config_handler.get('retrieval', 'RERANK_MODEL_TYPE', 'bge') # 'bge' 或 'qwen3'
  23. self.logger.info(f"初始化重排序模型类型: {self.rerank_model_type}")
  24. def set_rerank_model(self, model_type: str):
  25. """
  26. 设置重排序模型类型
  27. Args:
  28. model_type: 模型类型 ('bge' 或 'qwen3')
  29. """
  30. if model_type not in ['bge', 'qwen3']:
  31. raise ValueError("model_type 必须是 'bge' 或 'qwen3'")
  32. self.rerank_model_type = model_type
  33. self.logger.info(f"重排序模型类型已设置为: {model_type}")
  34. def _clean_document(self, doc: str) -> str:
  35. """
  36. 清理文档文本,移除HTML标签和特殊字符
  37. Args:
  38. doc: 原始文档文本
  39. Returns:
  40. str: 清理后的文档文本
  41. """
  42. if not isinstance(doc, str):
  43. self.logger.debug(f"文档类型转换: {type(doc)} -> str")
  44. return str(doc)
  45. original_length = len(doc)
  46. # 移除HTML标签
  47. import re
  48. doc = re.sub(r'<[^>]+>', '', doc)
  49. # 移除多余的空白字符
  50. doc = re.sub(r'\s+', ' ', doc)
  51. # 更宽松的字符过滤 - 保留更多字符
  52. doc = re.sub(r'[^\u4e00-\u9fff\w\s.,;:!?()()。,;:!?\-\+\=\*/%&@#¥$【】「」""''""\n\r]', '', doc)
  53. # 截断过长的文本
  54. if len(doc) > 8000: # 设置最大长度限制
  55. doc = doc[:8000] + "..."
  56. cleaned_doc = doc.strip()
  57. self.logger.debug(f"文档清理: {original_length} -> {len(cleaned_doc)} 字符")
  58. return cleaned_doc
  59. def _get_rerank_results(self, query_text: str, documents: List[str], top_k: int = None) -> List[Dict[str, Any]]:
  60. """
  61. 根据配置选择重排序模型并执行重排序
  62. Args:
  63. query_text: 查询文本
  64. documents: 文档列表
  65. top_k: 返回结果数量
  66. Returns:
  67. List[Dict]: 重排序后的结果列表
  68. """
  69. try:
  70. # 清理和验证文档列表
  71. cleaned_documents = []
  72. valid_original_docs = []
  73. for doc in documents:
  74. if doc and isinstance(doc, str) and doc.strip():
  75. cleaned_doc = self._clean_document(doc)
  76. if cleaned_doc and len(cleaned_doc) > 3:
  77. cleaned_documents.append(cleaned_doc)
  78. valid_original_docs.append(doc)
  79. if not cleaned_documents:
  80. return []
  81. if self.rerank_model_type == 'qwen3':
  82. self.logger.info("使用 Qwen3-Reranker-8B 进行重排序")
  83. rerank_results = rerank_model.qwen3_rerank(query_text, cleaned_documents, top_k)
  84. # 将清理后的文本映射回原始文本
  85. for result in rerank_results:
  86. cleaned_text = result.get('text', '')
  87. # 查找原始文本
  88. for i, cleaned in enumerate(cleaned_documents):
  89. if cleaned == cleaned_text:
  90. result['text'] = valid_original_docs[i]
  91. break
  92. return rerank_results
  93. else:
  94. self.logger.info("使用 BGE Reranker 进行重排序")
  95. rerank_results = rerank_model.bge_rerank(query_text, cleaned_documents, top_k)
  96. # 将清理后的文本映射回原始文本
  97. for result in rerank_results:
  98. cleaned_text = result.get('text', '')
  99. # 查找原始文本
  100. for i, cleaned in enumerate(cleaned_documents):
  101. if cleaned == cleaned_text:
  102. result['text'] = valid_original_docs[i]
  103. break
  104. return rerank_results
  105. except Exception as e:
  106. self.logger.error(f"重排序失败,模型类型: {self.rerank_model_type}, 错误: {str(e)}")
  107. # 返回原始顺序作为fallback
  108. return [{"text": doc, "score": 0.0} for i, doc in enumerate(documents[:top_k])]
  109. @track_execution_time
  110. async def entity_recall(self, main_entity: str,assisted_search_entity: list,
  111. top_k: int = 5) -> List[Dict[str, Any]]:
  112. """
  113. 执行实体召回
  114. :param main_entity: 查询实体
  115. :param assisted_search_entity: 辅助搜索实体
  116. :param top_k: 返回结果数量
  117. """
  118. collection_name = "first_bfp_collection_entity"
  119. # 主实体搜索 - 使用异步方法
  120. entity_result = await self.async_multi_stage_recall(
  121. collection_name=collection_name,
  122. query_text=main_entity,
  123. hybrid_top_k=20, # 从默认50降到20
  124. top_k=top_k
  125. )
  126. assist_tasks = [
  127. self.async_multi_stage_recall(
  128. collection_name=collection_name,
  129. query_text=assisted_search_entity,
  130. hybrid_top_k=20, # 从默认50降到20
  131. top_k=top_k
  132. ) for assisted_search_entity in assisted_search_entity
  133. ]
  134. # 辅助搜索,异步并发
  135. assist_results_list = await asyncio.gather(*assist_tasks,return_exceptions=True)
  136. assist_results = []
  137. for res in assist_results_list:
  138. if isinstance(res, Exception):
  139. self.logger.error(f"辅助实体召回失败: {str(res)}")
  140. else:
  141. assist_results.extend(res)
  142. all_results = entity_result + assist_results
  143. entity_list = list(set([item['text_content'] for item in all_results]))
  144. self.logger.info(f"entity_list:{entity_list}")
  145. return entity_list
  146. @track_execution_time
  147. async def async_bfp_recall(self, entity_list: List[str],background: str ,
  148. top_k: int = 3,) -> List[Dict[str, Any]]:
  149. """
  150. 混合搜索召回 - 向量+BM25召回
  151. Args:
  152. entity_list: 实体列表
  153. background: 背景/上下文信息,用于二次重排
  154. top_k: 返回结果数量
  155. """
  156. import time
  157. start_time = time.time()
  158. # 异步并发召回编制依据
  159. collection_name = "first_bfp_collection_test"
  160. gather_start = time.time()
  161. # 优化:降低hybrid_top_k参数从50到20,减少混合搜索时间
  162. bfp_tasks = [
  163. self.async_multi_stage_recall(
  164. collection_name=collection_name,
  165. query_text=entity,
  166. hybrid_top_k=20, # 从50降到20,减少60%的混合搜索时间
  167. top_k=top_k
  168. ) for entity in entity_list
  169. ]
  170. bfp_tasks_list = await asyncio.gather(*bfp_tasks,return_exceptions=True)
  171. gather_end = time.time()
  172. gather_time = gather_end - gather_start
  173. bfp_results = []
  174. for res in bfp_tasks_list:
  175. if isinstance(res, Exception):
  176. self.logger.error(f"辅助实体召回失败: {str(res)}")
  177. else:
  178. bfp_results.extend(res)
  179. # BFP召回结果已经通过multi_stage_recall进行了重排序,保持原有顺序
  180. # 只对第一次重排序得分大于0.8的文档进行二次重排序
  181. high_score_results = [item for item in bfp_results if item.get('rerank_score', 0) > 0.8]
  182. low_score_results = [item for item in bfp_results if item.get('rerank_score', 0) <= 0.8]
  183. self.logger.info(f"筛选结果:高分文档(>0.8) {len(high_score_results)} 个,低分文档(≤0.8) {len(low_score_results)} 个")
  184. # 如果没有高分文档,直接返回原始结果
  185. if not high_score_results:
  186. self.logger.info("没有得分大于0.8的文档,跳过二次重排序,直接返回原始结果")
  187. return bfp_results
  188. # 提取高分文档的文本内容用于二次重排
  189. high_score_text_content = list(set([item['text_content'] for item in high_score_results]))
  190. self.logger.info(f"提取高分文档文本内容,共 {len(high_score_text_content)} 个,准备二次重排")
  191. # 二次重排 - 使用配置的重排序模型
  192. rerank_start = time.time()
  193. bfp_rerank_result = self._get_rerank_results(background, high_score_text_content, 5)
  194. rerank_end = time.time()
  195. self.logger.info(f"二次重排序耗时: {rerank_end - rerank_start:.3f}秒")
  196. # 根据重排结果重新组织数据
  197. reorganize_start = time.time()
  198. final_results = []
  199. text_to_metadata = {item['text_content']: item for item in high_score_results}
  200. # 处理二次重排序的高分文档
  201. for rerank_item in bfp_rerank_result:
  202. text = rerank_item.get('text', '')
  203. score = rerank_item.get('score', 0.0)
  204. if text in text_to_metadata:
  205. original_item = text_to_metadata[text].copy()
  206. original_item['bfp_rerank_score'] = score
  207. final_results.append(original_item)
  208. reorganize_end = time.time()
  209. total_time = reorganize_end - start_time
  210. self.logger.info(f"结果重组耗时: {reorganize_end - reorganize_start:.3f}秒")
  211. self.logger.info(f"二次重排完成,返回 {len(final_results)} 个高分文档,丢弃 {len(low_score_results)} 个低分文档")
  212. self.logger.info(f"[async_bfp_recall] 总耗时: {total_time:.3f}秒 (召回: {gather_end-gather_start:.3f}s + 重排: {rerank_end-rerank_start:.3f}s + 其他: {total_time-(gather_end-gather_start)-(rerank_end-rerank_start):.3f}s)")
  213. return final_results
  214. def hybrid_search_recall(self, collection_name: str, query_text: str,
  215. top_k: int = 10 , ranker_type: str = "weighted",
  216. dense_weight: float = 0.7, sparse_weight: float = 0.3) -> List[Dict[str, Any]]:
  217. """
  218. 混合搜索召回 - 向量+BM25召回
  219. Args:
  220. collection_name: 集合名称
  221. query_text: 查询文本
  222. top_k: 返回结果数量
  223. ranker_type: 重排序类型 "weighted" 或 "rrf"
  224. dense_weight: 密集向量权重
  225. sparse_weight: 稀疏向量权重
  226. Returns:
  227. List[Dict]: 搜索结果列表
  228. """
  229. try:
  230. self.logger.info(f"开始混合检索")
  231. param = {'collection_name': collection_name}
  232. # 直接调用同步的混合搜索(在同步方法中)
  233. results = self.vector_manager.hybrid_search(
  234. param=param,
  235. query_text=query_text,
  236. top_k=top_k,
  237. ranker_type=ranker_type,
  238. dense_weight=dense_weight,
  239. sparse_weight=sparse_weight
  240. )
  241. # 详细记录混合搜索结果
  242. self.logger.info(f"混合搜索召回返回 {len(results)} 个结果")
  243. # for i, result in enumerate(results):
  244. # text_content = result.get('text_content', '')
  245. # metadata = result.get('metadata', {})
  246. # title = metadata.get('title', 'N/A')
  247. # file = metadata.get('file', 'N/A')
  248. # self.logger.info(f"混合搜索结果 {i+1}: 标题='{title}', 文件='{file}', 内容长度={len(text_content)}")
  249. # # self.logger.info(f" 完整元数据: {metadata}")
  250. # # self.logger.info(f" 文本内容: '{text_content}'")
  251. return results
  252. except Exception as e:
  253. self.logger.error(f"混合搜索召回失败: {str(e)}")
  254. return []
  255. def rerank_recall(self, candidates_with_metadata: List[Dict[str, Any]], query_text: str,
  256. top_k: int = None ) -> List[Dict[str, Any]]:
  257. """
  258. 重排序召回 - 使用配置的重排序模型对候选文档重新排序
  259. Args:
  260. candidates_with_metadata: 候选文档列表,包含文本内容和元数据
  261. query_text: 查询文本
  262. top_k: 返回结果数量
  263. Returns:
  264. List[Dict]: 重排序后的结果列表,包含原始索引信息
  265. """
  266. try:
  267. # 第一步:基于文本内容+元数据的组合去重
  268. unique_candidates = []
  269. original_indices_map = [] # 记录每个去重后的候选文档对应的原始索引列表
  270. unique_combinations = set() # 记录已见过的文本+元数据组合
  271. for original_index, candidate in enumerate(candidates_with_metadata):
  272. text_content = candidate.get('text_content', '')
  273. metadata = candidate.get('metadata', {})
  274. # 处理嵌套的metadata字符串
  275. title = ''
  276. file = ''
  277. if 'metadata' in metadata and isinstance(metadata['metadata'], str):
  278. import json
  279. try:
  280. # 解析JSON格式的metadata
  281. inner_metadata = json.loads(metadata['metadata'])
  282. title = inner_metadata.get('title', '')
  283. file = inner_metadata.get('file', '')
  284. except (json.JSONDecodeError, TypeError):
  285. pass
  286. else:
  287. title = metadata.get('title', '')
  288. file = metadata.get('file', '')
  289. # 创建组合键:文本内容 + 关键元数据
  290. combination_key = (text_content, title, file)
  291. if combination_key not in unique_combinations:
  292. # 新的唯一组合
  293. unique_candidates.append(candidate)
  294. original_indices_map.append([original_index])
  295. unique_combinations.add(combination_key)
  296. else:
  297. # 找到对应的唯一候选并添加索引
  298. for unique_idx, unique_candidate in enumerate(unique_candidates):
  299. if unique_candidate.get('text_content', '') == text_content:
  300. # 解析唯一候选的元数据
  301. unique_metadata = unique_candidate.get('metadata', {})
  302. unique_title = ''
  303. unique_file = ''
  304. if 'metadata' in unique_metadata and isinstance(unique_metadata['metadata'], str):
  305. import json
  306. try:
  307. inner_metadata = json.loads(unique_metadata['metadata'])
  308. unique_title = inner_metadata.get('title', '')
  309. unique_file = inner_metadata.get('file', '')
  310. except (json.JSONDecodeError, TypeError):
  311. pass
  312. else:
  313. unique_title = unique_metadata.get('title', '')
  314. unique_file = unique_metadata.get('file', '')
  315. if unique_title == title and unique_file == file:
  316. original_indices_map[unique_idx].append(original_index)
  317. break
  318. # 提取唯一候选文档的文本内容用于重排序
  319. unique_texts = [candidate.get('text_content', '') for candidate in unique_candidates]
  320. # 使用配置的重排序模型进行重排序
  321. rerank_results = self._get_rerank_results(query_text, unique_texts, top_k)
  322. # 转换结果格式,使用索引映射来处理原始索引
  323. scored_docs = []
  324. for i, api_result in enumerate(rerank_results):
  325. rerank_text = api_result.get('text', '')
  326. rerank_score = float(api_result.get('score', '0.0'))
  327. # 使用去重时的索引映射
  328. original_index = original_indices_map[i][0] # 取第一个原始索引
  329. original_candidate = unique_candidates[i] # 获取原始候选文档(包含元数据)
  330. # 获取原始混合搜索的评分信息
  331. hybrid_distance = original_candidate.get('distance', 0.0)
  332. hybrid_similarity = original_candidate.get('similarity', 0.0)
  333. # 解析元数据获取标题用于日志
  334. metadata = original_candidate.get('metadata', {})
  335. title = 'N/A'
  336. if 'metadata' in metadata and isinstance(metadata['metadata'], str):
  337. try:
  338. import json
  339. inner_metadata = json.loads(metadata['metadata'])
  340. title = inner_metadata.get('title', 'N/A')
  341. except:
  342. pass
  343. scored_docs.append({
  344. 'text_content': rerank_text,
  345. 'metadata': original_candidate.get('metadata', {}), # 保留原始元数据
  346. 'rerank_score': rerank_score,
  347. 'original_index': original_index,
  348. 'rerank_rank': i,
  349. 'duplicate_count': len(original_indices_map[i]), # 记录重复数量
  350. 'hybrid_distance': hybrid_distance, # 保留原始混合搜索评分
  351. 'hybrid_similarity': hybrid_similarity
  352. })
  353. # 输出双重评分信息
  354. # self.logger.info(f"重排序评分 #{i+1}: 标题='{title}' | 混合搜索相似度={hybrid_similarity:.4f} | BGE重排序评分={rerank_score:.6f}")
  355. return scored_docs
  356. except Exception as e:
  357. self.logger.error(f"重排序召回失败: {str(e)}")
  358. return []
  359. def multi_stage_recall(self, collection_name: str, query_text: str,
  360. hybrid_top_k: int = 50, top_k: int = 10,
  361. ranker_type: str = "weighted") -> List[Dict[str, Any]]:
  362. """
  363. 多路召回 - 先混合搜索召回,再重排序,只返回重排序结果
  364. Args:
  365. collection_name: 集合名称
  366. query_text: 查询文本
  367. hybrid_top_k: 混合搜索召回的文档数量
  368. top_k: 最终返回的文档数量
  369. ranker_type: 混合搜索的重排序类型
  370. Returns:
  371. List[Dict]: 重排序后的结果列表,只包含重排序分数
  372. """
  373. try:
  374. self.logger.info(f"执行多路召回")
  375. # 第一阶段:混合搜索召回(向量+BM25)
  376. hybrid_results = self.hybrid_search_recall(
  377. collection_name=collection_name,
  378. query_text=query_text,
  379. top_k=hybrid_top_k,
  380. ranker_type=ranker_type
  381. )
  382. if not hybrid_results:
  383. self.logger.warning("混合搜索召回无结果,返回空列表")
  384. return []
  385. # 第二阶段:重排序召回,传递完整的混合搜索结果(包含元数据)
  386. rerank_results = self.rerank_recall(
  387. candidates_with_metadata=hybrid_results,
  388. query_text=query_text,
  389. top_k=top_k
  390. )
  391. # 优化重排序结果的元数据结构
  392. final_results = []
  393. for rerank_result in rerank_results:
  394. metadata = rerank_result.get('metadata', {}).copy()
  395. duplicate_count = rerank_result.get('duplicate_count', 1)
  396. # 如果内层有metadata字段,将其提取到外层
  397. if 'metadata' in metadata and isinstance(metadata['metadata'], str):
  398. import json
  399. try:
  400. # 解析JSON格式的metadata
  401. inner_metadata = json.loads(metadata['metadata'])
  402. metadata.update(inner_metadata)
  403. # 移除内层的metadata字符串,避免重复
  404. del metadata['metadata']
  405. except (json.JSONDecodeError, TypeError):
  406. # 如果解析失败,保持原样
  407. pass
  408. # 移除重复的content字段
  409. if 'content' in metadata:
  410. del metadata['content']
  411. # 添加重复计数信息到元数据中
  412. if duplicate_count > 1:
  413. metadata['duplicate_count'] = duplicate_count
  414. # 输出优化后的结果,包含双重评分
  415. final_result = {
  416. 'text_content': rerank_result['text_content'],
  417. 'metadata': metadata,
  418. 'hybrid_similarity': rerank_result.get('hybrid_similarity', 0.0), # 混合搜索相似度
  419. 'rerank_score': rerank_result.get('rerank_score', 0.0) # BGE重排序评分
  420. }
  421. final_results.append(final_result)
  422. self.logger.debug(f"元数据优化完成: 重排序排名{rerank_result.get('rerank_rank')}, 重复数量={duplicate_count}")
  423. return final_results
  424. except Exception as e:
  425. self.logger.error(f"多路召回失败: {str(e)}")
  426. return []
  427. async def async_multi_stage_recall(self, collection_name: str, query_text: str,
  428. hybrid_top_k: int = 50, top_k: int = 10,
  429. ranker_type: str = "weighted") -> List[Dict[str, Any]]:
  430. """
  431. 多路召回 - 先混合搜索召回,再重排序,只返回重排序结果
  432. Args:
  433. collection_name: 集合名称
  434. query_text: 查询文本
  435. hybrid_top_k: 混合搜索召回的文档数量
  436. top_k: 最终返回的文档数量
  437. ranker_type: 混合搜索的重排序类型
  438. Returns:
  439. List[Dict]: 重排序后的结果列表,只包含重排序分数
  440. """
  441. import time
  442. try:
  443. start_time = time.time()
  444. # 第一阶段:混合搜索召回(向量+BM25)
  445. hybrid_start = time.time()
  446. hybrid_results = await asyncio.to_thread(
  447. self.hybrid_search_recall,
  448. collection_name=collection_name,
  449. query_text=query_text,
  450. top_k=hybrid_top_k,
  451. ranker_type=ranker_type
  452. )
  453. if not hybrid_results:
  454. return []
  455. # 第二阶段:重排序召回
  456. rerank_results = self.rerank_recall(
  457. candidates_with_metadata=hybrid_results,
  458. query_text=query_text,
  459. top_k=top_k
  460. )
  461. # 优化重排序结果的元数据结构
  462. final_results = []
  463. for rerank_result in rerank_results:
  464. metadata = rerank_result.get('metadata', {}).copy()
  465. duplicate_count = rerank_result.get('duplicate_count', 1)
  466. # 如果内层有metadata字段,将其提取到外层
  467. if 'metadata' in metadata and isinstance(metadata['metadata'], str):
  468. import json
  469. try:
  470. # 解析JSON格式的metadata
  471. inner_metadata = json.loads(metadata['metadata'])
  472. metadata.update(inner_metadata)
  473. # 移除内层的metadata字符串,避免重复
  474. del metadata['metadata']
  475. except (json.JSONDecodeError, TypeError):
  476. # 如果解析失败,保持原样
  477. pass
  478. # 移除重复的content字段
  479. if 'content' in metadata:
  480. del metadata['content']
  481. # 添加重复计数信息到元数据中
  482. if duplicate_count > 1:
  483. metadata['duplicate_count'] = duplicate_count
  484. # 输出优化后的结果,包含双重评分
  485. final_result = {
  486. 'text_content': rerank_result['text_content'],
  487. 'metadata': metadata,
  488. 'hybrid_similarity': rerank_result.get('hybrid_similarity', 0.0), # 混合搜索相似度
  489. 'rerank_score': rerank_result.get('rerank_score', 0.0) # BGE重排序评分
  490. }
  491. final_results.append(final_result)
  492. self.logger.debug(f"元数据优化完成: 重排序排名{rerank_result.get('rerank_rank')}, 重复数量={duplicate_count}")
  493. return final_results
  494. except Exception as e:
  495. self.logger.error(f"多路召回失败: {str(e)}")
  496. return []
  497. # 创建全局召回管理器实例
  498. retrieval_manager = RetrievalManager()