| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 测试 extract_first_result 的两种模式
- """
- import sys
- import os
- import json
- # 添加项目根目录到路径
- project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- sys.path.insert(0, project_root)
- from core.construction_review.component.infrastructure.parent_tool import extract_first_result
- def load_enhanced_results():
- """加载增强后的检索结果"""
- result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
- if not os.path.exists(result_path):
- print(f"❌ 文件不存在: {result_path}")
- print("请先运行 RAG 检索生成该文件")
- return None
- with open(result_path, 'r', encoding='utf-8') as f:
- enhanced_results = json.load(f)
- print(f"✅ 成功加载增强结果,共 {len(enhanced_results)} 个查询对")
- for idx, results in enumerate(enhanced_results):
- if results:
- first_entity = results[0].get('source_entity', f'query_{idx}')
- print(f" - 查询对 {idx}: entity='{first_entity}', {len(results)} 个结果")
- return enhanced_results
- def load_query_pairs():
- """加载查询对"""
- # 优先从 rag_pipeline_data.json 读取
- pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json")
- if os.path.exists(pipeline_path):
- with open(pipeline_path, 'r', encoding='utf-8') as f:
- pipeline_data = json.load(f)
- # 从 steps 中提取 query_pairs
- query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {})
- query_pairs = query_extract_step.get('output', {}).get('query_pairs', [])
- if query_pairs:
- print(f"✅ 从 rag_pipeline_data.json 加载了 {len(query_pairs)} 个查询对")
- for idx, qp in enumerate(query_pairs):
- print(f" - 查询对 {idx}: entity='{qp.get('entity', 'N/A')}'")
- return query_pairs
- # 降级:从 enhanced_results 中提取 entity 信息
- print("⚠️ 未找到 rag_pipeline_data.json,尝试从 enhanced_results 提取")
- result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
- if not os.path.exists(result_path):
- return None
- with open(result_path, 'r', encoding='utf-8') as f:
- enhanced_results = json.load(f)
- # 构造简化的 query_pairs
- query_pairs = []
- for idx, results in enumerate(enhanced_results):
- if results:
- # 优先使用 source_entity,回退到 query_N
- entity = results[0].get('source_entity', f'query_{idx}')
- query_pairs.append({
- 'entity': entity,
- 'search_keywords': [],
- 'background': ''
- })
- return query_pairs
- def test_mode_best_overall(enhanced_results, query_pairs):
- """测试模式1: 全局最优"""
- print("\n" + "="*80)
- print("📊 测试模式1: best_overall (全局最优)")
- print("="*80)
- result = extract_first_result(enhanced_results, query_pairs, mode='best_overall')
- print(f"\n✅ 返回结果:")
- print(f" - file_name: {result.get('file_name', 'N/A')}")
- print(f" - source_entity: {result.get('source_entity', 'N/A')}")
- print(f" - bfp_rerank_score: {result.get('bfp_rerank_score', 0.0):.6f}")
- print(f" - text_content 长度: {len(result.get('text_content', ''))}")
- print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}")
- # 显示文本内容预览
- text_preview = result.get('text_content', '')[:200]
- print(f"\n - 文本预览: {text_preview}...")
- return result
- def test_mode_best_per_entity(enhanced_results, query_pairs):
- """测试模式2: 分实体最优"""
- print("\n" + "="*80)
- print("📊 测试模式2: best_per_entity (分实体最优)")
- print("="*80)
- result = extract_first_result(enhanced_results, query_pairs, mode='best_per_entity')
- print(f"\n✅ 返回结果:")
- print(f" - total_entities: {result.get('total_entities', 0)}")
- print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}")
- entity_results = result.get('entity_results', {})
- print(f"\n📋 各实体最优结果:")
- for entity_name, entity_result in entity_results.items():
- score = entity_result.get('bfp_rerank_score', 0.0)
- file_name = entity_result.get('file_name', 'N/A')
- text_len = len(entity_result.get('text_content', ''))
- print(f"\n 🎯 实体: {entity_name}")
- print(f" - score: {score:.6f}")
- print(f" - file_name: {file_name}")
- print(f" - text_length: {text_len}")
- return result
- def compare_with_current_result():
- """对比当前 extract_first_result.json 的结果"""
- print("\n" + "="*80)
- print("📂 对比当前保存的结果")
- print("="*80)
- result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "extract_first_result.json")
- if not os.path.exists(result_path):
- print("⚠️ 当前没有保存的 extract_first_result.json")
- return
- with open(result_path, 'r', encoding='utf-8') as f:
- current_result = json.load(f)
- print(f"\n当前保存的结果:")
- print(f" - file_name: {current_result.get('file_name', 'N/A')}")
- print(f" - retrieval_status: {current_result.get('retrieval_status', 'N/A')}")
- print(f" - bfp_rerank_score: {current_result.get('bfp_rerank_score', 'N/A')}")
- print(f" - source_entity: {current_result.get('source_entity', 'N/A')}")
- if __name__ == "__main__":
- print("\n" + "="*80)
- print("🚀 开始测试 extract_first_result 的两种模式")
- print("="*80)
- # 加载数据
- enhanced_results = load_enhanced_results()
- if not enhanced_results:
- sys.exit(1)
- query_pairs = load_query_pairs()
- # 测试模式1
- result1 = test_mode_best_overall(enhanced_results, query_pairs)
- # 测试模式2
- result2 = test_mode_best_per_entity(enhanced_results, query_pairs)
- # 对比当前结果
- compare_with_current_result()
- # 保存测试结果
- test_output_path = os.path.join(project_root, "temp", "entity_bfp_recall", "test_extract_modes.json")
- with open(test_output_path, 'w', encoding='utf-8') as f:
- json.dump({
- 'best_overall': result1,
- 'best_per_entity': result2
- }, f, ensure_ascii=False, indent=4)
- print(f"\n✅ 测试完成,结果已保存到: {test_output_path}")
- print("\n" + "="*80)
- print("📝 建议使用哪种模式?")
- print("="*80)
- print("\n模式1 (best_overall):")
- print(" - 适用场景: 只需要一个最相关的结果")
- print(" - 优点: 返回全局最优的结果")
- print(" - 缺点: 可能丢失其他实体的有用信息")
- print("\n模式2 (best_per_entity):")
- print(" - 适用场景: 需要保留所有查询对的最优结果")
- print(" - 优点: 保留各实体的最优结果,信息更全面")
- print(" - 缺点: 返回结构更复杂,需要后续处理")
- print("\n💡 如果审查需要针对不同实体分别检查,建议使用 mode='best_per_entity'")
- print("="*80 + "\n")
|