| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import pandas as pd
- import os
- import random
- # 与项目其他脚本保持一致:UTF-8 with BOM,便于 Excel 正常识别中文
- CHINESE_UTF8_SIG = "utf-8-sig"
- # 设置随机种子(可选,保证结果可复现)
- random.seed(42)
- def deduplicate_entity_keep_true(input_file, output_file):
- """
- 对CSV数据去重,entity_name列去重,优先保留eval_hit为TRUE的记录
-
- Args:
- input_file (str): 输入CSV文件路径
- output_file (str): 输出CSV文件路径
- """
- # 检查输入文件是否存在
- if not os.path.exists(input_file):
- print(f"错误:输入文件 {input_file} 不存在!")
- return
-
- try:
- # 1. 读取CSV文件
- df = pd.read_csv(input_file, encoding=CHINESE_UTF8_SIG)
-
- # 检查必要的列是否存在
- required_columns = ['entity_name', 'eval_hit']
- missing_columns = [col for col in required_columns if col not in df.columns]
- if missing_columns:
- print(f"错误:CSV文件缺少必要的列:{missing_columns}")
- return
-
- # 2. 标准化eval_hit列的值(统一大小写,去除空格)
- df['eval_hit'] = df['eval_hit'].astype(str).str.strip().str.upper()
-
- # 存储去重后的结果
- deduplicated_rows = []
-
- # 3. 按entity_name分组处理
- for entity_name, group in df.groupby('entity_name'):
- # 分离该分组下TRUE和FALSE的记录
- true_records = group[group['eval_hit'] == 'TRUE']
- false_records = group[group['eval_hit'] == 'FALSE']
-
- if not true_records.empty:
- # 有TRUE值:随机选1条TRUE记录保留
- selected_row = true_records.sample(n=1, random_state=random.randint(1, 1000))
- elif not false_records.empty:
- # 只有FALSE值:随机选1条FALSE记录保留
- selected_row = false_records.sample(n=1, random_state=random.randint(1, 1000))
- else:
- # 无有效eval_hit值(理论上不会出现)
- print(f"警告:entity_name={entity_name} 无有效TRUE/FALSE值,跳过")
- continue
-
- deduplicated_rows.append(selected_row)
-
- # 4. 合并所有选中的行并保存
- if deduplicated_rows:
- result_df = pd.concat(deduplicated_rows, ignore_index=True)
- # 保存到新CSV
- result_df.to_csv(output_file, index=False, encoding=CHINESE_UTF8_SIG)
-
- print(f"去重完成!")
- print(f"- 原始记录总数:{len(df)}")
- print(f"- 去重后记录总数:{len(result_df)}")
- print(f"- 结果已保存到:{output_file}")
- else:
- print("未找到可处理的有效记录!")
-
- except Exception as e:
- print(f"处理过程中出错:{str(e)}")
- # ===================== 主程序 =====================
- if __name__ == "__main__":
- # 配置文件路径(请修改为你的实际文件路径)
- INPUT_CSV = "rag_eval_results.csv" # 输入CSV文件路径
- OUTPUT_CSV = "deduplicated_data.csv" # 输出CSV文件路径
-
- # 执行去重
- deduplicate_entity_keep_true(INPUT_CSV, OUTPUT_CSV)
|