patch_enhanced_results.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 修复旧的 enhanced_results,添加 source_entity 字段
  5. """
  6. import sys
  7. import os
  8. import json
  9. # 添加项目根目录到路径
  10. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  11. sys.path.insert(0, project_root)
  12. def patch_enhanced_results():
  13. """为旧的 enhanced_results 添加 source_entity 字段"""
  14. enhanced_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
  15. pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json")
  16. # 加载数据
  17. with open(enhanced_path, 'r', encoding='utf-8') as f:
  18. enhanced_results = json.load(f)
  19. with open(pipeline_path, 'r', encoding='utf-8') as f:
  20. pipeline_data = json.load(f)
  21. # 提取 query_pairs
  22. query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {})
  23. query_pairs = query_extract_step.get('output', {}).get('query_pairs', [])
  24. if not query_pairs:
  25. print("❌ 未找到 query_pairs 信息")
  26. return
  27. print(f"✅ 找到 {len(query_pairs)} 个查询对:")
  28. for idx, qp in enumerate(query_pairs):
  29. entity = qp.get('entity', 'N/A')
  30. print(f" - 查询对 {idx}: entity='{entity}'")
  31. # 为每个结果添加 source_entity 字段
  32. print(f"\n🔧 开始修复 enhanced_results...")
  33. for query_idx, results in enumerate(enhanced_results):
  34. if query_idx < len(query_pairs):
  35. entity = query_pairs[query_idx].get('entity', f'query_{query_idx}')
  36. for result in results:
  37. if 'source_entity' not in result:
  38. result['source_entity'] = entity
  39. print(f" ✅ 查询对 {query_idx} (entity='{entity}'): 已为 {len(results)} 个结果添加 source_entity")
  40. # 保存修复后的数据
  41. backup_path = enhanced_path + ".backup"
  42. import shutil
  43. shutil.copy(enhanced_path, backup_path)
  44. print(f"\n✅ 原文件已备份到: {backup_path}")
  45. with open(enhanced_path, 'w', encoding='utf-8') as f:
  46. json.dump(enhanced_results, f, ensure_ascii=False, indent=4)
  47. print(f"✅ 修复后的数据已保存到: {enhanced_path}")
  48. print("\n🎉 现在可以重新运行 test_extract_modes.py 查看修复后的效果!")
  49. if __name__ == "__main__":
  50. print("\n" + "="*80)
  51. print("🔧 修复 enhanced_results,添加 source_entity 字段")
  52. print("="*80 + "\n")
  53. patch_enhanced_results()
  54. print("\n" + "="*80)
  55. print("✅ 修复完成!")
  56. print("="*80 + "\n")