csv_同实体却有的命中有的未命中.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import pandas as pd
  2. import os
  3. # 与项目其他脚本保持一致:UTF-8 with BOM,便于 Excel 正常识别中文
  4. CHINESE_UTF8_SIG = "utf-8-sig"
  5. def filter_duplicate_entity_with_both_hit_values(input_file, output_file):
  6. """
  7. 筛选entity_name重复且对应eval_hit同时包含TRUE/FALSE的记录
  8. Args:
  9. input_file (str): 输入CSV文件路径
  10. output_file (str): 输出CSV文件路径
  11. """
  12. # 检查输入文件是否存在
  13. if not os.path.exists(input_file):
  14. print(f"错误:输入文件 {input_file} 不存在!")
  15. return
  16. try:
  17. # 1. 读取CSV文件
  18. df = pd.read_csv(input_file, encoding=CHINESE_UTF8_SIG)
  19. # 检查必要的列是否存在
  20. required_columns = ['entity_name', 'eval_hit']
  21. missing_columns = [col for col in required_columns if col not in df.columns]
  22. if missing_columns:
  23. print(f"错误:CSV文件缺少必要的列:{missing_columns}")
  24. return
  25. # 2. 标准化eval_hit列的值(统一大小写,去除空格)
  26. df['eval_hit'] = df['eval_hit'].astype(str).str.strip().str.upper()
  27. # 3. 按entity_name分组,检查每个分组是否同时包含TRUE和FALSE
  28. # 获取每个entity_name对应的唯一eval_hit值集合
  29. entity_hit_groups = df.groupby('entity_name')['eval_hit'].unique()
  30. # 筛选出同时包含TRUE和FALSE的entity_name
  31. target_entities = [
  32. entity for entity, hits in entity_hit_groups.items()
  33. if 'TRUE' in hits and 'FALSE' in hits
  34. ]
  35. if not target_entities:
  36. print("未找到符合条件的记录(entity_name重复且eval_hit包含TRUE/FALSE)")
  37. return
  38. # 4. 提取符合条件的所有记录
  39. filtered_df = df[df['entity_name'].isin(target_entities)]
  40. # 5. 保存到新CSV文件
  41. filtered_df.to_csv(output_file, index=False, encoding=CHINESE_UTF8_SIG)
  42. print(f"筛选完成!")
  43. print(f"- 符合条件的entity_name数量:{len(target_entities)}")
  44. print(f"- 提取的记录总数:{len(filtered_df)}")
  45. print(f"- 结果已保存到:{output_file}")
  46. except Exception as e:
  47. print(f"处理过程中出错:{str(e)}")
  48. # ===================== 主程序 =====================
  49. if __name__ == "__main__":
  50. # 配置文件路径(请修改为你的实际文件路径)
  51. INPUT_CSV = "rag_eval_results.csv" # 输入CSV文件路径
  52. OUTPUT_CSV = "filtered_data.csv" # 输出CSV文件路径
  53. # 执行筛选
  54. filter_duplicate_entity_with_both_hit_values(INPUT_CSV, OUTPUT_CSV)