test_rag.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试多阶段召回功能
  5. """
  6. import sys
  7. import os
  8. import time
  9. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  11. from foundation.observability.logger.loggering import server_logger as logger
  12. def test_multi_stage_recall(collection_name,query):
  13. """
  14. 测试多阶段召回
  15. """
  16. try:
  17. start_time = time.time()
  18. results = retrieval_manager.multi_stage_recall(
  19. collection_name=collection_name,
  20. query_text=query,
  21. hybrid_top_k=10,
  22. top_k=5,
  23. )
  24. logger.info(f"返回结果:{results}")
  25. end_time = time.time()
  26. elapsed_time = end_time - start_time
  27. print(f"[OK] 召回完成,耗时: {elapsed_time:.2f}秒")
  28. print(f"[OK] 返回结果数量: {len(results)}")
  29. except Exception as e:
  30. print(f"[ERROR] 多阶段召回测试失败: {str(e)}")
  31. def test_hybrid_search_recall(collection_name,query):
  32. """
  33. 测试混合召回
  34. """
  35. try:
  36. start_time = time.time()
  37. results = retrieval_manager.hybrid_search_recall(
  38. collection_name=collection_name,
  39. query_text=query,
  40. top_k=1,
  41. ranker_type="weighted",
  42. dense_weight=0.7,
  43. sparse_weight=0.3
  44. )
  45. logger.info(f"返回结果:{results}")
  46. end_time = time.time()
  47. elapsed_time = end_time - start_time
  48. print(f"[OK] 召回完成,耗时: {elapsed_time:.2f}秒")
  49. print(f"[OK] 召回结果数量: {len(results)}")
  50. return results
  51. except Exception as e:
  52. print(f"[ERROR] 混合召回测试失败: {str(e)}")
  53. def main():
  54. """
  55. 主测试函数
  56. """
  57. collection_name = "first_bfp_collection_test"
  58. query = "起重小车轨道,起重量小于 320t的分段拼接桁架梁每段梁上小车轨道不允许有接缝(允许焊为一体),拼接 处高低差≤2mm、间隙≤4mm、侧向错位≤2mm,非焊接连接轨道端部加挡铁,其他梁轨道接头高低差≤1mm、间隙≤2mm、侧向错位≤1mm,正轨箱形梁及半偏轨箱形梁轨道接缝应放 在筋板上允差≤15mm,两端最短轨道长度≥1.5m且端部加挡"
  59. # 测试多路召回
  60. logger.info("开始测试多路召回...")
  61. test_multi_stage_recall(collection_name,query=query)
  62. # # 测试混合召回
  63. # logger.info("开始测试混合召回...")
  64. # test_hybrid_search_recall(collection_name="first_bfp_collection_entity",query=query)
  65. if __name__ == "__main__":
  66. main()