"""数据预处理器:将不同格式的数据集转换为训练所需格式。""" import json from pathlib import Path from typing import Any # 常见列名映射 _PROMPT_COLUMNS = {"prompt", "question", "query", "text", "input"} _COMPLETION_COLUMNS = {"completion", "answer", "response", "target", "output"} _ALPACA_COLUMNS = {"instruction", "input", "output"} _SHAREGPT_COLUMNS = {"conversations"} _DPO_COLUMNS = {"prompt", "chosen", "rejected"} def apply_auto_template(item: dict, column_map: dict[str, str]) -> dict: """Auto 模板:根据实际列名自动映射。""" prompt = "" completion = "" # 先找 prompt 列 for col in column_map.get("prompt_candidates", []): if col in item and item[col] is not None: prompt = str(item[col]) break # 再找 completion 列 for col in column_map.get("completion_candidates", []): if col in item and item[col] is not None: completion = str(item[col]) break return {"prompt": prompt, "completion": completion} def _detect_columns(raw_data: list[dict]) -> dict[str, list[str]]: """扫描数据集前几行,自动检测列名并返回映射关系。""" if not raw_data: return {"prompt_candidates": [], "completion_candidates": [], "template": "raw"} # 取前 5 行扫描 sample = raw_data[:5] all_columns = set() for item in sample: all_columns.update(item.keys()) lower_cols = {c.lower().strip(): c for c in all_columns} # 检测模板类型 if _SHAREGPT_COLUMNS & all_columns: return {"template": "sharegpt"} if _DPO_COLUMNS & all_columns: return {"template": "dpo"} if _ALPACA_COLUMNS & all_columns: return {"template": "alpaca"} # 查找 prompt 和 completion 候选列 prompt_candidates = [lower_cols.get(c) for c in ["prompt", "question", "query", "text", "input"] if lower_cols.get(c)] completion_candidates = [lower_cols.get(c) for c in ["completion", "answer", "response", "target", "output"] if lower_cols.get(c)] return { "template": "auto", "prompt_candidates": prompt_candidates, "completion_candidates": completion_candidates, } def apply_alpaca_template(item: dict) -> dict: """Alpaca 模板: instruction + input -> output。""" instruction = item.get("instruction", "") input_text = item.get("input", "") output = item.get("output", "") # 确保所有值为字符串 instruction = str(instruction) if instruction is not None else "" input_text = str(input_text) if input_text is not None else "" output = str(output) if output is not None else "" prompt = f"{instruction}\n\n{input_text}" if input_text else instruction return {"prompt": prompt, "completion": output} def apply_sharegpt_template(item: dict) -> dict: """ShareGPT 模板: conversations list -> formatted prompt + completion。""" conversations = item.get("conversations", []) if len(conversations) < 2: return {"prompt": "", "completion": ""} prompt_parts = [] completion = "" for i, turn in enumerate(conversations): role = turn.get("from", turn.get("role", "human")) content = turn.get("value", turn.get("content", "")) if i == 0: prompt_parts.append(content) elif i == 1: completion = content break else: prompt_parts.append(f"{role}: {content}") prompt = "\n".join(prompt_parts) return {"prompt": prompt, "completion": completion} def apply_raw_template(item: dict) -> dict: """Raw 模板: 直接读取 prompt/text 和 completion/output 字段。""" prompt = item.get("prompt", item.get("text", item.get("input", item.get("question", item.get("query", ""))))) completion = item.get("completion", item.get("output", item.get("target", item.get("answer", item.get("response", ""))))) return {"prompt": str(prompt), "completion": str(completion)} def apply_dpo_template(item: dict) -> dict: """DPO 模板: prompt + chosen + rejected。""" prompt = item.get("prompt", item.get("instruction", item.get("input", item.get("question", item.get("query", ""))))) chosen = item.get("chosen", item.get("positive", item.get("answer", ""))) rejected = item.get("rejected", item.get("negative", "")) # 确保所有值为字符串 prompt = str(prompt) if prompt is not None else "" chosen = str(chosen) if chosen is not None else "" rejected = str(rejected) if rejected is not None else "" return {"prompt": prompt, "chosen": chosen, "rejected": rejected} TEMPLATE_MAP = { "sft": { "auto": None, # 特殊处理:自动检测 "alpaca": apply_alpaca_template, "sharegpt": apply_sharegpt_template, "raw": apply_raw_template, }, "dpo": { "auto": apply_dpo_template, "alpaca": apply_dpo_template, "sharegpt": apply_dpo_template, "raw": apply_dpo_template, }, "ppo": { "auto": apply_raw_template, "alpaca": apply_alpaca_template, "sharegpt": apply_sharegpt_template, "raw": apply_raw_template, }, } def preprocess_file( input_path: str, output_path: str, task_type: str = "sft", template: str = "auto", ) -> list[dict[str, Any]]: """读取文件并应用模板,返回处理后的数据列表。""" input_p = Path(input_path) ext = input_p.suffix.lower() # 读取原始数据 if ext == ".jsonl": with open(input_path, "r", encoding="utf-8") as f: raw_data = [json.loads(line) for line in f if line.strip()] elif ext == ".json": with open(input_path, "r", encoding="utf-8") as f: try: data = json.load(f) raw_data = data if isinstance(data, list) else [data] except json.JSONDecodeError: # 回退到 JSONL 格式(每行一个 JSON 对象) f.seek(0) raw_data = [json.loads(line) for line in f if line.strip()] elif ext == ".csv": import csv with open(input_path, "r", encoding="utf-8") as f: reader = csv.DictReader(f) raw_data = [dict(row) for row in reader] elif ext == ".parquet": import pandas as pd df = pd.read_parquet(input_path) raw_data = df.to_dict(orient="records") else: raise ValueError(f"Unsupported format: {ext}") # 获取模板函数 templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"]) apply_fn = templates.get(template, templates.get("raw", apply_raw_template)) # Auto 模板:自动检测列名 column_map = {} if template == "auto" and apply_fn is None: column_map = _detect_columns(raw_data) detected = column_map.get("template", "raw") if detected == "sharegpt": apply_fn = apply_sharegpt_template elif detected == "alpaca": apply_fn = apply_alpaca_template elif detected == "dpo": apply_fn = apply_dpo_template elif detected == "auto": apply_fn = lambda item, cm=column_map: apply_auto_template(item, cm) else: apply_fn = apply_raw_template # 应用模板 processed = [] for item in raw_data: try: result = apply_fn(item) if result.get("prompt"): processed.append(result) except Exception: continue # 写入处理后的数据(先删旧文件避免权限冲突) output_p = Path(output_path) output_p.parent.mkdir(parents=True, exist_ok=True) if output_p.exists(): output_p.unlink() with open(output_path, "w", encoding="utf-8") as f: for item in processed: f.write(json.dumps(item, ensure_ascii=False) + "\n") return processed