|
@@ -305,7 +305,11 @@ class TextEngine(BaseEngine):
|
|
|
from copy import deepcopy
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
import torch
|
|
import torch
|
|
|
- from trl import PPOConfig, PPOTrainer
|
|
|
|
|
|
|
+ # 兼容新版 TRL(PPO 移到了 experimental 子模块)和旧版 TRL
|
|
|
|
|
+ try:
|
|
|
|
|
+ from trl.experimental.ppo import PPOConfig, PPOTrainer
|
|
|
|
|
+ except ImportError:
|
|
|
|
|
+ from trl import PPOConfig, PPOTrainer
|
|
|
|
|
|
|
|
ppo_epochs = training_args.get("ppo_epochs", 4)
|
|
ppo_epochs = training_args.get("ppo_epochs", 4)
|
|
|
vf_coef = training_args.get("vf_coef", 0.1)
|
|
vf_coef = training_args.get("vf_coef", 0.1)
|