test_entity_bfp_recall.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import json
  2. import asyncio
  3. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  4. from foundation.observability.monitoring.time_statistics import track_execution_time
  5. entity = "架桥机"
  6. search_keywords = ["提梁机", "架桥设备", "造桥机"]
  7. background = "JQ220t-40m架桥机安装及拆除"
  8. @track_execution_time
  9. def main():
  10. print("="*60)
  11. print("实体增强召回测试")
  12. print("="*60)
  13. print(f"主实体: {entity}")
  14. print(f"辅助实体: {search_keywords}")
  15. print(f"背景信息: {background}")
  16. print("-"*60)
  17. # 使用新参数调用 entity_recall
  18. # recall_top_k=5: 每个实体召回5个结果
  19. # max_results=20: 最终返回最多20个实体文本
  20. entity_list = asyncio.run(retrieval_manager.entity_recall(
  21. entity,
  22. search_keywords,
  23. recall_top_k=5, # 每次单实体召回返回5个
  24. max_results=20 # 最终最多返回20个
  25. ))
  26. print(f"\n✅ 实体召回完成, 共召回 {len(entity_list)} 个实体")
  27. print(f"实体列表前5个: {entity_list[:5]}")
  28. # 使用 top_k 参数调用 async_bfp_recall
  29. # top_k=3: 二次重排后最多返回3个BFP文档
  30. bfp_result = asyncio.run(retrieval_manager.async_bfp_recall(
  31. entity_list,
  32. background,
  33. top_k=3
  34. ))
  35. print(f"\n✅ BFP召回完成, 共召回 {len(bfp_result)} 个文档")
  36. # 保存结果
  37. with open("temp/entity_bfp_recall/entity_bfp_recall.json", "w", encoding="utf-8") as f:
  38. json.dump(bfp_result, f, ensure_ascii=False, indent=4)
  39. print(f"\n✅ 结果已保存到: temp/entity_bfp_recall/entity_bfp_recall.json")
  40. print("="*60)
  41. if __name__ == "__main__":
  42. main()