run_tests.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """
  2. Agent 驱动的 RAG 管线测试运行器
  3. 执行全部测试样本,评估并生成报告
  4. """
  5. import sys
  6. import os
  7. # 确保项目根目录在路径中
  8. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  9. if project_root not in sys.path:
  10. sys.path.insert(0, project_root)
  11. from utils_test.RAG_Pipeline_Test.test_data import TEST_SAMPLES
  12. from utils_test.RAG_Pipeline_Test.rag_pipeline_runner import RAGPipelineRunner
  13. from utils_test.RAG_Pipeline_Test.rag_evaluator import RAGEvaluator
  14. def main():
  15. print("=" * 70)
  16. print("RAG 管线 Agent 驱动测试")
  17. print("=" * 70)
  18. print(f"测试样本数: {len(TEST_SAMPLES)}")
  19. print()
  20. # 初始化
  21. print("[1/3] 初始化管线执行器和评估器...")
  22. runner = RAGPipelineRunner()
  23. evaluator = RAGEvaluator()
  24. print(" 初始化完成")
  25. print()
  26. # 执行管线
  27. print("[2/3] 执行 RAG 管线...")
  28. results = runner.run_batch(TEST_SAMPLES)
  29. print()
  30. # 评估
  31. print("[3/3] 评估结果...")
  32. evaluations = []
  33. for i, (result, sample) in enumerate(zip(results, TEST_SAMPLES)):
  34. print(f" 评估样本 {i+1}/{len(results)}: {result.chunk_id}")
  35. ev = evaluator.evaluate_sample(result, sample["content"])
  36. evaluations.append(ev)
  37. print(f" 总分: {ev.overall_score:.1f}/5.0 [{ev.overall_status}]")
  38. print(f" {ev.analysis}")
  39. print()
  40. # 生成报告
  41. report = evaluator.generate_report(evaluations)
  42. # 先保存报告到文件(避免打印编码问题导致丢失)
  43. report_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "reports")
  44. os.makedirs(report_dir, exist_ok=True)
  45. report_path = os.path.join(report_dir, "rag_pipeline_test_report.md")
  46. with open(report_path, "w", encoding="utf-8") as f:
  47. f.write(report)
  48. print(f"\n报告已保存: {report_path}")
  49. # 输出报告
  50. print("=" * 70)
  51. try:
  52. print(report)
  53. except UnicodeEncodeError:
  54. safe_report = report.replace("✅", "[PASS]").replace("⚠️", "[WARN]").replace("❌", "[FAIL]")
  55. print(safe_report)
  56. print("=" * 70)
  57. # 返回汇总
  58. pass_count = sum(1 for ev in evaluations if ev.overall_status == "PASS")
  59. warn_count = sum(1 for ev in evaluations if ev.overall_status == "WARN")
  60. fail_count = sum(1 for ev in evaluations if ev.overall_status == "FAIL")
  61. print(f"\n汇总: {pass_count} PASS / {warn_count} WARN / {fail_count} FAIL")
  62. return evaluations
  63. if __name__ == "__main__":
  64. main()