test_rag.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试多阶段召回功能
  5. """
  6. import sys
  7. import os
  8. import time
  9. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  10. from foundation.observability.logger.loggering import review_logger as logger
  11. def test_multi_stage_recall(collection_name,query):
  12. """
  13. 测试多阶段召回
  14. """
  15. try:
  16. start_time = time.time()
  17. results = retrieval_manager.multi_stage_recall(
  18. collection_name=collection_name,
  19. query_text=query,
  20. hybrid_top_k=10,
  21. top_k=5,
  22. )
  23. logger.info(f"返回结果:{results}")
  24. end_time = time.time()
  25. elapsed_time = end_time - start_time
  26. print(f"[OK] 召回完成,耗时: {elapsed_time:.2f}秒")
  27. print(f"[OK] 返回结果数量: {len(results)}")
  28. except Exception as e:
  29. print(f"[ERROR] 多阶段召回测试失败: {str(e)}")
  30. def test_hybrid_search_recall(collection_name,query):
  31. """
  32. 测试混合召回
  33. """
  34. try:
  35. start_time = time.time()
  36. results = retrieval_manager.hybrid_search_recall(
  37. collection_name=collection_name,
  38. query_text=query,
  39. top_k=1,
  40. ranker_type="weighted",
  41. dense_weight=0.7,
  42. sparse_weight=0.3
  43. )
  44. logger.info(f"返回结果:{results}")
  45. end_time = time.time()
  46. elapsed_time = end_time - start_time
  47. print(f"[OK] 召回完成,耗时: {elapsed_time:.2f}秒")
  48. print(f"[OK] 召回结果数量: {len(results)}")
  49. return results
  50. except Exception as e:
  51. print(f"[ERROR] 混合召回测试失败: {str(e)}")
  52. def main():
  53. """
  54. 主测试函数
  55. """
  56. collection_name = "first_bfp_collection_test"
  57. query = "起重小车轨道,起重量小于 320t的分段拼接桁架梁每段梁上小车轨道不允许有接缝(允许焊为一体),拼接 处高低差≤2mm、间隙≤4mm、侧向错位≤2mm,非焊接连接轨道端部加挡铁,其他梁轨道接头高低差≤1mm、间隙≤2mm、侧向错位≤1mm,正轨箱形梁及半偏轨箱形梁轨道接缝应放 在筋板上允差≤15mm,两端最短轨道长度≥1.5m且端部加挡"
  58. # 测试多路召回
  59. logger.info("开始测试多路召回...")
  60. test_multi_stage_recall(collection_name,query=query)
  61. # # 测试混合召回
  62. # logger.info("开始测试混合召回...")
  63. # test_hybrid_search_recall(collection_name="first_bfp_collection_entity",query=query)
  64. if __name__ == "__main__":
  65. main()