entities_enhance.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import json
  2. import asyncio
  3. from foundation.observability.monitoring.time_statistics import track_execution_time
  4. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  5. from foundation.observability.logger.loggering import server_logger
  6. class EntitiesEnhance():
  7. def __init__(self):
  8. self.save_path = "temp\entity_bfp_recall\entity_bfp_recall.json"
  9. self.bfp_result_lists = []
  10. @track_execution_time
  11. def entities_enhance_retrieval(self, query_pairs):
  12. def run_async(coro):
  13. """在合适的环境中运行异步函数"""
  14. try:
  15. loop = asyncio.get_running_loop()
  16. import concurrent.futures
  17. with concurrent.futures.ThreadPoolExecutor() as executor:
  18. future = executor.submit(asyncio.run, coro)
  19. return future.result()
  20. except RuntimeError:
  21. return asyncio.run(coro)
  22. # 清空之前的结果
  23. self.bfp_result_lists = []
  24. for query_pair in query_pairs:
  25. entity = query_pair['entity']
  26. search_keywords = query_pair['search_keywords']
  27. background = query_pair['background']
  28. server_logger.info(f"正在处理实体:{entity},辅助搜索词:{search_keywords},背景:{background}")
  29. entity_list = run_async(retrieval_manager.entity_recall(
  30. entity,
  31. search_keywords,
  32. recall_top_k=5, # 主实体返回数量
  33. max_results=5 # 最终最多返回20个实体文本
  34. ))
  35. # BFP背景增强召回
  36. bfp_result = run_async(retrieval_manager.async_bfp_recall(entity_list, background, top_k=3))
  37. # 为每个结果添加实体信息
  38. for result in bfp_result:
  39. result['source_entity'] = entity
  40. self.bfp_result_lists.append(bfp_result)
  41. self.test_file(self.bfp_result_lists, seve=True)
  42. return self.bfp_result_lists
  43. def test_file(self,bfp_result,seve = False):
  44. if seve:
  45. with open(self.save_path, "w", encoding="utf-8") as f:
  46. json.dump(bfp_result, f, ensure_ascii=False, indent=4)
  47. entity_enhance = EntitiesEnhance()