|
|
@@ -1,8 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
from foundation.ai.models.rerank_model import rerank_model
|
|
|
+from foundation.observability.monitoring.time_statistics import track_execution_time
|
|
|
from foundation.infrastructure.config.config import config_handler
|
|
|
from foundation.observability.logger.loggering import server_logger
|
|
|
from foundation.database.base.vector.milvus_vector import MilvusVectorManager
|
|
|
@@ -21,15 +24,246 @@ class RetrievalManager:
|
|
|
self.dense_weight = config_handler.get('hybrid_search', 'DENSE_WEIGHT', 0.7)
|
|
|
self.sparse_weight = config_handler.get('hybrid_search', 'SPARSE_WEIGHT', 0.3)
|
|
|
|
|
|
- def entity_recall(self, collection_name: str, query_text: str,
|
|
|
- top_k: int = 10) -> List[Dict[str, Any]]:
|
|
|
+ # 重排序模型配置
|
|
|
+ self.rerank_model_type = config_handler.get('retrieval', 'RERANK_MODEL_TYPE', 'bge') # 'bge' 或 'qwen3'
|
|
|
+ self.logger.info(f"初始化重排序模型类型: {self.rerank_model_type}")
|
|
|
+
|
|
|
+ def set_rerank_model(self, model_type: str):
|
|
|
+ """
|
|
|
+ 设置重排序模型类型
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model_type: 模型类型 ('bge' 或 'qwen3')
|
|
|
+ """
|
|
|
+ if model_type not in ['bge', 'qwen3']:
|
|
|
+ raise ValueError("model_type 必须是 'bge' 或 'qwen3'")
|
|
|
+
|
|
|
+ self.rerank_model_type = model_type
|
|
|
+ self.logger.info(f"重排序模型类型已设置为: {model_type}")
|
|
|
+
|
|
|
+ def _clean_document(self, doc: str) -> str:
|
|
|
+ """
|
|
|
+ 清理文档文本,移除HTML标签和特殊字符
|
|
|
+
|
|
|
+ Args:
|
|
|
+ doc: 原始文档文本
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ str: 清理后的文档文本
|
|
|
+ """
|
|
|
+ if not isinstance(doc, str):
|
|
|
+ self.logger.debug(f"文档类型转换: {type(doc)} -> str")
|
|
|
+ return str(doc)
|
|
|
+
|
|
|
+ original_length = len(doc)
|
|
|
+
|
|
|
+ # 移除HTML标签
|
|
|
+ import re
|
|
|
+ doc = re.sub(r'<[^>]+>', '', doc)
|
|
|
+
|
|
|
+ # 移除多余的空白字符
|
|
|
+ doc = re.sub(r'\s+', ' ', doc)
|
|
|
+
|
|
|
+ # 更宽松的字符过滤 - 保留更多字符
|
|
|
+ doc = re.sub(r'[^\u4e00-\u9fff\w\s.,;:!?()()。,;:!?\-\+\=\*/%&@#¥$【】「」""''""\n\r]', '', doc)
|
|
|
+
|
|
|
+ # 截断过长的文本
|
|
|
+ if len(doc) > 8000: # 设置最大长度限制
|
|
|
+ doc = doc[:8000] + "..."
|
|
|
+
|
|
|
+ cleaned_doc = doc.strip()
|
|
|
+ self.logger.debug(f"文档清理: {original_length} -> {len(cleaned_doc)} 字符")
|
|
|
+
|
|
|
+ return cleaned_doc
|
|
|
+
|
|
|
+ def _get_rerank_results(self, query_text: str, documents: List[str], top_k: int = None) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 根据配置选择重排序模型并执行重排序
|
|
|
+
|
|
|
+ Args:
|
|
|
+ query_text: 查询文本
|
|
|
+ documents: 文档列表
|
|
|
+ top_k: 返回结果数量
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 重排序后的结果列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 清理和验证文档列表
|
|
|
+ cleaned_documents = []
|
|
|
+ valid_original_docs = []
|
|
|
+
|
|
|
+ for doc in documents:
|
|
|
+ if doc and isinstance(doc, str) and doc.strip():
|
|
|
+ cleaned_doc = self._clean_document(doc)
|
|
|
+ if cleaned_doc and len(cleaned_doc) > 3:
|
|
|
+ cleaned_documents.append(cleaned_doc)
|
|
|
+ valid_original_docs.append(doc)
|
|
|
+
|
|
|
+ if not cleaned_documents:
|
|
|
+ return []
|
|
|
+
|
|
|
+ if self.rerank_model_type == 'qwen3':
|
|
|
+ self.logger.info("使用 Qwen3-Reranker-8B 进行重排序")
|
|
|
+ rerank_results = rerank_model.qwen3_rerank(query_text, cleaned_documents, top_k)
|
|
|
+
|
|
|
+ # 将清理后的文本映射回原始文本
|
|
|
+ for result in rerank_results:
|
|
|
+ cleaned_text = result.get('text', '')
|
|
|
+ # 查找原始文本
|
|
|
+ for i, cleaned in enumerate(cleaned_documents):
|
|
|
+ if cleaned == cleaned_text:
|
|
|
+ result['text'] = valid_original_docs[i]
|
|
|
+ break
|
|
|
+
|
|
|
+ return rerank_results
|
|
|
+ else:
|
|
|
+ self.logger.info("使用 BGE Reranker 进行重排序")
|
|
|
+ rerank_results = rerank_model.bge_rerank(query_text, cleaned_documents, top_k)
|
|
|
+
|
|
|
+ # 将清理后的文本映射回原始文本
|
|
|
+ for result in rerank_results:
|
|
|
+ cleaned_text = result.get('text', '')
|
|
|
+ # 查找原始文本
|
|
|
+ for i, cleaned in enumerate(cleaned_documents):
|
|
|
+ if cleaned == cleaned_text:
|
|
|
+ result['text'] = valid_original_docs[i]
|
|
|
+ break
|
|
|
+
|
|
|
+ return rerank_results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"重排序失败,模型类型: {self.rerank_model_type}, 错误: {str(e)}")
|
|
|
+ # 返回原始顺序作为fallback
|
|
|
+ return [{"text": doc, "score": 0.0} for i, doc in enumerate(documents[:top_k])]
|
|
|
+
|
|
|
+ @track_execution_time
|
|
|
+ async def entity_recall(self, main_entity: str,assisted_search_entity: list,
|
|
|
+ top_k: int = 5) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
执行实体召回
|
|
|
- :param collection_name: 集合名称
|
|
|
- :param query_text: 查询文本
|
|
|
+ :param main_entity: 查询实体
|
|
|
+ :param assisted_search_entity: 辅助搜索实体
|
|
|
:param top_k: 返回结果数量
|
|
|
- :return: 召回结果列表
|
|
|
"""
|
|
|
+ collection_name = "first_bfp_collection_entity"
|
|
|
+ # 主实体搜索 - 使用异步方法
|
|
|
+ entity_result = await self.async_multi_stage_recall(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=main_entity,
|
|
|
+ hybrid_top_k=20, # 从默认50降到20
|
|
|
+ top_k=top_k
|
|
|
+ )
|
|
|
+
|
|
|
+ assist_tasks = [
|
|
|
+ self.async_multi_stage_recall(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=assisted_search_entity,
|
|
|
+ hybrid_top_k=20, # 从默认50降到20
|
|
|
+ top_k=top_k
|
|
|
+ ) for assisted_search_entity in assisted_search_entity
|
|
|
+ ]
|
|
|
+ # 辅助搜索,异步并发
|
|
|
+ assist_results_list = await asyncio.gather(*assist_tasks,return_exceptions=True)
|
|
|
+ assist_results = []
|
|
|
+ for res in assist_results_list:
|
|
|
+ if isinstance(res, Exception):
|
|
|
+ self.logger.error(f"辅助实体召回失败: {str(res)}")
|
|
|
+ else:
|
|
|
+ assist_results.extend(res)
|
|
|
+
|
|
|
+ all_results = entity_result + assist_results
|
|
|
+
|
|
|
+ entity_list = list(set([item['text_content'] for item in all_results]))
|
|
|
+ self.logger.info(f"entity_list:{entity_list}")
|
|
|
+
|
|
|
+ return entity_list
|
|
|
+
|
|
|
+ @track_execution_time
|
|
|
+ async def async_bfp_recall(self, entity_list: List[str],background: str ,
|
|
|
+ top_k: int = 3,) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 混合搜索召回 - 向量+BM25召回
|
|
|
+
|
|
|
+ Args:
|
|
|
+ entity_list: 实体列表
|
|
|
+ background: 背景/上下文信息,用于二次重排
|
|
|
+ top_k: 返回结果数量
|
|
|
+ """
|
|
|
+ import time
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 异步并发召回编制依据
|
|
|
+ collection_name = "first_bfp_collection_test"
|
|
|
+
|
|
|
+ gather_start = time.time()
|
|
|
+ # 优化:降低hybrid_top_k参数从50到20,减少混合搜索时间
|
|
|
+ bfp_tasks = [
|
|
|
+ self.async_multi_stage_recall(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=entity,
|
|
|
+ hybrid_top_k=20, # 从50降到20,减少60%的混合搜索时间
|
|
|
+ top_k=top_k
|
|
|
+ ) for entity in entity_list
|
|
|
+ ]
|
|
|
+
|
|
|
+ bfp_tasks_list = await asyncio.gather(*bfp_tasks,return_exceptions=True)
|
|
|
+ gather_end = time.time()
|
|
|
+ gather_time = gather_end - gather_start
|
|
|
+
|
|
|
+ bfp_results = []
|
|
|
+ for res in bfp_tasks_list:
|
|
|
+ if isinstance(res, Exception):
|
|
|
+ self.logger.error(f"辅助实体召回失败: {str(res)}")
|
|
|
+ else:
|
|
|
+ bfp_results.extend(res)
|
|
|
+
|
|
|
+ # BFP召回结果已经通过multi_stage_recall进行了重排序,保持原有顺序
|
|
|
+ # 只对第一次重排序得分大于0.8的文档进行二次重排序
|
|
|
+ high_score_results = [item for item in bfp_results if item.get('rerank_score', 0) > 0.8]
|
|
|
+ low_score_results = [item for item in bfp_results if item.get('rerank_score', 0) <= 0.8]
|
|
|
+
|
|
|
+ self.logger.info(f"筛选结果:高分文档(>0.8) {len(high_score_results)} 个,低分文档(≤0.8) {len(low_score_results)} 个")
|
|
|
+
|
|
|
+ # 如果没有高分文档,直接返回原始结果
|
|
|
+ if not high_score_results:
|
|
|
+ self.logger.info("没有得分大于0.8的文档,跳过二次重排序,直接返回原始结果")
|
|
|
+ return bfp_results
|
|
|
+
|
|
|
+ # 提取高分文档的文本内容用于二次重排
|
|
|
+ high_score_text_content = list(set([item['text_content'] for item in high_score_results]))
|
|
|
+ self.logger.info(f"提取高分文档文本内容,共 {len(high_score_text_content)} 个,准备二次重排")
|
|
|
+
|
|
|
+ # 二次重排 - 使用配置的重排序模型
|
|
|
+ rerank_start = time.time()
|
|
|
+ bfp_rerank_result = self._get_rerank_results(background, high_score_text_content, 5)
|
|
|
+ rerank_end = time.time()
|
|
|
+ self.logger.info(f"二次重排序耗时: {rerank_end - rerank_start:.3f}秒")
|
|
|
+
|
|
|
+ # 根据重排结果重新组织数据
|
|
|
+ reorganize_start = time.time()
|
|
|
+ final_results = []
|
|
|
+ text_to_metadata = {item['text_content']: item for item in high_score_results}
|
|
|
+
|
|
|
+ # 处理二次重排序的高分文档
|
|
|
+ for rerank_item in bfp_rerank_result:
|
|
|
+ text = rerank_item.get('text', '')
|
|
|
+ score = rerank_item.get('score', 0.0)
|
|
|
+
|
|
|
+ if text in text_to_metadata:
|
|
|
+ original_item = text_to_metadata[text].copy()
|
|
|
+ original_item['bfp_rerank_score'] = score
|
|
|
+ final_results.append(original_item)
|
|
|
+
|
|
|
+ reorganize_end = time.time()
|
|
|
+ total_time = reorganize_end - start_time
|
|
|
+
|
|
|
+ self.logger.info(f"结果重组耗时: {reorganize_end - reorganize_start:.3f}秒")
|
|
|
+ self.logger.info(f"二次重排完成,返回 {len(final_results)} 个高分文档,丢弃 {len(low_score_results)} 个低分文档")
|
|
|
+ self.logger.info(f"[async_bfp_recall] 总耗时: {total_time:.3f}秒 (召回: {gather_end-gather_start:.3f}s + 重排: {rerank_end-rerank_start:.3f}s + 其他: {total_time-(gather_end-gather_start)-(rerank_end-rerank_start):.3f}s)")
|
|
|
+
|
|
|
+ return final_results
|
|
|
+
|
|
|
|
|
|
def hybrid_search_recall(self, collection_name: str, query_text: str,
|
|
|
top_k: int = 10 , ranker_type: str = "weighted",
|
|
|
@@ -52,6 +286,8 @@ class RetrievalManager:
|
|
|
self.logger.info(f"开始混合检索")
|
|
|
|
|
|
param = {'collection_name': collection_name}
|
|
|
+
|
|
|
+ # 直接调用同步的混合搜索(在同步方法中)
|
|
|
results = self.vector_manager.hybrid_search(
|
|
|
param=param,
|
|
|
query_text=query_text,
|
|
|
@@ -81,7 +317,7 @@ class RetrievalManager:
|
|
|
def rerank_recall(self, candidates_with_metadata: List[Dict[str, Any]], query_text: str,
|
|
|
top_k: int = None ) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
- 重排序召回 - 使用BGE重排序模型对候选文档重新排序
|
|
|
+ 重排序召回 - 使用配置的重排序模型对候选文档重新排序
|
|
|
|
|
|
Args:
|
|
|
candidates_with_metadata: 候选文档列表,包含文本内容和元数据
|
|
|
@@ -92,8 +328,6 @@ class RetrievalManager:
|
|
|
List[Dict]: 重排序后的结果列表,包含原始索引信息
|
|
|
"""
|
|
|
try:
|
|
|
- self.logger.info(f"开始重排序召回,候选文档数量: {len(candidates_with_metadata)}")
|
|
|
-
|
|
|
# 第一步:基于文本内容+元数据的组合去重
|
|
|
unique_candidates = []
|
|
|
original_indices_map = [] # 记录每个去重后的候选文档对应的原始索引列表
|
|
|
@@ -151,13 +385,11 @@ class RetrievalManager:
|
|
|
original_indices_map[unique_idx].append(original_index)
|
|
|
break
|
|
|
|
|
|
- self.logger.info(f"基于内容+元数据去重后候选文档数量: {len(unique_candidates)}")
|
|
|
-
|
|
|
# 提取唯一候选文档的文本内容用于重排序
|
|
|
unique_texts = [candidate.get('text_content', '') for candidate in unique_candidates]
|
|
|
|
|
|
- # 调用重排序执行器,使用去重后的候选文档文本
|
|
|
- rerank_results = rerank_model.bge_rerank(query_text, unique_texts, top_k)
|
|
|
+ # 使用配置的重排序模型进行重排序
|
|
|
+ rerank_results = self._get_rerank_results(query_text, unique_texts, top_k)
|
|
|
|
|
|
# 转换结果格式,使用索引映射来处理原始索引
|
|
|
scored_docs = []
|
|
|
@@ -196,9 +428,8 @@ class RetrievalManager:
|
|
|
})
|
|
|
|
|
|
# 输出双重评分信息
|
|
|
- self.logger.info(f"重排序评分 #{i+1}: 标题='{title}' | 混合搜索相似度={hybrid_similarity:.4f} | BGE重排序评分={rerank_score:.6f}")
|
|
|
+ # self.logger.info(f"重排序评分 #{i+1}: 标题='{title}' | 混合搜索相似度={hybrid_similarity:.4f} | BGE重排序评分={rerank_score:.6f}")
|
|
|
|
|
|
- self.logger.info(f"重排序召回返回 {len(scored_docs)} 个结果")
|
|
|
return scored_docs
|
|
|
|
|
|
except Exception as e:
|
|
|
@@ -281,7 +512,6 @@ class RetrievalManager:
|
|
|
|
|
|
self.logger.debug(f"元数据优化完成: 重排序排名{rerank_result.get('rerank_rank')}, 重复数量={duplicate_count}")
|
|
|
|
|
|
- self.logger.info(f"多路召回完成,返回 {len(final_results)} 个重排序结果")
|
|
|
return final_results
|
|
|
|
|
|
except Exception as e:
|
|
|
@@ -289,5 +519,90 @@ class RetrievalManager:
|
|
|
return []
|
|
|
|
|
|
|
|
|
+ async def async_multi_stage_recall(self, collection_name: str, query_text: str,
|
|
|
+ hybrid_top_k: int = 50, top_k: int = 10,
|
|
|
+ ranker_type: str = "weighted") -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 多路召回 - 先混合搜索召回,再重排序,只返回重排序结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ collection_name: 集合名称
|
|
|
+ query_text: 查询文本
|
|
|
+ hybrid_top_k: 混合搜索召回的文档数量
|
|
|
+ top_k: 最终返回的文档数量
|
|
|
+ ranker_type: 混合搜索的重排序类型
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 重排序后的结果列表,只包含重排序分数
|
|
|
+ """
|
|
|
+ import time
|
|
|
+ try:
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 第一阶段:混合搜索召回(向量+BM25)
|
|
|
+ hybrid_start = time.time()
|
|
|
+ hybrid_results = await asyncio.to_thread(
|
|
|
+ self.hybrid_search_recall,
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=query_text,
|
|
|
+ top_k=hybrid_top_k,
|
|
|
+ ranker_type=ranker_type
|
|
|
+ )
|
|
|
+
|
|
|
+ if not hybrid_results:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 第二阶段:重排序召回
|
|
|
+ rerank_results = self.rerank_recall(
|
|
|
+ candidates_with_metadata=hybrid_results,
|
|
|
+ query_text=query_text,
|
|
|
+ top_k=top_k
|
|
|
+ )
|
|
|
+
|
|
|
+ # 优化重排序结果的元数据结构
|
|
|
+ final_results = []
|
|
|
+ for rerank_result in rerank_results:
|
|
|
+ metadata = rerank_result.get('metadata', {}).copy()
|
|
|
+ duplicate_count = rerank_result.get('duplicate_count', 1)
|
|
|
+
|
|
|
+ # 如果内层有metadata字段,将其提取到外层
|
|
|
+ if 'metadata' in metadata and isinstance(metadata['metadata'], str):
|
|
|
+ import json
|
|
|
+ try:
|
|
|
+ # 解析JSON格式的metadata
|
|
|
+ inner_metadata = json.loads(metadata['metadata'])
|
|
|
+ metadata.update(inner_metadata)
|
|
|
+ # 移除内层的metadata字符串,避免重复
|
|
|
+ del metadata['metadata']
|
|
|
+ except (json.JSONDecodeError, TypeError):
|
|
|
+ # 如果解析失败,保持原样
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 移除重复的content字段
|
|
|
+ if 'content' in metadata:
|
|
|
+ del metadata['content']
|
|
|
+
|
|
|
+ # 添加重复计数信息到元数据中
|
|
|
+ if duplicate_count > 1:
|
|
|
+ metadata['duplicate_count'] = duplicate_count
|
|
|
+
|
|
|
+ # 输出优化后的结果,包含双重评分
|
|
|
+ final_result = {
|
|
|
+ 'text_content': rerank_result['text_content'],
|
|
|
+ 'metadata': metadata,
|
|
|
+ 'hybrid_similarity': rerank_result.get('hybrid_similarity', 0.0), # 混合搜索相似度
|
|
|
+ 'rerank_score': rerank_result.get('rerank_score', 0.0) # BGE重排序评分
|
|
|
+ }
|
|
|
+ final_results.append(final_result)
|
|
|
+
|
|
|
+ self.logger.debug(f"元数据优化完成: 重排序排名{rerank_result.get('rerank_rank')}, 重复数量={duplicate_count}")
|
|
|
+
|
|
|
+ return final_results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ self.logger.error(f"多路召回失败: {str(e)}")
|
|
|
+ return []
|
|
|
+
|
|
|
# 创建全局召回管理器实例
|
|
|
-retrieval_manager = RetrievalManager()
|
|
|
+retrieval_manager = RetrievalManager()
|
|
|
+
|