|
|
@@ -328,9 +328,11 @@ class TextEngine(BaseEngine):
|
|
|
param.requires_grad = False
|
|
|
|
|
|
# TRL 0.9.x PPOConfig 只接受 PPO 专用参数,不支持 HuggingFace Trainer 参数
|
|
|
+ # mini_batch_size 必须满足:batch_size % (mini_batch_size * gradient_accumulation_steps) == 0
|
|
|
ppo_config = PPOConfig(
|
|
|
learning_rate=learning_rate,
|
|
|
batch_size=batch_size,
|
|
|
+ mini_batch_size=1,
|
|
|
gradient_accumulation_steps=gradient_accumulation,
|
|
|
ppo_epochs=ppo_epochs,
|
|
|
vf_coef=vf_coef,
|