Explorar o código

修复ppo报错

lxylxy123321 hai 23 horas
pai
achega
34d10f36cd
Modificáronse 1 ficheiros con 21 adicións e 0 borrados
  1. 21 0
      backend/app/engines/text_engine.py

+ 21 - 0
backend/app/engines/text_engine.py

@@ -409,6 +409,23 @@ class TextEngine(BaseEngine):
 
                 reward_model = _HeuristicRewardModel(self._tokenizer, _compute_heuristic_reward)
 
+            # ---- 构建 value_model(价值函数模型,新版 PPOTrainer 必需)----
+            value_model = None
+            if "value_model" in trainer_params:
+                from transformers import AutoModelForSequenceClassification
+                # PEFT 包装后 config._name_or_path 仍指向 base model
+                base_model_path = getattr(
+                    peft_config, "base_model_name_or_path", None
+                ) or self._model.config._name_or_path
+                value_model = AutoModelForSequenceClassification.from_pretrained(
+                    base_model_path,
+                    num_labels=1,
+                    torch_dtype=torch.float16,
+                )
+                value_model.to(self._model.device)
+                value_model.eval()
+                logger.info(f"已加载 value_model from {base_model_path}")
+
             # ---- 构建 PPOTrainer ----
             trainer_kwargs = dict(
                 model=self._model,
@@ -427,6 +444,10 @@ class TextEngine(BaseEngine):
             if "reward_model" in trainer_params:
                 trainer_kwargs["reward_model"] = reward_model
 
+            # 新版 PPOTrainer 需要 value_model
+            if value_model is not None:
+                trainer_kwargs["value_model"] = value_model
+
             logger.info(f"PPOTrainer 可用参数: {sorted(trainer_params)}")
             trainer = PPOTrainer(**trainer_kwargs)