Bläddra i källkod

修复训练列名报错问题

lxylxy123321 6 dagar sedan
förälder
incheckning
c87965bdf3

+ 1 - 1
backend/app/engines/remote_train.py

@@ -122,7 +122,7 @@ async def run_training(job_id: str, model_id: str, model_type: str, dataset_path
         # 预处理
         processed_path = str(_PROCESSED_DIR / f"{job_id}_processed.jsonl")
         task_type = config.get("task_type", "sft")
-        template = config.get("dataset_template", "alpaca")
+        template = config.get("dataset_template", "auto")
         _remote_log(f"  task_type={task_type}, template={template}")
         _remote_log(f"  output_path={processed_path}")
 

+ 82 - 1
backend/app/preprocessors/__init__.py

@@ -4,6 +4,65 @@ import json
 from pathlib import Path
 from typing import Any
 
+# 常见列名映射
+_PROMPT_COLUMNS = {"prompt", "question", "query", "text", "input"}
+_COMPLETION_COLUMNS = {"completion", "answer", "response", "target", "output"}
+_ALPACA_COLUMNS = {"instruction", "input", "output"}
+_SHAREGPT_COLUMNS = {"conversations"}
+_DPO_COLUMNS = {"prompt", "chosen", "rejected"}
+
+
+def apply_auto_template(item: dict, column_map: dict[str, str]) -> dict:
+    """Auto 模板:根据实际列名自动映射。"""
+    prompt = ""
+    completion = ""
+
+    # 先找 prompt 列
+    for col in column_map.get("prompt_candidates", []):
+        if col in item and item[col] is not None:
+            prompt = str(item[col])
+            break
+
+    # 再找 completion 列
+    for col in column_map.get("completion_candidates", []):
+        if col in item and item[col] is not None:
+            completion = str(item[col])
+            break
+
+    return {"prompt": prompt, "completion": completion}
+
+
+def _detect_columns(raw_data: list[dict]) -> dict[str, list[str]]:
+    """扫描数据集前几行,自动检测列名并返回映射关系。"""
+    if not raw_data:
+        return {"prompt_candidates": [], "completion_candidates": [], "template": "raw"}
+
+    # 取前 5 行扫描
+    sample = raw_data[:5]
+    all_columns = set()
+    for item in sample:
+        all_columns.update(item.keys())
+
+    lower_cols = {c.lower().strip(): c for c in all_columns}
+
+    # 检测模板类型
+    if _SHAREGPT_COLUMNS & all_columns:
+        return {"template": "sharegpt"}
+    if _DPO_COLUMNS & all_columns:
+        return {"template": "dpo"}
+    if _ALPACA_COLUMNS & all_columns:
+        return {"template": "alpaca"}
+
+    # 查找 prompt 和 completion 候选列
+    prompt_candidates = [lower_cols.get(c) for c in ["prompt", "question", "query", "text", "input"] if lower_cols.get(c)]
+    completion_candidates = [lower_cols.get(c) for c in ["completion", "answer", "response", "target", "output"] if lower_cols.get(c)]
+
+    return {
+        "template": "auto",
+        "prompt_candidates": prompt_candidates,
+        "completion_candidates": completion_candidates,
+    }
+
 
 def apply_alpaca_template(item: dict) -> dict:
     """Alpaca 模板: instruction + input -> output。"""
@@ -86,26 +145,32 @@ def apply_rm_template(item: dict) -> dict:
 
 TEMPLATE_MAP = {
     "sft": {
+        "auto": None,  # 特殊处理:自动检测
         "alpaca": apply_alpaca_template,
         "sharegpt": apply_sharegpt_template,
         "raw": apply_raw_template,
     },
     "dpo": {
+        "auto": apply_dpo_template,
         "alpaca": apply_dpo_template,
         "sharegpt": apply_dpo_template,
         "raw": apply_dpo_template,
     },
     "kto": {
+        "auto": apply_kto_template,
         "raw": apply_kto_template,
     },
     "orpo": {
+        "auto": apply_orpo_template,
         "alpaca": apply_orpo_template,
         "raw": apply_orpo_template,
     },
     "rm": {
+        "auto": apply_rm_template,
         "raw": apply_rm_template,
     },
     "ppo": {
+        "auto": apply_raw_template,
         "raw": apply_raw_template,
     },
 }
@@ -115,7 +180,7 @@ def preprocess_file(
     input_path: str,
     output_path: str,
     task_type: str = "sft",
-    template: str = "alpaca",
+    template: str = "auto",
 ) -> list[dict[str, Any]]:
     """读取文件并应用模板,返回处理后的数据列表。"""
     input_p = Path(input_path)
@@ -150,6 +215,22 @@ def preprocess_file(
     templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
     apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
 
+    # Auto 模板:自动检测列名
+    column_map = {}
+    if template == "auto" and apply_fn is None:
+        column_map = _detect_columns(raw_data)
+        detected = column_map.get("template", "raw")
+        if detected == "sharegpt":
+            apply_fn = apply_sharegpt_template
+        elif detected == "alpaca":
+            apply_fn = apply_alpaca_template
+        elif detected == "dpo":
+            apply_fn = apply_dpo_template
+        elif detected == "auto":
+            apply_fn = lambda item, cm=column_map: apply_auto_template(item, cm)
+        else:
+            apply_fn = apply_raw_template
+
     # 应用模板
     processed = []
     for item in raw_data:

+ 1 - 1
backend/app/services/training_service.py

@@ -21,7 +21,7 @@ async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
     dataset_id = config.get("dataset_id", "")
     peft_method = config.get("peft_method", "lora")
     task_type = config.get("task_type", "sft")
-    dataset_template = config.get("dataset_template", "alpaca")
+    dataset_template = config.get("dataset_template", "auto")
 
     # 写入数据库
     record = TrainingJobModel(

+ 3 - 2
frontend/src/pages/Training.tsx

@@ -25,6 +25,7 @@ const TASK_TYPES = [
 ]
 
 const DATASET_TEMPLATES = [
+  { value: 'auto', label: 'Auto (自动检测)' },
   { value: 'alpaca', label: 'Alpaca (instruction/input/output)' },
   { value: 'sharegpt', label: 'ShareGPT (conversations)' },
   { value: 'raw', label: 'Raw (text 字段)' },
@@ -311,9 +312,9 @@ export function Training() {
   const [modelId, setModelId] = useState('')
   const [modelType, setModelType] = useState('text')
   const [datasetId, setDatasetId] = useState('')
-  const [peftMethod, setPeftMethod] = useState('qlora')
+  const [peftMethod, setPeftMethod] = useState('lora')
   const [taskType, setTaskType] = useState('sft')
-  const [template, setTemplate] = useState('alpaca')
+  const [template, setTemplate] = useState('auto')
   const [epochs, setEpochs] = useState(3)
   const [batchSize, setBatchSize] = useState(4)
   const [lr, setLr] = useState('2e-4')