from typing import List, Dict, Any, Optional from foundation.ai.models.rerank_model import rerank_model from foundation.infrastructure.config.config import config_handler from foundation.observability.logger.loggering import server_logger from foundation.database.base.vector.milvus_vector import MilvusVectorManager class RetrievalManager: """ 召回管理器,实现多路召回功能 """ def __init__(self): """ 初始化召回管理器 """ self.vector_manager = MilvusVectorManager() self.logger = server_logger self.dense_weight = config_handler.get('hybrid_search', 'DENSE_WEIGHT', 0.7) self.sparse_weight = config_handler.get('hybrid_search', 'SPARSE_WEIGHT', 0.3) def hybrid_search_recall(self, collection_name: str, query_text: str, top_k: int = 10, ranker_type: str = "weighted", dense_weight: float = 0.7, sparse_weight: float = 0.3) -> List[Dict[str, Any]]: """ 混合搜索召回 - 向量+BM25召回 Args: collection_name: 集合名称 query_text: 查询文本 top_k: 返回结果数量 ranker_type: 重排序类型 "weighted" 或 "rrf" dense_weight: 密集向量权重 sparse_weight: 稀疏向量权重 Returns: List[Dict]: 搜索结果列表 """ try: self.logger.info(f"开始混合检索") param = {'collection_name': collection_name} results = self.vector_manager.hybrid_search( param=param, query_text=query_text, top_k=top_k, ranker_type=ranker_type, dense_weight=dense_weight, sparse_weight=sparse_weight ) self.logger.info(f"混合搜索召回返回 {len(results)} 个结果") return results except Exception as e: self.logger.error(f"混合搜索召回失败: {str(e)}") return [] def rerank_recall(self, candidates: List[str], query_text: str, top_k: int = None ) -> List[Dict[str, Any]]: """ 重排序召回 - 使用BGE重排序模型对候选文档重新排序 Args: candidates: 候选文档列表 query_text: 查询文本 top_k: 返回结果数量 Returns: List[Dict]: 重排序后的结果列表,包含原始索引信息 """ try: self.logger.info(f"开始重排序召回,候选文档数量: {len(candidates)}") # 调用重排序执行器 rerank_results = rerank_model.bge_rerank(query_text, candidates, top_k) # 转换结果格式,通过文本匹配找到正确的原始索引 scored_docs = [] for i, api_result in enumerate(rerank_results): rerank_text = api_result.get('text', '') rerank_score = float(api_result.get('score', '0.0')) # 通过文本匹配找到原始在candidates中的索引 original_index = None for j, candidate_text in enumerate(candidates): if candidate_text == rerank_text: original_index = j break if original_index is None: self.logger.warning(f"无法找到重排序结果的原始索引,文本: {rerank_text[:50]}...") original_index = i # 回退到当前索引 scored_docs.append({ 'text_content': rerank_text, 'rerank_score': rerank_score, 'original_index': original_index, # 正确的原始索引 'rerank_rank': i # 重排序后的排名 }) self.logger.debug(f"重排序结果 {i}: 原始索引={original_index}, 重排序分数={rerank_score}") self.logger.info(f"重排序召回返回 {len(scored_docs)} 个结果") return scored_docs except Exception as e: self.logger.error(f"重排序召回失败: {str(e)}") return [] def multi_stage_recall(self, collection_name: str, query_text: str, hybrid_top_k: int = 50, top_k: int = 3, ranker_type: str = "weighted") -> List[Dict[str, Any]]: """ 多路召回 - 先混合搜索召回,再重排序,只返回重排序结果 Args: collection_name: 集合名称 query_text: 查询文本 hybrid_top_k: 混合搜索召回的文档数量 top_k: 最终返回的文档数量 ranker_type: 混合搜索的重排序类型 Returns: List[Dict]: 重排序后的结果列表,只包含重排序分数 """ try: self.logger.info(f"执行多路召回") # 第一阶段:混合搜索召回(向量+BM25) hybrid_results = self.hybrid_search_recall( collection_name=collection_name, query_text=query_text, top_k=hybrid_top_k, ranker_type=ranker_type ) if not hybrid_results: self.logger.warning("混合搜索召回无结果,返回空列表") return [] # 提取候选文档文本 candidates = [result['text_content'] for result in hybrid_results] # 第二阶段:重排序召回 rerank_results = self.rerank_recall( candidates=candidates, query_text=query_text, top_k=top_k ) # 为重排序结果添加混合搜索的原始元数据,优化metadata结构 final_results = [] for rerank_result in rerank_results: # 使用正确的原始索引进行元数据映射 original_index = rerank_result.get('original_index', 0) if original_index < len(hybrid_results): original_metadata = hybrid_results[original_index].get('metadata', {}) # 提取内层metadata并移除重复的content optimized_metadata = original_metadata.copy() # 如果内层有metadata字段,将其提取到外层 if 'metadata' in optimized_metadata and isinstance(optimized_metadata['metadata'], str): import json try: # 解析JSON格式的metadata inner_metadata = json.loads(optimized_metadata['metadata']) optimized_metadata.update(inner_metadata) # 移除内层的metadata字符串,避免重复 del optimized_metadata['metadata'] except (json.JSONDecodeError, TypeError): # 如果解析失败,保持原样 pass # 移除重复的content字段 if 'content' in optimized_metadata: del optimized_metadata['content'] # 输出优化后的结果 final_result = { 'text_content': rerank_result['text_content'], 'metadata': optimized_metadata } final_results.append(final_result) self.logger.debug(f"元数据映射成功: 重排序排名{rerank_result.get('rerank_rank')} -> 原始索引{original_index}") else: self.logger.warning(f"元数据映射失败: 原始索引{original_index}超出范围(0-{len(hybrid_results)-1})") self.logger.info(f"多路召回完成,返回 {len(final_results)} 个重排序结果") return final_results except Exception as e: self.logger.error(f"多路召回失败: {str(e)}") return [] # 创建全局召回管理器实例 retrieval_manager = RetrievalManager()