entities_enhance.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import json
  2. import asyncio
  3. from foundation.observability.monitoring.time_statistics import track_execution_time
  4. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  5. from foundation.observability.logger.loggering import server_logger
  6. class EntitiesEnhance():
  7. def __init__(self):
  8. self.bfp_result_lists = []
  9. self._entity_recall_cache = {} # 实体检索结果缓存
  10. self._bfp_recall_cache = {} # BFP召回结果缓存
  11. def _get_cache_key(self, entity: str, search_keywords: list, background: str = "") -> str:
  12. """生成缓存键"""
  13. keywords_str = "|".join(sorted(search_keywords)) if search_keywords else ""
  14. return f"{entity}::{keywords_str}::{background[:50]}"
  15. @track_execution_time
  16. def entities_enhance_retrieval(self, query_pairs):
  17. def run_async(coro):
  18. """在合适的环境中运行异步函数"""
  19. try:
  20. loop = asyncio.get_running_loop()
  21. import concurrent.futures
  22. with concurrent.futures.ThreadPoolExecutor() as executor:
  23. future = executor.submit(asyncio.run, coro)
  24. return future.result()
  25. except RuntimeError:
  26. return asyncio.run(coro)
  27. # 清空之前的结果
  28. self.bfp_result_lists = []
  29. for query_pair in query_pairs:
  30. entity = query_pair['entity']
  31. search_keywords = query_pair['search_keywords']
  32. background = query_pair['background']
  33. server_logger.info(f"正在处理实体:{entity},辅助搜索词:{search_keywords},背景:{background}")
  34. # 检查 entity_recall 缓存
  35. recall_cache_key = self._get_cache_key(entity, search_keywords)
  36. if recall_cache_key in self._entity_recall_cache:
  37. entity_list = self._entity_recall_cache[recall_cache_key]
  38. server_logger.info(f"[缓存命中] entity_recall: {entity}")
  39. else:
  40. entity_list = run_async(retrieval_manager.entity_recall(
  41. entity,
  42. search_keywords,
  43. recall_top_k=5, # 主实体返回数量
  44. max_results=5 # 最终最多返回5个实体文本
  45. ))
  46. self._entity_recall_cache[recall_cache_key] = entity_list
  47. server_logger.info(f"[缓存存储] entity_recall: {entity}")
  48. # 检查 bfp_recall 缓存
  49. bfp_cache_key = self._get_cache_key(entity, search_keywords, background)
  50. if bfp_cache_key in self._bfp_recall_cache:
  51. bfp_result = self._bfp_recall_cache[bfp_cache_key]
  52. server_logger.info(f"[缓存命中] bfp_recall: {entity}")
  53. else:
  54. # BFP背景增强召回
  55. bfp_result = run_async(retrieval_manager.async_bfp_recall(entity_list, background, top_k=2)) # 降低到2,减少上下文量
  56. self._bfp_recall_cache[bfp_cache_key] = bfp_result
  57. server_logger.info(f"[缓存存储] bfp_recall: {entity}")
  58. # 为每个结果添加实体信息
  59. for result in bfp_result:
  60. result['source_entity'] = entity
  61. self.bfp_result_lists.append(bfp_result)
  62. return self.bfp_result_lists
  63. def clear_cache(self):
  64. """清空缓存"""
  65. self._entity_recall_cache.clear()
  66. self._bfp_recall_cache.clear()
  67. server_logger.info("[缓存清理] 实体检索缓存已清空")
  68. entity_enhance = EntitiesEnhance()