test_extract_modes.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 测试 extract_first_result 的两种模式
  5. """
  6. import sys
  7. import os
  8. import json
  9. # 添加项目根目录到路径
  10. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  11. sys.path.insert(0, project_root)
  12. from core.construction_review.component.infrastructure.parent_tool import extract_first_result
  13. def load_enhanced_results():
  14. """加载增强后的检索结果"""
  15. result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
  16. if not os.path.exists(result_path):
  17. print(f"❌ 文件不存在: {result_path}")
  18. print("请先运行 RAG 检索生成该文件")
  19. return None
  20. with open(result_path, 'r', encoding='utf-8') as f:
  21. enhanced_results = json.load(f)
  22. print(f"✅ 成功加载增强结果,共 {len(enhanced_results)} 个查询对")
  23. for idx, results in enumerate(enhanced_results):
  24. if results:
  25. first_entity = results[0].get('source_entity', f'query_{idx}')
  26. print(f" - 查询对 {idx}: entity='{first_entity}', {len(results)} 个结果")
  27. return enhanced_results
  28. def load_query_pairs():
  29. """加载查询对"""
  30. # 优先从 rag_pipeline_data.json 读取
  31. pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json")
  32. if os.path.exists(pipeline_path):
  33. with open(pipeline_path, 'r', encoding='utf-8') as f:
  34. pipeline_data = json.load(f)
  35. # 从 steps 中提取 query_pairs
  36. query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {})
  37. query_pairs = query_extract_step.get('output', {}).get('query_pairs', [])
  38. if query_pairs:
  39. print(f"✅ 从 rag_pipeline_data.json 加载了 {len(query_pairs)} 个查询对")
  40. for idx, qp in enumerate(query_pairs):
  41. print(f" - 查询对 {idx}: entity='{qp.get('entity', 'N/A')}'")
  42. return query_pairs
  43. # 降级:从 enhanced_results 中提取 entity 信息
  44. print("⚠️ 未找到 rag_pipeline_data.json,尝试从 enhanced_results 提取")
  45. result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
  46. if not os.path.exists(result_path):
  47. return None
  48. with open(result_path, 'r', encoding='utf-8') as f:
  49. enhanced_results = json.load(f)
  50. # 构造简化的 query_pairs
  51. query_pairs = []
  52. for idx, results in enumerate(enhanced_results):
  53. if results:
  54. # 优先使用 source_entity,回退到 query_N
  55. entity = results[0].get('source_entity', f'query_{idx}')
  56. query_pairs.append({
  57. 'entity': entity,
  58. 'search_keywords': [],
  59. 'background': ''
  60. })
  61. return query_pairs
  62. def test_mode_best_overall(enhanced_results, query_pairs):
  63. """测试模式1: 全局最优"""
  64. print("\n" + "="*80)
  65. print("📊 测试模式1: best_overall (全局最优)")
  66. print("="*80)
  67. result = extract_first_result(enhanced_results, query_pairs, mode='best_overall')
  68. print(f"\n✅ 返回结果:")
  69. print(f" - file_name: {result.get('file_name', 'N/A')}")
  70. print(f" - source_entity: {result.get('source_entity', 'N/A')}")
  71. print(f" - bfp_rerank_score: {result.get('bfp_rerank_score', 0.0):.6f}")
  72. print(f" - text_content 长度: {len(result.get('text_content', ''))}")
  73. print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}")
  74. # 显示文本内容预览
  75. text_preview = result.get('text_content', '')[:200]
  76. print(f"\n - 文本预览: {text_preview}...")
  77. return result
  78. def test_mode_best_per_entity(enhanced_results, query_pairs):
  79. """测试模式2: 分实体最优"""
  80. print("\n" + "="*80)
  81. print("📊 测试模式2: best_per_entity (分实体最优)")
  82. print("="*80)
  83. result = extract_first_result(enhanced_results, query_pairs, mode='best_per_entity')
  84. print(f"\n✅ 返回结果:")
  85. print(f" - total_entities: {result.get('total_entities', 0)}")
  86. print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}")
  87. entity_results = result.get('entity_results', {})
  88. print(f"\n📋 各实体最优结果:")
  89. for entity_name, entity_result in entity_results.items():
  90. score = entity_result.get('bfp_rerank_score', 0.0)
  91. file_name = entity_result.get('file_name', 'N/A')
  92. text_len = len(entity_result.get('text_content', ''))
  93. print(f"\n 🎯 实体: {entity_name}")
  94. print(f" - score: {score:.6f}")
  95. print(f" - file_name: {file_name}")
  96. print(f" - text_length: {text_len}")
  97. return result
  98. def compare_with_current_result():
  99. """对比当前 extract_first_result.json 的结果"""
  100. print("\n" + "="*80)
  101. print("📂 对比当前保存的结果")
  102. print("="*80)
  103. result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "extract_first_result.json")
  104. if not os.path.exists(result_path):
  105. print("⚠️ 当前没有保存的 extract_first_result.json")
  106. return
  107. with open(result_path, 'r', encoding='utf-8') as f:
  108. current_result = json.load(f)
  109. print(f"\n当前保存的结果:")
  110. print(f" - file_name: {current_result.get('file_name', 'N/A')}")
  111. print(f" - retrieval_status: {current_result.get('retrieval_status', 'N/A')}")
  112. print(f" - bfp_rerank_score: {current_result.get('bfp_rerank_score', 'N/A')}")
  113. print(f" - source_entity: {current_result.get('source_entity', 'N/A')}")
  114. if __name__ == "__main__":
  115. print("\n" + "="*80)
  116. print("🚀 开始测试 extract_first_result 的两种模式")
  117. print("="*80)
  118. # 加载数据
  119. enhanced_results = load_enhanced_results()
  120. if not enhanced_results:
  121. sys.exit(1)
  122. query_pairs = load_query_pairs()
  123. # 测试模式1
  124. result1 = test_mode_best_overall(enhanced_results, query_pairs)
  125. # 测试模式2
  126. result2 = test_mode_best_per_entity(enhanced_results, query_pairs)
  127. # 对比当前结果
  128. compare_with_current_result()
  129. # 保存测试结果
  130. test_output_path = os.path.join(project_root, "temp", "entity_bfp_recall", "test_extract_modes.json")
  131. with open(test_output_path, 'w', encoding='utf-8') as f:
  132. json.dump({
  133. 'best_overall': result1,
  134. 'best_per_entity': result2
  135. }, f, ensure_ascii=False, indent=4)
  136. print(f"\n✅ 测试完成,结果已保存到: {test_output_path}")
  137. print("\n" + "="*80)
  138. print("📝 建议使用哪种模式?")
  139. print("="*80)
  140. print("\n模式1 (best_overall):")
  141. print(" - 适用场景: 只需要一个最相关的结果")
  142. print(" - 优点: 返回全局最优的结果")
  143. print(" - 缺点: 可能丢失其他实体的有用信息")
  144. print("\n模式2 (best_per_entity):")
  145. print(" - 适用场景: 需要保留所有查询对的最优结果")
  146. print(" - 优点: 保留各实体的最优结果,信息更全面")
  147. print(" - 缺点: 返回结构更复杂,需要后续处理")
  148. print("\n💡 如果审查需要针对不同实体分别检查,建议使用 mode='best_per_entity'")
  149. print("="*80 + "\n")