| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- """数据预处理器:将不同格式的数据集转换为训练所需格式。"""
- import json
- from pathlib import Path
- from typing import Any
- # 常见列名映射
- _PROMPT_COLUMNS = {"prompt", "question", "query", "text", "input"}
- _COMPLETION_COLUMNS = {"completion", "answer", "response", "target", "output"}
- _ALPACA_COLUMNS = {"instruction", "input", "output"}
- _SHAREGPT_COLUMNS = {"conversations"}
- _DPO_COLUMNS = {"prompt", "chosen", "rejected"}
- def apply_auto_template(item: dict, column_map: dict[str, str]) -> dict:
- """Auto 模板:根据实际列名自动映射。"""
- prompt = ""
- completion = ""
- # 先找 prompt 列
- for col in column_map.get("prompt_candidates", []):
- if col in item and item[col] is not None:
- prompt = str(item[col])
- break
- # 再找 completion 列
- for col in column_map.get("completion_candidates", []):
- if col in item and item[col] is not None:
- completion = str(item[col])
- break
- return {"prompt": prompt, "completion": completion}
- def _detect_columns(raw_data: list[dict]) -> dict[str, list[str]]:
- """扫描数据集前几行,自动检测列名并返回映射关系。"""
- if not raw_data:
- return {"prompt_candidates": [], "completion_candidates": [], "template": "raw"}
- # 取前 5 行扫描
- sample = raw_data[:5]
- all_columns = set()
- for item in sample:
- all_columns.update(item.keys())
- lower_cols = {c.lower().strip(): c for c in all_columns}
- # 检测模板类型
- if _SHAREGPT_COLUMNS & all_columns:
- return {"template": "sharegpt"}
- if _DPO_COLUMNS & all_columns:
- return {"template": "dpo"}
- if _ALPACA_COLUMNS & all_columns:
- return {"template": "alpaca"}
- # 查找 prompt 和 completion 候选列
- prompt_candidates = [lower_cols.get(c) for c in ["prompt", "question", "query", "text", "input"] if lower_cols.get(c)]
- completion_candidates = [lower_cols.get(c) for c in ["completion", "answer", "response", "target", "output"] if lower_cols.get(c)]
- return {
- "template": "auto",
- "prompt_candidates": prompt_candidates,
- "completion_candidates": completion_candidates,
- }
- def apply_alpaca_template(item: dict) -> dict:
- """Alpaca 模板: instruction + input -> output。"""
- instruction = item.get("instruction", "")
- input_text = item.get("input", "")
- output = item.get("output", "")
- # 确保所有值为字符串
- instruction = str(instruction) if instruction is not None else ""
- input_text = str(input_text) if input_text is not None else ""
- output = str(output) if output is not None else ""
- prompt = f"{instruction}\n\n{input_text}" if input_text else instruction
- return {"prompt": prompt, "completion": output}
- def apply_sharegpt_template(item: dict) -> dict:
- """ShareGPT 模板: conversations list -> formatted prompt + completion。"""
- conversations = item.get("conversations", [])
- if len(conversations) < 2:
- return {"prompt": "", "completion": ""}
- prompt_parts = []
- completion = ""
- for i, turn in enumerate(conversations):
- role = turn.get("from", turn.get("role", "human"))
- content = turn.get("value", turn.get("content", ""))
- if i == 0:
- prompt_parts.append(content)
- elif i == 1:
- completion = content
- break
- else:
- prompt_parts.append(f"{role}: {content}")
- prompt = "\n".join(prompt_parts)
- return {"prompt": prompt, "completion": completion}
- def apply_raw_template(item: dict) -> dict:
- """Raw 模板: 直接读取 prompt/text 和 completion/output 字段。"""
- prompt = item.get("prompt", item.get("text", item.get("input", item.get("question", item.get("query", "")))))
- completion = item.get("completion", item.get("output", item.get("target", item.get("answer", item.get("response", "")))))
- return {"prompt": str(prompt), "completion": str(completion)}
- def apply_dpo_template(item: dict) -> dict:
- """DPO 模板: prompt + chosen + rejected。"""
- return {
- "prompt": item.get("prompt", item.get("input", item.get("question", item.get("query", "")))),
- "chosen": item.get("chosen", item.get("positive", item.get("answer", ""))),
- "rejected": item.get("rejected", item.get("negative", "")),
- }
- def apply_kto_template(item: dict) -> dict:
- """KTO 模板: prompt + completion + label。"""
- return {
- "prompt": item.get("prompt", item.get("input", item.get("question", item.get("query", "")))),
- "completion": item.get("completion", item.get("output", item.get("answer", item.get("response", "")))),
- "label": item.get("label", True),
- }
- def apply_orpo_template(item: dict) -> dict:
- """ORPO 模板: prompt + chosen + rejected (类似 DPO)。"""
- return {
- "prompt": item.get("prompt", item.get("input", item.get("question", item.get("query", "")))),
- "chosen": item.get("chosen", item.get("positive", item.get("answer", ""))),
- "rejected": item.get("rejected", item.get("negative", "")),
- }
- def apply_rm_template(item: dict) -> dict:
- """Reward Modeling 模板: prompt + chosen + rejected。"""
- return {
- "prompt": item.get("prompt", item.get("input", item.get("question", item.get("query", "")))),
- "chosen": item.get("chosen", item.get("positive", item.get("answer", ""))),
- "rejected": item.get("rejected", item.get("negative", "")),
- }
- TEMPLATE_MAP = {
- "sft": {
- "auto": None, # 特殊处理:自动检测
- "alpaca": apply_alpaca_template,
- "sharegpt": apply_sharegpt_template,
- "raw": apply_raw_template,
- },
- "dpo": {
- "auto": apply_dpo_template,
- "alpaca": apply_dpo_template,
- "sharegpt": apply_dpo_template,
- "raw": apply_dpo_template,
- },
- "kto": {
- "auto": apply_kto_template,
- "raw": apply_kto_template,
- },
- "orpo": {
- "auto": apply_orpo_template,
- "alpaca": apply_orpo_template,
- "raw": apply_orpo_template,
- },
- "rm": {
- "auto": apply_rm_template,
- "raw": apply_rm_template,
- },
- "ppo": {
- "auto": apply_raw_template,
- "raw": apply_raw_template,
- },
- }
- def preprocess_file(
- input_path: str,
- output_path: str,
- task_type: str = "sft",
- template: str = "auto",
- ) -> list[dict[str, Any]]:
- """读取文件并应用模板,返回处理后的数据列表。"""
- input_p = Path(input_path)
- ext = input_p.suffix.lower()
- # 读取原始数据
- if ext == ".jsonl":
- with open(input_path, "r", encoding="utf-8") as f:
- raw_data = [json.loads(line) for line in f if line.strip()]
- elif ext == ".json":
- with open(input_path, "r", encoding="utf-8") as f:
- try:
- data = json.load(f)
- raw_data = data if isinstance(data, list) else [data]
- except json.JSONDecodeError:
- # 回退到 JSONL 格式(每行一个 JSON 对象)
- f.seek(0)
- raw_data = [json.loads(line) for line in f if line.strip()]
- elif ext == ".csv":
- import csv
- with open(input_path, "r", encoding="utf-8") as f:
- reader = csv.DictReader(f)
- raw_data = [dict(row) for row in reader]
- elif ext == ".parquet":
- import pandas as pd
- df = pd.read_parquet(input_path)
- raw_data = df.to_dict(orient="records")
- else:
- raise ValueError(f"Unsupported format: {ext}")
- # 获取模板函数
- templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
- apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
- # Auto 模板:自动检测列名
- column_map = {}
- if template == "auto" and apply_fn is None:
- column_map = _detect_columns(raw_data)
- detected = column_map.get("template", "raw")
- if detected == "sharegpt":
- apply_fn = apply_sharegpt_template
- elif detected == "alpaca":
- apply_fn = apply_alpaca_template
- elif detected == "dpo":
- apply_fn = apply_dpo_template
- elif detected == "auto":
- apply_fn = lambda item, cm=column_map: apply_auto_template(item, cm)
- else:
- apply_fn = apply_raw_template
- # 应用模板
- processed = []
- for item in raw_data:
- try:
- result = apply_fn(item)
- if result.get("prompt"):
- processed.append(result)
- except Exception:
- continue
- # 写入处理后的数据(先删旧文件避免权限冲突)
- output_p = Path(output_path)
- output_p.parent.mkdir(parents=True, exist_ok=True)
- if output_p.exists():
- output_p.unlink()
- with open(output_path, "w", encoding="utf-8") as f:
- for item in processed:
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
- return processed
|