|
@@ -4,6 +4,65 @@ import json
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
+# 常见列名映射
|
|
|
|
|
+_PROMPT_COLUMNS = {"prompt", "question", "query", "text", "input"}
|
|
|
|
|
+_COMPLETION_COLUMNS = {"completion", "answer", "response", "target", "output"}
|
|
|
|
|
+_ALPACA_COLUMNS = {"instruction", "input", "output"}
|
|
|
|
|
+_SHAREGPT_COLUMNS = {"conversations"}
|
|
|
|
|
+_DPO_COLUMNS = {"prompt", "chosen", "rejected"}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def apply_auto_template(item: dict, column_map: dict[str, str]) -> dict:
|
|
|
|
|
+ """Auto 模板:根据实际列名自动映射。"""
|
|
|
|
|
+ prompt = ""
|
|
|
|
|
+ completion = ""
|
|
|
|
|
+
|
|
|
|
|
+ # 先找 prompt 列
|
|
|
|
|
+ for col in column_map.get("prompt_candidates", []):
|
|
|
|
|
+ if col in item and item[col] is not None:
|
|
|
|
|
+ prompt = str(item[col])
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ # 再找 completion 列
|
|
|
|
|
+ for col in column_map.get("completion_candidates", []):
|
|
|
|
|
+ if col in item and item[col] is not None:
|
|
|
|
|
+ completion = str(item[col])
|
|
|
|
|
+ break
|
|
|
|
|
+
|
|
|
|
|
+ return {"prompt": prompt, "completion": completion}
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _detect_columns(raw_data: list[dict]) -> dict[str, list[str]]:
|
|
|
|
|
+ """扫描数据集前几行,自动检测列名并返回映射关系。"""
|
|
|
|
|
+ if not raw_data:
|
|
|
|
|
+ return {"prompt_candidates": [], "completion_candidates": [], "template": "raw"}
|
|
|
|
|
+
|
|
|
|
|
+ # 取前 5 行扫描
|
|
|
|
|
+ sample = raw_data[:5]
|
|
|
|
|
+ all_columns = set()
|
|
|
|
|
+ for item in sample:
|
|
|
|
|
+ all_columns.update(item.keys())
|
|
|
|
|
+
|
|
|
|
|
+ lower_cols = {c.lower().strip(): c for c in all_columns}
|
|
|
|
|
+
|
|
|
|
|
+ # 检测模板类型
|
|
|
|
|
+ if _SHAREGPT_COLUMNS & all_columns:
|
|
|
|
|
+ return {"template": "sharegpt"}
|
|
|
|
|
+ if _DPO_COLUMNS & all_columns:
|
|
|
|
|
+ return {"template": "dpo"}
|
|
|
|
|
+ if _ALPACA_COLUMNS & all_columns:
|
|
|
|
|
+ return {"template": "alpaca"}
|
|
|
|
|
+
|
|
|
|
|
+ # 查找 prompt 和 completion 候选列
|
|
|
|
|
+ prompt_candidates = [lower_cols.get(c) for c in ["prompt", "question", "query", "text", "input"] if lower_cols.get(c)]
|
|
|
|
|
+ completion_candidates = [lower_cols.get(c) for c in ["completion", "answer", "response", "target", "output"] if lower_cols.get(c)]
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "template": "auto",
|
|
|
|
|
+ "prompt_candidates": prompt_candidates,
|
|
|
|
|
+ "completion_candidates": completion_candidates,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
|
|
|
def apply_alpaca_template(item: dict) -> dict:
|
|
def apply_alpaca_template(item: dict) -> dict:
|
|
|
"""Alpaca 模板: instruction + input -> output。"""
|
|
"""Alpaca 模板: instruction + input -> output。"""
|
|
@@ -86,26 +145,32 @@ def apply_rm_template(item: dict) -> dict:
|
|
|
|
|
|
|
|
TEMPLATE_MAP = {
|
|
TEMPLATE_MAP = {
|
|
|
"sft": {
|
|
"sft": {
|
|
|
|
|
+ "auto": None, # 特殊处理:自动检测
|
|
|
"alpaca": apply_alpaca_template,
|
|
"alpaca": apply_alpaca_template,
|
|
|
"sharegpt": apply_sharegpt_template,
|
|
"sharegpt": apply_sharegpt_template,
|
|
|
"raw": apply_raw_template,
|
|
"raw": apply_raw_template,
|
|
|
},
|
|
},
|
|
|
"dpo": {
|
|
"dpo": {
|
|
|
|
|
+ "auto": apply_dpo_template,
|
|
|
"alpaca": apply_dpo_template,
|
|
"alpaca": apply_dpo_template,
|
|
|
"sharegpt": apply_dpo_template,
|
|
"sharegpt": apply_dpo_template,
|
|
|
"raw": apply_dpo_template,
|
|
"raw": apply_dpo_template,
|
|
|
},
|
|
},
|
|
|
"kto": {
|
|
"kto": {
|
|
|
|
|
+ "auto": apply_kto_template,
|
|
|
"raw": apply_kto_template,
|
|
"raw": apply_kto_template,
|
|
|
},
|
|
},
|
|
|
"orpo": {
|
|
"orpo": {
|
|
|
|
|
+ "auto": apply_orpo_template,
|
|
|
"alpaca": apply_orpo_template,
|
|
"alpaca": apply_orpo_template,
|
|
|
"raw": apply_orpo_template,
|
|
"raw": apply_orpo_template,
|
|
|
},
|
|
},
|
|
|
"rm": {
|
|
"rm": {
|
|
|
|
|
+ "auto": apply_rm_template,
|
|
|
"raw": apply_rm_template,
|
|
"raw": apply_rm_template,
|
|
|
},
|
|
},
|
|
|
"ppo": {
|
|
"ppo": {
|
|
|
|
|
+ "auto": apply_raw_template,
|
|
|
"raw": apply_raw_template,
|
|
"raw": apply_raw_template,
|
|
|
},
|
|
},
|
|
|
}
|
|
}
|
|
@@ -115,7 +180,7 @@ def preprocess_file(
|
|
|
input_path: str,
|
|
input_path: str,
|
|
|
output_path: str,
|
|
output_path: str,
|
|
|
task_type: str = "sft",
|
|
task_type: str = "sft",
|
|
|
- template: str = "alpaca",
|
|
|
|
|
|
|
+ template: str = "auto",
|
|
|
) -> list[dict[str, Any]]:
|
|
) -> list[dict[str, Any]]:
|
|
|
"""读取文件并应用模板,返回处理后的数据列表。"""
|
|
"""读取文件并应用模板,返回处理后的数据列表。"""
|
|
|
input_p = Path(input_path)
|
|
input_p = Path(input_path)
|
|
@@ -150,6 +215,22 @@ def preprocess_file(
|
|
|
templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
|
|
templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
|
|
|
apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
|
|
apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
|
|
|
|
|
|
|
|
|
|
+ # Auto 模板:自动检测列名
|
|
|
|
|
+ column_map = {}
|
|
|
|
|
+ if template == "auto" and apply_fn is None:
|
|
|
|
|
+ column_map = _detect_columns(raw_data)
|
|
|
|
|
+ detected = column_map.get("template", "raw")
|
|
|
|
|
+ if detected == "sharegpt":
|
|
|
|
|
+ apply_fn = apply_sharegpt_template
|
|
|
|
|
+ elif detected == "alpaca":
|
|
|
|
|
+ apply_fn = apply_alpaca_template
|
|
|
|
|
+ elif detected == "dpo":
|
|
|
|
|
+ apply_fn = apply_dpo_template
|
|
|
|
|
+ elif detected == "auto":
|
|
|
|
|
+ apply_fn = lambda item, cm=column_map: apply_auto_template(item, cm)
|
|
|
|
|
+ else:
|
|
|
|
|
+ apply_fn = apply_raw_template
|
|
|
|
|
+
|
|
|
# 应用模板
|
|
# 应用模板
|
|
|
processed = []
|
|
processed = []
|
|
|
for item in raw_data:
|
|
for item in raw_data:
|