test_extract_modes.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试 extract_first_result 的两种模式
  5. """
  6. import sys
  7. import os
  8. import json
  9. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. from core.construction_review.component.infrastructure.parent_tool import extract_first_result
  11. def load_enhanced_results():
  12. """加载增强后的检索结果"""
  13. result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
  14. if not os.path.exists(result_path):
  15. print(f"❌ 文件不存在: {result_path}")
  16. print("请先运行 RAG 检索生成该文件")
  17. return None
  18. with open(result_path, 'r', encoding='utf-8') as f:
  19. enhanced_results = json.load(f)
  20. print(f"✅ 成功加载增强结果,共 {len(enhanced_results)} 个查询对")
  21. for idx, results in enumerate(enhanced_results):
  22. if results:
  23. first_entity = results[0].get('source_entity', f'query_{idx}')
  24. print(f" - 查询对 {idx}: entity='{first_entity}', {len(results)} 个结果")
  25. return enhanced_results
  26. def load_query_pairs():
  27. """加载查询对"""
  28. # 优先从 rag_pipeline_data.json 读取
  29. pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json")
  30. if os.path.exists(pipeline_path):
  31. with open(pipeline_path, 'r', encoding='utf-8') as f:
  32. pipeline_data = json.load(f)
  33. # 从 steps 中提取 query_pairs
  34. query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {})
  35. query_pairs = query_extract_step.get('output', {}).get('query_pairs', [])
  36. if query_pairs:
  37. print(f"✅ 从 rag_pipeline_data.json 加载了 {len(query_pairs)} 个查询对")
  38. for idx, qp in enumerate(query_pairs):
  39. print(f" - 查询对 {idx}: entity='{qp.get('entity', 'N/A')}'")
  40. return query_pairs
  41. # 降级:从 enhanced_results 中提取 entity 信息
  42. print("⚠️ 未找到 rag_pipeline_data.json,尝试从 enhanced_results 提取")
  43. result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
  44. if not os.path.exists(result_path):
  45. return None
  46. with open(result_path, 'r', encoding='utf-8') as f:
  47. enhanced_results = json.load(f)
  48. # 构造简化的 query_pairs
  49. query_pairs = []
  50. for idx, results in enumerate(enhanced_results):
  51. if results:
  52. # 优先使用 source_entity,回退到 query_N
  53. entity = results[0].get('source_entity', f'query_{idx}')
  54. query_pairs.append({
  55. 'entity': entity,
  56. 'search_keywords': [],
  57. 'background': ''
  58. })
  59. return query_pairs
  60. def test_mode_best_overall(enhanced_results, query_pairs):
  61. """测试模式1: 全局最优"""
  62. print("\n" + "="*80)
  63. print("📊 测试模式1: best_overall (全局最优)")
  64. print("="*80)
  65. result = extract_first_result(enhanced_results, query_pairs, mode='best_overall')
  66. print(f"\n✅ 返回结果:")
  67. print(f" - file_name: {result.get('file_name', 'N/A')}")
  68. print(f" - source_entity: {result.get('source_entity', 'N/A')}")
  69. print(f" - bfp_rerank_score: {result.get('bfp_rerank_score', 0.0):.6f}")
  70. print(f" - text_content 长度: {len(result.get('text_content', ''))}")
  71. print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}")
  72. # 显示文本内容预览
  73. text_preview = result.get('text_content', '')[:200]
  74. print(f"\n - 文本预览: {text_preview}...")
  75. return result
  76. def test_mode_best_per_entity(enhanced_results, query_pairs):
  77. """测试模式2: 分实体最优"""
  78. print("\n" + "="*80)
  79. print("📊 测试模式2: best_per_entity (分实体最优)")
  80. print("="*80)
  81. result = extract_first_result(enhanced_results, query_pairs, mode='best_per_entity')
  82. print(f"\n✅ 返回结果:")
  83. print(f" - total_entities: {result.get('total_entities', 0)}")
  84. print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}")
  85. entity_results = result.get('entity_results', {})
  86. print(f"\n📋 各实体最优结果:")
  87. for entity_name, entity_result in entity_results.items():
  88. score = entity_result.get('bfp_rerank_score', 0.0)
  89. file_name = entity_result.get('file_name', 'N/A')
  90. text_len = len(entity_result.get('text_content', ''))
  91. print(f"\n 🎯 实体: {entity_name}")
  92. print(f" - score: {score:.6f}")
  93. print(f" - file_name: {file_name}")
  94. print(f" - text_length: {text_len}")
  95. return result
  96. def compare_with_current_result():
  97. """对比当前 extract_first_result.json 的结果"""
  98. print("\n" + "="*80)
  99. print("📂 对比当前保存的结果")
  100. print("="*80)
  101. result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "extract_first_result.json")
  102. if not os.path.exists(result_path):
  103. print("⚠️ 当前没有保存的 extract_first_result.json")
  104. return
  105. with open(result_path, 'r', encoding='utf-8') as f:
  106. current_result = json.load(f)
  107. print(f"\n当前保存的结果:")
  108. print(f" - file_name: {current_result.get('file_name', 'N/A')}")
  109. print(f" - retrieval_status: {current_result.get('retrieval_status', 'N/A')}")
  110. print(f" - bfp_rerank_score: {current_result.get('bfp_rerank_score', 'N/A')}")
  111. print(f" - source_entity: {current_result.get('source_entity', 'N/A')}")
  112. if __name__ == "__main__":
  113. print("\n" + "="*80)
  114. print("🚀 开始测试 extract_first_result 的两种模式")
  115. print("="*80)
  116. # 加载数据
  117. enhanced_results = load_enhanced_results()
  118. if not enhanced_results:
  119. sys.exit(1)
  120. query_pairs = load_query_pairs()
  121. # 测试模式1
  122. result1 = test_mode_best_overall(enhanced_results, query_pairs)
  123. # 测试模式2
  124. result2 = test_mode_best_per_entity(enhanced_results, query_pairs)
  125. # 对比当前结果
  126. compare_with_current_result()
  127. # 保存测试结果
  128. test_output_path = os.path.join(project_root, "temp", "entity_bfp_recall", "test_extract_modes.json")
  129. with open(test_output_path, 'w', encoding='utf-8') as f:
  130. json.dump({
  131. 'best_overall': result1,
  132. 'best_per_entity': result2
  133. }, f, ensure_ascii=False, indent=4)
  134. print(f"\n✅ 测试完成,结果已保存到: {test_output_path}")
  135. print("\n" + "="*80)
  136. print("📝 建议使用哪种模式?")
  137. print("="*80)
  138. print("\n模式1 (best_overall):")
  139. print(" - 适用场景: 只需要一个最相关的结果")
  140. print(" - 优点: 返回全局最优的结果")
  141. print(" - 缺点: 可能丢失其他实体的有用信息")
  142. print("\n模式2 (best_per_entity):")
  143. print(" - 适用场景: 需要保留所有查询对的最优结果")
  144. print(" - 优点: 保留各实体的最优结果,信息更全面")
  145. print(" - 缺点: 返回结构更复杂,需要后续处理")
  146. print("\n💡 如果审查需要针对不同实体分别检查,建议使用 mode='best_per_entity'")
  147. print("="*80 + "\n")