patch_enhanced_results.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 修复旧的 enhanced_results,添加 source_entity 字段
  5. """
  6. import sys
  7. import os
  8. import json
  9. project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. def patch_enhanced_results():
  11. """为旧的 enhanced_results 添加 source_entity 字段"""
  12. enhanced_path = os.path.join(project_root, "temp", "entity_bfp_recall", "enhance_with_parent_docs.json")
  13. pipeline_path = os.path.join(project_root, "temp", "entity_bfp_recall", "rag_pipeline_data.json")
  14. # 加载数据
  15. with open(enhanced_path, 'r', encoding='utf-8') as f:
  16. enhanced_results = json.load(f)
  17. with open(pipeline_path, 'r', encoding='utf-8') as f:
  18. pipeline_data = json.load(f)
  19. # 提取 query_pairs
  20. query_extract_step = pipeline_data.get('steps', {}).get('1_query_extract', {})
  21. query_pairs = query_extract_step.get('output', {}).get('query_pairs', [])
  22. if not query_pairs:
  23. print("❌ 未找到 query_pairs 信息")
  24. return
  25. print(f"✅ 找到 {len(query_pairs)} 个查询对:")
  26. for idx, qp in enumerate(query_pairs):
  27. entity = qp.get('entity', 'N/A')
  28. print(f" - 查询对 {idx}: entity='{entity}'")
  29. # 为每个结果添加 source_entity 字段
  30. print(f"\n🔧 开始修复 enhanced_results...")
  31. for query_idx, results in enumerate(enhanced_results):
  32. if query_idx < len(query_pairs):
  33. entity = query_pairs[query_idx].get('entity', f'query_{query_idx}')
  34. for result in results:
  35. if 'source_entity' not in result:
  36. result['source_entity'] = entity
  37. print(f" ✅ 查询对 {query_idx} (entity='{entity}'): 已为 {len(results)} 个结果添加 source_entity")
  38. # 保存修复后的数据
  39. backup_path = enhanced_path + ".backup"
  40. import shutil
  41. shutil.copy(enhanced_path, backup_path)
  42. print(f"\n✅ 原文件已备份到: {backup_path}")
  43. with open(enhanced_path, 'w', encoding='utf-8') as f:
  44. json.dump(enhanced_results, f, ensure_ascii=False, indent=4)
  45. print(f"✅ 修复后的数据已保存到: {enhanced_path}")
  46. print("\n🎉 现在可以重新运行 test_extract_modes.py 查看修复后的效果!")
  47. if __name__ == "__main__":
  48. print("\n" + "="*80)
  49. print("🔧 修复 enhanced_results,添加 source_entity 字段")
  50. print("="*80 + "\n")
  51. patch_enhanced_results()
  52. print("\n" + "="*80)
  53. print("✅ 修复完成!")
  54. print("="*80 + "\n")