| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 修复旧的 enhanced_results,添加 source_entity 字段
- """
- import sys
- import os
- import json
- # 添加项目根目录到路径
- project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- sys.path.insert(0, project_root)
- def patch_enhanced_results():
- """为旧的 enhanced_results 添加 source_entity 字段"""
- enhanced_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
- pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json")
- # 加载数据
- with open(enhanced_path, 'r', encoding='utf-8') as f:
- enhanced_results = json.load(f)
- with open(pipeline_path, 'r', encoding='utf-8') as f:
- pipeline_data = json.load(f)
- # 提取 query_pairs
- query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {})
- query_pairs = query_extract_step.get('output', {}).get('query_pairs', [])
- if not query_pairs:
- print("❌ 未找到 query_pairs 信息")
- return
- print(f"✅ 找到 {len(query_pairs)} 个查询对:")
- for idx, qp in enumerate(query_pairs):
- entity = qp.get('entity', 'N/A')
- print(f" - 查询对 {idx}: entity='{entity}'")
- # 为每个结果添加 source_entity 字段
- print(f"\n🔧 开始修复 enhanced_results...")
- for query_idx, results in enumerate(enhanced_results):
- if query_idx < len(query_pairs):
- entity = query_pairs[query_idx].get('entity', f'query_{query_idx}')
- for result in results:
- if 'source_entity' not in result:
- result['source_entity'] = entity
- print(f" ✅ 查询对 {query_idx} (entity='{entity}'): 已为 {len(results)} 个结果添加 source_entity")
- # 保存修复后的数据
- backup_path = enhanced_path + ".backup"
- import shutil
- shutil.copy(enhanced_path, backup_path)
- print(f"\n✅ 原文件已备份到: {backup_path}")
- with open(enhanced_path, 'w', encoding='utf-8') as f:
- json.dump(enhanced_results, f, ensure_ascii=False, indent=4)
- print(f"✅ 修复后的数据已保存到: {enhanced_path}")
- print("\n🎉 现在可以重新运行 test_extract_modes.py 查看修复后的效果!")
- if __name__ == "__main__":
- print("\n" + "="*80)
- print("🔧 修复 enhanced_results,添加 source_entity 字段")
- print("="*80 + "\n")
- patch_enhanced_results()
- print("\n" + "="*80)
- print("✅ 修复完成!")
- print("="*80 + "\n")
|