| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import json
- import asyncio
- from foundation.ai.rag.retrieval.retrieval import retrieval_manager
- from foundation.observability.monitoring.time_statistics import track_execution_time
- entity = "架桥机"
- search_keywords = ["提梁机", "架桥设备", "造桥机"]
- background = "JQ220t-40m架桥机安装及拆除"
- @track_execution_time
- def main():
- print("="*60)
- print("实体增强召回测试")
- print("="*60)
- print(f"主实体: {entity}")
- print(f"辅助实体: {search_keywords}")
- print(f"背景信息: {background}")
- print("-"*60)
- # 使用新参数调用 entity_recall
- # recall_top_k=5: 每个实体召回5个结果
- # max_results=20: 最终返回最多20个实体文本
- entity_list = asyncio.run(retrieval_manager.entity_recall(
- entity,
- search_keywords,
- recall_top_k=5, # 每次单实体召回返回5个
- max_results=20 # 最终最多返回20个
- ))
- print(f"\n✅ 实体召回完成, 共召回 {len(entity_list)} 个实体")
- print(f"实体列表前5个: {entity_list[:5]}")
- # 使用 top_k 参数调用 async_bfp_recall
- # top_k=3: 二次重排后最多返回3个BFP文档
- bfp_result = asyncio.run(retrieval_manager.async_bfp_recall(
- entity_list,
- background,
- top_k=3
- ))
- print(f"\n✅ BFP召回完成, 共召回 {len(bfp_result)} 个文档")
- # 保存结果
- with open("temp/entity_bfp_recall/entity_bfp_recall.json", "w", encoding="utf-8") as f:
- json.dump(bfp_result, f, ensure_ascii=False, indent=4)
- print(f"\n✅ 结果已保存到: temp/entity_bfp_recall/entity_bfp_recall.json")
- print("="*60)
- if __name__ == "__main__":
- main()
|