lxylxy123321 17 часов назад
Родитель
Сommit
383056165f
1 измененных файлов с 14 добавлено и 15 удалено
  1. 14 15
      backend/app/engines/text_engine.py

+ 14 - 15
backend/app/engines/text_engine.py

@@ -212,8 +212,10 @@ class TextEngine(BaseEngine):
         if hasattr(peft_config, "init_r") and hasattr(peft_config, "target_r"):
             peft_config.total_step = max_steps
 
-        self._model = get_peft_model(self._model, peft_config)
-        self._model.print_trainable_parameters()
+        # PPO 需要先用 AutoModelForCausalLMWithValueHead 包装,再应用 PEFT(后面单独处理)
+        if task_type != "ppo":
+            self._model = get_peft_model(self._model, peft_config)
+            self._model.print_trainable_parameters()
 
         output_dir = str(settings.adapters_dir / job_id)
 
@@ -302,14 +304,8 @@ class TextEngine(BaseEngine):
                 processing_class=self._tokenizer,
             )
         elif task_type == "ppo":
-            from copy import deepcopy
-
             import torch
-            # 兼容新版 TRL(PPO 移到了 experimental 子模块)和旧版 TRL
-            try:
-                from trl.experimental.ppo import PPOConfig, PPOTrainer
-            except ImportError:
-                from trl import PPOConfig, PPOTrainer
+            from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
 
             ppo_epochs = training_args.get("ppo_epochs", 4)
             vf_coef = training_args.get("vf_coef", 0.1)
@@ -321,11 +317,13 @@ class TextEngine(BaseEngine):
             # PPO 专用:仅 tokenize prompt
             ppo_dataset = self._tokenize_dataset_ppo(dataset_path, max_seq_length, response_length)
 
-            # Reference 模型(冻结,用于 KL 惩罚)
-            ref_model = deepcopy(self._model)
-            ref_model.eval()
-            for param in ref_model.parameters():
-                param.requires_grad = False
+            # PPO 需要 AutoModelForCausalLMWithValueHead(添加 value head 用于评估动作价值)
+            # 通过 peft_config 参数让 TRL 内部处理 PEFT 包装,返回的对象是 PreTrainedModelWrapper
+            # 不能用 get_peft_model(会产生 PeftModel,PPOTrainer 不认)
+            self._model = AutoModelForCausalLMWithValueHead.from_pretrained(
+                self._model, peft_config=peft_config,
+            )
+            self._model.print_trainable_parameters()
 
             # TRL 0.9.x PPOConfig 只接受 PPO 专用参数,不支持 HuggingFace Trainer 参数
             # mini_batch_size 必须满足:batch_size % (mini_batch_size * gradient_accumulation_steps) == 0
@@ -339,10 +337,11 @@ class TextEngine(BaseEngine):
                 init_kl_coef=kl_coef,
             )
 
+            # ref_model=None 让 PPOTrainer 自动创建冻结的 reference model(用于 KL 惩罚)
             trainer = PPOTrainer(
                 config=ppo_config,
                 model=self._model,
-                ref_model=ref_model,
+                ref_model=None,
                 tokenizer=self._tokenizer,
                 dataset=ppo_dataset,
             )