lxylxy123321 1 giorno fa
parent
commit
0cae43e09c
1 ha cambiato i file con 3 aggiunte e 10 eliminazioni
  1. 3 10
      backend/app/engines/text_engine.py

+ 3 - 10
backend/app/engines/text_engine.py

@@ -327,6 +327,7 @@ class TextEngine(BaseEngine):
             for param in ref_model.parameters():
                 param.requires_grad = False
 
+            # TRL 0.9.x PPOConfig 只接受 PPO 专用参数,不支持 HuggingFace Trainer 参数
             ppo_config = PPOConfig(
                 learning_rate=learning_rate,
                 batch_size=batch_size,
@@ -334,22 +335,14 @@ class TextEngine(BaseEngine):
                 ppo_epochs=ppo_epochs,
                 vf_coef=vf_coef,
                 init_kl_coef=kl_coef,
-                response_length=response_length,
-                output_dir=output_dir,
-                logging_steps=10,
-                save_strategy=save_strategy,
-                fp16=True,
-                report_to="none",
-                dataloader_num_workers=4,
-                dataloader_pin_memory=False,
             )
 
             trainer = PPOTrainer(
                 config=ppo_config,
                 model=self._model,
                 ref_model=ref_model,
-                processing_class=self._tokenizer,
-                train_dataset=ppo_dataset,
+                tokenizer=self._tokenizer,
+                dataset=ppo_dataset,
             )
 
             dataloader = trainer.dataloader