csv_去重.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import pandas as pd
  2. import os
  3. import random
  4. # 与项目其他脚本保持一致:UTF-8 with BOM,便于 Excel 正常识别中文
  5. CHINESE_UTF8_SIG = "utf-8-sig"
  6. # 设置随机种子(可选,保证结果可复现)
  7. random.seed(42)
  8. def deduplicate_entity_keep_true(input_file, output_file):
  9. """
  10. 对CSV数据去重,entity_name列去重,优先保留eval_hit为TRUE的记录
  11. Args:
  12. input_file (str): 输入CSV文件路径
  13. output_file (str): 输出CSV文件路径
  14. """
  15. # 检查输入文件是否存在
  16. if not os.path.exists(input_file):
  17. print(f"错误:输入文件 {input_file} 不存在!")
  18. return
  19. try:
  20. # 1. 读取CSV文件
  21. df = pd.read_csv(input_file, encoding=CHINESE_UTF8_SIG)
  22. # 检查必要的列是否存在
  23. required_columns = ['entity_name', 'eval_hit']
  24. missing_columns = [col for col in required_columns if col not in df.columns]
  25. if missing_columns:
  26. print(f"错误:CSV文件缺少必要的列:{missing_columns}")
  27. return
  28. # 2. 标准化eval_hit列的值(统一大小写,去除空格)
  29. df['eval_hit'] = df['eval_hit'].astype(str).str.strip().str.upper()
  30. # 存储去重后的结果
  31. deduplicated_rows = []
  32. # 3. 按entity_name分组处理
  33. for entity_name, group in df.groupby('entity_name'):
  34. # 分离该分组下TRUE和FALSE的记录
  35. true_records = group[group['eval_hit'] == 'TRUE']
  36. false_records = group[group['eval_hit'] == 'FALSE']
  37. if not true_records.empty:
  38. # 有TRUE值:随机选1条TRUE记录保留
  39. selected_row = true_records.sample(n=1, random_state=random.randint(1, 1000))
  40. elif not false_records.empty:
  41. # 只有FALSE值:随机选1条FALSE记录保留
  42. selected_row = false_records.sample(n=1, random_state=random.randint(1, 1000))
  43. else:
  44. # 无有效eval_hit值(理论上不会出现)
  45. print(f"警告:entity_name={entity_name} 无有效TRUE/FALSE值,跳过")
  46. continue
  47. deduplicated_rows.append(selected_row)
  48. # 4. 合并所有选中的行并保存
  49. if deduplicated_rows:
  50. result_df = pd.concat(deduplicated_rows, ignore_index=True)
  51. # 保存到新CSV
  52. result_df.to_csv(output_file, index=False, encoding=CHINESE_UTF8_SIG)
  53. print(f"去重完成!")
  54. print(f"- 原始记录总数:{len(df)}")
  55. print(f"- 去重后记录总数:{len(result_df)}")
  56. print(f"- 结果已保存到:{output_file}")
  57. else:
  58. print("未找到可处理的有效记录!")
  59. except Exception as e:
  60. print(f"处理过程中出错:{str(e)}")
  61. # ===================== 主程序 =====================
  62. if __name__ == "__main__":
  63. # 配置文件路径(请修改为你的实际文件路径)
  64. INPUT_CSV = "rag_eval_results.csv" # 输入CSV文件路径
  65. OUTPUT_CSV = "deduplicated_data.csv" # 输出CSV文件路径
  66. # 执行去重
  67. deduplicate_entity_keep_true(INPUT_CSV, OUTPUT_CSV)