| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- import json
- import asyncio
- from foundation.observability.monitoring.time_statistics import track_execution_time
- from foundation.ai.rag.retrieval.retrieval import retrieval_manager
- from foundation.observability.logger.loggering import server_logger
- class EntitiesEnhance():
- def __init__(self):
- self.bfp_result_lists = []
- self._entity_recall_cache = {} # 实体检索结果缓存
- self._bfp_recall_cache = {} # BFP召回结果缓存
- def _get_cache_key(self, entity: str, search_keywords: list, background: str = "") -> str:
- """生成缓存键"""
- keywords_str = "|".join(sorted(search_keywords)) if search_keywords else ""
- return f"{entity}::{keywords_str}::{background[:50]}"
- @track_execution_time
- def entities_enhance_retrieval(self, query_pairs):
- def run_async(coro):
- """在合适的环境中运行异步函数"""
- try:
- loop = asyncio.get_running_loop()
- import concurrent.futures
- with concurrent.futures.ThreadPoolExecutor() as executor:
- future = executor.submit(asyncio.run, coro)
- return future.result()
- except RuntimeError:
- return asyncio.run(coro)
- # 清空之前的结果
- self.bfp_result_lists = []
- for query_pair in query_pairs:
- entity = query_pair['entity']
- search_keywords = query_pair['search_keywords']
- background = query_pair['background']
- server_logger.info(f"正在处理实体:{entity},辅助搜索词:{search_keywords},背景:{background}")
- # 检查 entity_recall 缓存
- recall_cache_key = self._get_cache_key(entity, search_keywords)
- if recall_cache_key in self._entity_recall_cache:
- entity_list = self._entity_recall_cache[recall_cache_key]
- server_logger.info(f"[缓存命中] entity_recall: {entity}")
- else:
- entity_list = run_async(retrieval_manager.entity_recall(
- entity,
- search_keywords,
- recall_top_k=5, # 主实体返回数量
- max_results=5 # 最终最多返回5个实体文本
- ))
- self._entity_recall_cache[recall_cache_key] = entity_list
- server_logger.info(f"[缓存存储] entity_recall: {entity}")
- # 检查 bfp_recall 缓存
- bfp_cache_key = self._get_cache_key(entity, search_keywords, background)
- if bfp_cache_key in self._bfp_recall_cache:
- bfp_result = self._bfp_recall_cache[bfp_cache_key]
- server_logger.info(f"[缓存命中] bfp_recall: {entity}")
- else:
- # BFP背景增强召回
- bfp_result = run_async(retrieval_manager.async_bfp_recall(entity_list, background, top_k=2)) # 降低到2,减少上下文量
- self._bfp_recall_cache[bfp_cache_key] = bfp_result
- server_logger.info(f"[缓存存储] bfp_recall: {entity}")
- # 为每个结果添加实体信息
- for result in bfp_result:
- result['source_entity'] = entity
- self.bfp_result_lists.append(bfp_result)
- return self.bfp_result_lists
- def clear_cache(self):
- """清空缓存"""
- self._entity_recall_cache.clear()
- self._bfp_recall_cache.clear()
- server_logger.info("[缓存清理] 实体检索缓存已清空")
- entity_enhance = EntitiesEnhance()
|