#!/usr/bin/env python # -*- coding: utf-8 -*- """ 修复旧的 enhanced_results,添加 source_entity 字段 """ import sys import os import json # 添加项目根目录到路径 project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, project_root) def patch_enhanced_results(): """为旧的 enhanced_results 添加 source_entity 字段""" enhanced_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json") pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json") # 加载数据 with open(enhanced_path, 'r', encoding='utf-8') as f: enhanced_results = json.load(f) with open(pipeline_path, 'r', encoding='utf-8') as f: pipeline_data = json.load(f) # 提取 query_pairs query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {}) query_pairs = query_extract_step.get('output', {}).get('query_pairs', []) if not query_pairs: print("❌ 未找到 query_pairs 信息") return print(f"✅ 找到 {len(query_pairs)} 个查询对:") for idx, qp in enumerate(query_pairs): entity = qp.get('entity', 'N/A') print(f" - 查询对 {idx}: entity='{entity}'") # 为每个结果添加 source_entity 字段 print(f"\n🔧 开始修复 enhanced_results...") for query_idx, results in enumerate(enhanced_results): if query_idx < len(query_pairs): entity = query_pairs[query_idx].get('entity', f'query_{query_idx}') for result in results: if 'source_entity' not in result: result['source_entity'] = entity print(f" ✅ 查询对 {query_idx} (entity='{entity}'): 已为 {len(results)} 个结果添加 source_entity") # 保存修复后的数据 backup_path = enhanced_path + ".backup" import shutil shutil.copy(enhanced_path, backup_path) print(f"\n✅ 原文件已备份到: {backup_path}") with open(enhanced_path, 'w', encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=4) print(f"✅ 修复后的数据已保存到: {enhanced_path}") print("\n🎉 现在可以重新运行 test_extract_modes.py 查看修复后的效果!") if __name__ == "__main__": print("\n" + "="*80) print("🔧 修复 enhanced_results,添加 source_entity 字段") print("="*80 + "\n") patch_enhanced_results() print("\n" + "="*80) print("✅ 修复完成!") print("="*80 + "\n")