Explorar el Código

修复dpo报错

lxylxy123321 hace 20 horas
padre
commit
397974ca55
Se han modificado 1 ficheros con 38 adiciones y 8 borrados
  1. 38 8
      backend/app/engines/text_engine.py

+ 38 - 8
backend/app/engines/text_engine.py

@@ -604,10 +604,10 @@ class TextEngine(BaseEngine):
         return tokenized_dataset
 
     def _load_dataset_dpo(self, dataset_path: str):
-        """加载 DPO 数据集,保留 prompt/chosen/rejected 原始文本,由 DPOTrainer 内部 tokenize。"""
+        """加载并 tokenize DPO 数据集,生成 TRL 0.9.x 需要的 prompt/chosen/rejected_input_ids。"""
         from datasets import Dataset as HFDataset
 
-        data = []
+        raw_data = []
         with open(dataset_path, "r", encoding="utf-8") as f:
             for line in f:
                 line = line.strip()
@@ -616,12 +616,42 @@ class TextEngine(BaseEngine):
                     prompt = item.get("prompt", item.get("instruction", item.get("input", "")))
                     chosen = item.get("chosen", item.get("positive", ""))
                     rejected = item.get("rejected", item.get("negative", ""))
-                    data.append({
-                        "prompt": str(prompt) if prompt else "",
-                        "chosen": str(chosen) if chosen else "",
-                        "rejected": str(rejected) if rejected else "",
-                    })
-        return HFDataset.from_list(data)
+                    if prompt and chosen and rejected:
+                        raw_data.append({
+                            "prompt": str(prompt),
+                            "chosen": str(chosen),
+                            "rejected": str(rejected),
+                        })
+
+        max_len = 2048
+
+        def tokenize_fn(examples):
+            prompts = examples.get("prompt", [])
+            chosens = examples.get("chosen", [])
+            rejecteds = examples.get("rejected", [])
+
+            prompt_tok = self._tokenizer(prompts, truncation=True, max_length=max_len, padding=False)
+            chosen_tok = self._tokenizer(
+                [p + c for p, c in zip(prompts, chosens)],
+                truncation=True, max_length=max_len, padding=False,
+            )
+            rejected_tok = self._tokenizer(
+                [p + r for p, r in zip(prompts, rejecteds)],
+                truncation=True, max_length=max_len, padding=False,
+            )
+
+            return {
+                "prompt_input_ids": prompt_tok["input_ids"],
+                "prompt_attention_mask": prompt_tok["attention_mask"],
+                "chosen_input_ids": chosen_tok["input_ids"],
+                "chosen_attention_mask": chosen_tok["attention_mask"],
+                "rejected_input_ids": rejected_tok["input_ids"],
+                "rejected_attention_mask": rejected_tok["attention_mask"],
+            }
+
+        hf_dataset = HFDataset.from_list(raw_data)
+        tokenized = hf_dataset.map(tokenize_fn, batched=True, remove_columns=hf_dataset.column_names)
+        return tokenized
 
 
 try: