浏览代码

修复ppo报错

lxylxy123321 1 天之前
父节点
当前提交
a91c8c230e
共有 1 个文件被更改,包括 33 次插入4 次删除
  1. 33 4
      backend/app/engines/text_engine.py

+ 33 - 4
backend/app/engines/text_engine.py

@@ -318,8 +318,39 @@ class TextEngine(BaseEngine):
             reward_model_path = training_args.get("reward_model_path")
             reward_type = training_args.get("reward_type", "heuristic")
 
-            # PPO 专用:仅 tokenize prompt
-            ppo_dataset = self._tokenize_dataset_ppo(dataset_path, max_seq_length, response_length)
+            import inspect
+
+            # 检测 PPOTrainer 版本(新版无 step 方法,使用标准 Trainer API)
+            trainer_sig = inspect.signature(PPOTrainer.__init__)
+            trainer_params = set(trainer_sig.parameters.keys())
+            is_new_ppo = "step" not in dir(PPOTrainer)
+
+            # ---- 准备数据集 ----
+            if is_new_ppo:
+                # 新版 TRL (1.4.0+):PPOTrainer 自己处理 tokenization 和生成,
+                # 需要传入带 "prompt" 列的原始文本数据集
+                import json as _json
+                from datasets import Dataset as HFDataset
+
+                raw_data = []
+                with open(dataset_path, "r", encoding="utf-8") as f:
+                    for line in f:
+                        line = line.strip()
+                        if line:
+                            item = _json.loads(line)
+                            if "prompt" not in item:
+                                item["prompt"] = item.get("question", item.get("query", item.get("text", item.get("input", ""))))
+                            if isinstance(item["prompt"], (list, dict)):
+                                item["prompt"] = _json.dumps(item["prompt"], ensure_ascii=False)
+                            item["prompt"] = str(item["prompt"])
+                            raw_data.append(item)
+
+                ppo_dataset = HFDataset.from_list(raw_data)
+                logger.info(f"新版 PPOTrainer: 加载原始文本数据集,共 {len(ppo_dataset)} 条")
+            else:
+                # 旧版 TRL:需要预先 tokenize
+                ppo_dataset = self._tokenize_dataset_ppo(dataset_path, max_seq_length, response_length)
+                logger.info(f"旧版 PPOTrainer: 加载 tokenize 数据集,共 {len(ppo_dataset)} 条")
 
             # Reference 模型(冻结,用于 KL 惩罚)
             ref_model = deepcopy(self._model)
@@ -328,8 +359,6 @@ class TextEngine(BaseEngine):
                 param.requires_grad = False
 
             # 兼容不同版本的 TRL PPOConfig 参数名变化
-            # TRL 0.12+ 中 ppo_epochs -> num_ppo_epochs, kl_ctl -> init_kl_coef, vf_coef 被移除
-            import inspect
             ppo_config_sig = inspect.signature(PPOConfig.__init__)
             ppo_config_params = set(ppo_config_sig.parameters.keys())