|
|
@@ -318,8 +318,39 @@ class TextEngine(BaseEngine):
|
|
|
reward_model_path = training_args.get("reward_model_path")
|
|
|
reward_type = training_args.get("reward_type", "heuristic")
|
|
|
|
|
|
- # PPO 专用:仅 tokenize prompt
|
|
|
- ppo_dataset = self._tokenize_dataset_ppo(dataset_path, max_seq_length, response_length)
|
|
|
+ import inspect
|
|
|
+
|
|
|
+ # 检测 PPOTrainer 版本(新版无 step 方法,使用标准 Trainer API)
|
|
|
+ trainer_sig = inspect.signature(PPOTrainer.__init__)
|
|
|
+ trainer_params = set(trainer_sig.parameters.keys())
|
|
|
+ is_new_ppo = "step" not in dir(PPOTrainer)
|
|
|
+
|
|
|
+ # ---- 准备数据集 ----
|
|
|
+ if is_new_ppo:
|
|
|
+ # 新版 TRL (1.4.0+):PPOTrainer 自己处理 tokenization 和生成,
|
|
|
+ # 需要传入带 "prompt" 列的原始文本数据集
|
|
|
+ import json as _json
|
|
|
+ from datasets import Dataset as HFDataset
|
|
|
+
|
|
|
+ raw_data = []
|
|
|
+ with open(dataset_path, "r", encoding="utf-8") as f:
|
|
|
+ for line in f:
|
|
|
+ line = line.strip()
|
|
|
+ if line:
|
|
|
+ item = _json.loads(line)
|
|
|
+ if "prompt" not in item:
|
|
|
+ item["prompt"] = item.get("question", item.get("query", item.get("text", item.get("input", ""))))
|
|
|
+ if isinstance(item["prompt"], (list, dict)):
|
|
|
+ item["prompt"] = _json.dumps(item["prompt"], ensure_ascii=False)
|
|
|
+ item["prompt"] = str(item["prompt"])
|
|
|
+ raw_data.append(item)
|
|
|
+
|
|
|
+ ppo_dataset = HFDataset.from_list(raw_data)
|
|
|
+ logger.info(f"新版 PPOTrainer: 加载原始文本数据集,共 {len(ppo_dataset)} 条")
|
|
|
+ else:
|
|
|
+ # 旧版 TRL:需要预先 tokenize
|
|
|
+ ppo_dataset = self._tokenize_dataset_ppo(dataset_path, max_seq_length, response_length)
|
|
|
+ logger.info(f"旧版 PPOTrainer: 加载 tokenize 数据集,共 {len(ppo_dataset)} 条")
|
|
|
|
|
|
# Reference 模型(冻结,用于 KL 惩罚)
|
|
|
ref_model = deepcopy(self._model)
|
|
|
@@ -328,8 +359,6 @@ class TextEngine(BaseEngine):
|
|
|
param.requires_grad = False
|
|
|
|
|
|
# 兼容不同版本的 TRL PPOConfig 参数名变化
|
|
|
- # TRL 0.12+ 中 ppo_epochs -> num_ppo_epochs, kl_ctl -> init_kl_coef, vf_coef 被移除
|
|
|
- import inspect
|
|
|
ppo_config_sig = inspect.signature(PPOConfig.__init__)
|
|
|
ppo_config_params = set(ppo_config_sig.parameters.keys())
|
|
|
|