import pandas as pd import os # 与项目其他脚本保持一致:UTF-8 with BOM,便于 Excel 正常识别中文 CHINESE_UTF8_SIG = "utf-8-sig" def filter_duplicate_entity_with_both_hit_values(input_file, output_file): """ 筛选entity_name重复且对应eval_hit同时包含TRUE/FALSE的记录 Args: input_file (str): 输入CSV文件路径 output_file (str): 输出CSV文件路径 """ # 检查输入文件是否存在 if not os.path.exists(input_file): print(f"错误:输入文件 {input_file} 不存在!") return try: # 1. 读取CSV文件 df = pd.read_csv(input_file, encoding=CHINESE_UTF8_SIG) # 检查必要的列是否存在 required_columns = ['entity_name', 'eval_hit'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: print(f"错误:CSV文件缺少必要的列:{missing_columns}") return # 2. 标准化eval_hit列的值(统一大小写,去除空格) df['eval_hit'] = df['eval_hit'].astype(str).str.strip().str.upper() # 3. 按entity_name分组,检查每个分组是否同时包含TRUE和FALSE # 获取每个entity_name对应的唯一eval_hit值集合 entity_hit_groups = df.groupby('entity_name')['eval_hit'].unique() # 筛选出同时包含TRUE和FALSE的entity_name target_entities = [ entity for entity, hits in entity_hit_groups.items() if 'TRUE' in hits and 'FALSE' in hits ] if not target_entities: print("未找到符合条件的记录(entity_name重复且eval_hit包含TRUE/FALSE)") return # 4. 提取符合条件的所有记录 filtered_df = df[df['entity_name'].isin(target_entities)] # 5. 保存到新CSV文件 filtered_df.to_csv(output_file, index=False, encoding=CHINESE_UTF8_SIG) print(f"筛选完成!") print(f"- 符合条件的entity_name数量:{len(target_entities)}") print(f"- 提取的记录总数:{len(filtered_df)}") print(f"- 结果已保存到:{output_file}") except Exception as e: print(f"处理过程中出错:{str(e)}") # ===================== 主程序 ===================== if __name__ == "__main__": # 配置文件路径(请修改为你的实际文件路径) INPUT_CSV = "rag_eval_results.csv" # 输入CSV文件路径 OUTPUT_CSV = "filtered_data.csv" # 输出CSV文件路径 # 执行筛选 filter_duplicate_entity_with_both_hit_values(INPUT_CSV, OUTPUT_CSV)