|
@@ -323,7 +323,8 @@ class TextEngine(BaseEngine):
|
|
|
self._model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
|
self._model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
|
|
self._model, peft_config=peft_config,
|
|
self._model, peft_config=peft_config,
|
|
|
)
|
|
)
|
|
|
- self._model.print_trainable_parameters()
|
|
|
|
|
|
|
+ if hasattr(self._model, "print_trainable_parameters"):
|
|
|
|
|
+ self._model.print_trainable_parameters()
|
|
|
|
|
|
|
|
# TRL 0.9.x PPOConfig 只接受 PPO 专用参数,不支持 HuggingFace Trainer 参数
|
|
# TRL 0.9.x PPOConfig 只接受 PPO 专用参数,不支持 HuggingFace Trainer 参数
|
|
|
# mini_batch_size 必须满足:batch_size % (mini_batch_size * gradient_accumulation_steps) == 0
|
|
# mini_batch_size 必须满足:batch_size % (mini_batch_size * gradient_accumulation_steps) == 0
|