#!/usr/bin/env python # -*- coding: utf-8 -*- """ 测试 extract_first_result 的两种模式 """ import sys import os import json # 添加项目根目录到路径 project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, project_root) from core.construction_review.component.infrastructure.parent_tool import extract_first_result def load_enhanced_results(): """加载增强后的检索结果""" result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json") if not os.path.exists(result_path): print(f"❌ 文件不存在: {result_path}") print("请先运行 RAG 检索生成该文件") return None with open(result_path, 'r', encoding='utf-8') as f: enhanced_results = json.load(f) print(f"✅ 成功加载增强结果,共 {len(enhanced_results)} 个查询对") for idx, results in enumerate(enhanced_results): if results: first_entity = results[0].get('source_entity', f'query_{idx}') print(f" - 查询对 {idx}: entity='{first_entity}', {len(results)} 个结果") return enhanced_results def load_query_pairs(): """加载查询对""" # 优先从 rag_pipeline_data.json 读取 pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json") if os.path.exists(pipeline_path): with open(pipeline_path, 'r', encoding='utf-8') as f: pipeline_data = json.load(f) # 从 steps 中提取 query_pairs query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {}) query_pairs = query_extract_step.get('output', {}).get('query_pairs', []) if query_pairs: print(f"✅ 从 rag_pipeline_data.json 加载了 {len(query_pairs)} 个查询对") for idx, qp in enumerate(query_pairs): print(f" - 查询对 {idx}: entity='{qp.get('entity', 'N/A')}'") return query_pairs # 降级:从 enhanced_results 中提取 entity 信息 print("⚠️ 未找到 rag_pipeline_data.json,尝试从 enhanced_results 提取") result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json") if not os.path.exists(result_path): return None with open(result_path, 'r', encoding='utf-8') as f: enhanced_results = json.load(f) # 构造简化的 query_pairs query_pairs = [] for idx, results in enumerate(enhanced_results): if results: # 优先使用 source_entity,回退到 query_N entity = results[0].get('source_entity', f'query_{idx}') query_pairs.append({ 'entity': entity, 'search_keywords': [], 'background': '' }) return query_pairs def test_mode_best_overall(enhanced_results, query_pairs): """测试模式1: 全局最优""" print("\n" + "="*80) print("📊 测试模式1: best_overall (全局最优)") print("="*80) result = extract_first_result(enhanced_results, query_pairs, mode='best_overall') print(f"\n✅ 返回结果:") print(f" - file_name: {result.get('file_name', 'N/A')}") print(f" - source_entity: {result.get('source_entity', 'N/A')}") print(f" - bfp_rerank_score: {result.get('bfp_rerank_score', 0.0):.6f}") print(f" - text_content 长度: {len(result.get('text_content', ''))}") print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}") # 显示文本内容预览 text_preview = result.get('text_content', '')[:200] print(f"\n - 文本预览: {text_preview}...") return result def test_mode_best_per_entity(enhanced_results, query_pairs): """测试模式2: 分实体最优""" print("\n" + "="*80) print("📊 测试模式2: best_per_entity (分实体最优)") print("="*80) result = extract_first_result(enhanced_results, query_pairs, mode='best_per_entity') print(f"\n✅ 返回结果:") print(f" - total_entities: {result.get('total_entities', 0)}") print(f" - retrieval_status: {result.get('retrieval_status', 'N/A')}") entity_results = result.get('entity_results', {}) print(f"\n📋 各实体最优结果:") for entity_name, entity_result in entity_results.items(): score = entity_result.get('bfp_rerank_score', 0.0) file_name = entity_result.get('file_name', 'N/A') text_len = len(entity_result.get('text_content', '')) print(f"\n 🎯 实体: {entity_name}") print(f" - score: {score:.6f}") print(f" - file_name: {file_name}") print(f" - text_length: {text_len}") return result def compare_with_current_result(): """对比当前 extract_first_result.json 的结果""" print("\n" + "="*80) print("📂 对比当前保存的结果") print("="*80) result_path = os.path.join(project_root, "temp", "entity_bfp_recall", "extract_first_result.json") if not os.path.exists(result_path): print("⚠️ 当前没有保存的 extract_first_result.json") return with open(result_path, 'r', encoding='utf-8') as f: current_result = json.load(f) print(f"\n当前保存的结果:") print(f" - file_name: {current_result.get('file_name', 'N/A')}") print(f" - retrieval_status: {current_result.get('retrieval_status', 'N/A')}") print(f" - bfp_rerank_score: {current_result.get('bfp_rerank_score', 'N/A')}") print(f" - source_entity: {current_result.get('source_entity', 'N/A')}") if __name__ == "__main__": print("\n" + "="*80) print("🚀 开始测试 extract_first_result 的两种模式") print("="*80) # 加载数据 enhanced_results = load_enhanced_results() if not enhanced_results: sys.exit(1) query_pairs = load_query_pairs() # 测试模式1 result1 = test_mode_best_overall(enhanced_results, query_pairs) # 测试模式2 result2 = test_mode_best_per_entity(enhanced_results, query_pairs) # 对比当前结果 compare_with_current_result() # 保存测试结果 test_output_path = os.path.join(project_root, "temp", "entity_bfp_recall", "test_extract_modes.json") with open(test_output_path, 'w', encoding='utf-8') as f: json.dump({ 'best_overall': result1, 'best_per_entity': result2 }, f, ensure_ascii=False, indent=4) print(f"\n✅ 测试完成,结果已保存到: {test_output_path}") print("\n" + "="*80) print("📝 建议使用哪种模式?") print("="*80) print("\n模式1 (best_overall):") print(" - 适用场景: 只需要一个最相关的结果") print(" - 优点: 返回全局最优的结果") print(" - 缺点: 可能丢失其他实体的有用信息") print("\n模式2 (best_per_entity):") print(" - 适用场景: 需要保留所有查询对的最优结果") print(" - 优点: 保留各实体的最优结果,信息更全面") print(" - 缺点: 返回结构更复杂,需要后续处理") print("\n💡 如果审查需要针对不同实体分别检查,建议使用 mode='best_per_entity'") print("="*80 + "\n")