|
|
@@ -327,6 +327,7 @@ class TextEngine(BaseEngine):
|
|
|
for param in ref_model.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
+ # TRL 0.9.x PPOConfig 只接受 PPO 专用参数,不支持 HuggingFace Trainer 参数
|
|
|
ppo_config = PPOConfig(
|
|
|
learning_rate=learning_rate,
|
|
|
batch_size=batch_size,
|
|
|
@@ -334,22 +335,14 @@ class TextEngine(BaseEngine):
|
|
|
ppo_epochs=ppo_epochs,
|
|
|
vf_coef=vf_coef,
|
|
|
init_kl_coef=kl_coef,
|
|
|
- response_length=response_length,
|
|
|
- output_dir=output_dir,
|
|
|
- logging_steps=10,
|
|
|
- save_strategy=save_strategy,
|
|
|
- fp16=True,
|
|
|
- report_to="none",
|
|
|
- dataloader_num_workers=4,
|
|
|
- dataloader_pin_memory=False,
|
|
|
)
|
|
|
|
|
|
trainer = PPOTrainer(
|
|
|
config=ppo_config,
|
|
|
model=self._model,
|
|
|
ref_model=ref_model,
|
|
|
- processing_class=self._tokenizer,
|
|
|
- train_dataset=ppo_dataset,
|
|
|
+ tokenizer=self._tokenizer,
|
|
|
+ dataset=ppo_dataset,
|
|
|
)
|
|
|
|
|
|
dataloader = trainer.dataloader
|