|
|
@@ -374,82 +374,127 @@ class TextEngine(BaseEngine):
|
|
|
|
|
|
ppo_config = PPOConfig(**ppo_config_kwargs)
|
|
|
|
|
|
- trainer = PPOTrainer(
|
|
|
- config=ppo_config,
|
|
|
+ # 兼容不同版本的 PPOTrainer 参数名(config vs args)
|
|
|
+ trainer_sig = inspect.signature(PPOTrainer.__init__)
|
|
|
+ trainer_params = set(trainer_sig.parameters.keys())
|
|
|
+
|
|
|
+ # ---- 加载奖励模型 ----
|
|
|
+ reward_model = None
|
|
|
+ if reward_type == "model" and reward_model_path:
|
|
|
+ from transformers import AutoModelForSequenceClassification
|
|
|
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
|
|
|
+ reward_model_path, device_map={"": 0}
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # 启发式奖励:包装成 nn.Module 以兼容新版 PPOTrainer 的 reward_model 参数
|
|
|
+ class _HeuristicRewardModel(torch.nn.Module):
|
|
|
+ """将启发式奖励函数包装为 reward model,供新版 PPOTrainer 使用。"""
|
|
|
+
|
|
|
+ def __init__(self, tokenizer, reward_func):
|
|
|
+ super().__init__()
|
|
|
+ self.tokenizer = tokenizer
|
|
|
+ self.reward_func = reward_func
|
|
|
+ # 需要一个 dummy 参数让 Trainer 识别为有效的 Module
|
|
|
+ self._dummy = torch.nn.Parameter(torch.zeros(1))
|
|
|
+
|
|
|
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
|
|
|
+ texts = [
|
|
|
+ self.tokenizer.decode(ids, skip_special_tokens=True)
|
|
|
+ for ids in input_ids
|
|
|
+ ]
|
|
|
+ rewards = self.reward_func(texts, texts)
|
|
|
+ return type("RewardOutput", (), {
|
|
|
+ "logits": torch.tensor(rewards, dtype=torch.float32, device=input_ids.device).unsqueeze(-1)
|
|
|
+ })()
|
|
|
+
|
|
|
+ reward_model = _HeuristicRewardModel(self._tokenizer, _compute_heuristic_reward)
|
|
|
+
|
|
|
+ # ---- 构建 PPOTrainer ----
|
|
|
+ trainer_kwargs = dict(
|
|
|
model=self._model,
|
|
|
ref_model=ref_model,
|
|
|
processing_class=self._tokenizer,
|
|
|
train_dataset=ppo_dataset,
|
|
|
)
|
|
|
|
|
|
- dataloader = trainer.dataloader
|
|
|
- total_steps = len(dataloader) * epochs
|
|
|
- step_count = 0
|
|
|
-
|
|
|
- for epoch in range(epochs):
|
|
|
- for batch in dataloader:
|
|
|
- step_count += 1
|
|
|
- query_tensors = batch["input_ids"]
|
|
|
-
|
|
|
- # 生成回答
|
|
|
- response_tensors = []
|
|
|
- for query in query_tensors:
|
|
|
- query_tensor = torch.tensor(query).unsqueeze(0).to(self._model.device)
|
|
|
- gen_output = self._model.generate(
|
|
|
- query_tensor,
|
|
|
- max_new_tokens=response_length,
|
|
|
- do_sample=True,
|
|
|
- top_p=0.9,
|
|
|
- temperature=0.7,
|
|
|
- )
|
|
|
- response_tensors.append(gen_output[0][query_tensor.shape[-1]:])
|
|
|
-
|
|
|
- # 解码文本用于奖励计算
|
|
|
- responses_text = [
|
|
|
- self._tokenizer.decode(r, skip_special_tokens=True)
|
|
|
- for r in response_tensors
|
|
|
- ]
|
|
|
- prompts_text = [
|
|
|
- self._tokenizer.decode(q, skip_special_tokens=True)
|
|
|
- for q in query_tensors
|
|
|
- ]
|
|
|
-
|
|
|
- # 计算奖励
|
|
|
- if reward_type == "model" and reward_model_path:
|
|
|
- from transformers import AutoModelForSequenceClassification
|
|
|
-
|
|
|
- reward_model = AutoModelForSequenceClassification.from_pretrained(
|
|
|
- reward_model_path, device_map={"": 0}
|
|
|
- )
|
|
|
- reward_inputs = [p + r for p, r in zip(prompts_text, responses_text)]
|
|
|
- tokenized = self._tokenizer(
|
|
|
- reward_inputs, return_tensors="pt", padding=True, truncation=True
|
|
|
- ).to(self._model.device)
|
|
|
- with torch.no_grad():
|
|
|
- rewards = reward_model(**tokenized).logits.squeeze(-1).tolist()
|
|
|
- else:
|
|
|
- rewards = _compute_heuristic_reward(prompts_text, responses_text)
|
|
|
-
|
|
|
- reward_tensors = [torch.tensor(r, device=self._model.device) for r in rewards]
|
|
|
-
|
|
|
- # PPO 更新
|
|
|
- stats = trainer.step(query_tensors, response_tensors, reward_tensors)
|
|
|
-
|
|
|
- # 报告进度
|
|
|
- if step_count % 10 == 0:
|
|
|
- for cb in (all_callbacks or []):
|
|
|
- if hasattr(cb, "on_log"):
|
|
|
- cb.on_log(
|
|
|
- SimpleNamespace(),
|
|
|
- SimpleNamespace(
|
|
|
- epoch=epoch, global_step=step_count, max_steps=total_steps
|
|
|
- ),
|
|
|
- None,
|
|
|
- logs={
|
|
|
- "loss": stats.get("ppo/loss/total", 0),
|
|
|
- "learning_rate": stats.get("ppo/learning_rate", learning_rate),
|
|
|
- },
|
|
|
- )
|
|
|
+ # 新版叫 args,旧版叫 config
|
|
|
+ if "args" in trainer_params:
|
|
|
+ trainer_kwargs["args"] = ppo_config
|
|
|
+ elif "config" in trainer_params:
|
|
|
+ trainer_kwargs["config"] = ppo_config
|
|
|
+
|
|
|
+ # 新版 PPOTrainer 支持 reward_model 参数
|
|
|
+ if "reward_model" in trainer_params:
|
|
|
+ trainer_kwargs["reward_model"] = reward_model
|
|
|
+
|
|
|
+ logger.info(f"PPOTrainer 可用参数: {sorted(trainer_params)}")
|
|
|
+ trainer = PPOTrainer(**trainer_kwargs)
|
|
|
+
|
|
|
+ # ---- 训练 ----
|
|
|
+ if hasattr(trainer, "step"):
|
|
|
+ # 旧版 TRL:手动循环 + trainer.step()
|
|
|
+ dataloader = trainer.dataloader
|
|
|
+ total_steps = len(dataloader) * epochs
|
|
|
+ step_count = 0
|
|
|
+
|
|
|
+ for epoch in range(epochs):
|
|
|
+ for batch in dataloader:
|
|
|
+ step_count += 1
|
|
|
+ query_tensors = batch["input_ids"]
|
|
|
+
|
|
|
+ response_tensors = []
|
|
|
+ for query in query_tensors:
|
|
|
+ query_tensor = torch.tensor(query).unsqueeze(0).to(self._model.device)
|
|
|
+ gen_output = self._model.generate(
|
|
|
+ query_tensor,
|
|
|
+ max_new_tokens=response_length,
|
|
|
+ do_sample=True,
|
|
|
+ top_p=0.9,
|
|
|
+ temperature=0.7,
|
|
|
+ )
|
|
|
+ response_tensors.append(gen_output[0][query_tensor.shape[-1]:])
|
|
|
+
|
|
|
+ responses_text = [
|
|
|
+ self._tokenizer.decode(r, skip_special_tokens=True)
|
|
|
+ for r in response_tensors
|
|
|
+ ]
|
|
|
+ prompts_text = [
|
|
|
+ self._tokenizer.decode(q, skip_special_tokens=True)
|
|
|
+ for q in query_tensors
|
|
|
+ ]
|
|
|
+
|
|
|
+ if reward_type == "model" and reward_model_path:
|
|
|
+ reward_inputs = [p + r for p, r in zip(prompts_text, responses_text)]
|
|
|
+ tokenized = self._tokenizer(
|
|
|
+ reward_inputs, return_tensors="pt", padding=True, truncation=True
|
|
|
+ ).to(self._model.device)
|
|
|
+ with torch.no_grad():
|
|
|
+ rewards = reward_model(**tokenized).logits.squeeze(-1).tolist()
|
|
|
+ else:
|
|
|
+ rewards = _compute_heuristic_reward(prompts_text, responses_text)
|
|
|
+
|
|
|
+ reward_tensors = [torch.tensor(r, device=self._model.device) for r in rewards]
|
|
|
+ stats = trainer.step(query_tensors, response_tensors, reward_tensors)
|
|
|
+
|
|
|
+ if step_count % 10 == 0:
|
|
|
+ for cb in (all_callbacks or []):
|
|
|
+ if hasattr(cb, "on_log"):
|
|
|
+ cb.on_log(
|
|
|
+ SimpleNamespace(),
|
|
|
+ SimpleNamespace(
|
|
|
+ epoch=epoch, global_step=step_count, max_steps=total_steps
|
|
|
+ ),
|
|
|
+ None,
|
|
|
+ logs={
|
|
|
+ "loss": stats.get("ppo/loss/total", 0),
|
|
|
+ "learning_rate": stats.get("ppo/learning_rate", learning_rate),
|
|
|
+ },
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # 新版 TRL (>=1.0):标准 Trainer API,直接 train()
|
|
|
+ for cb in (all_callbacks or []):
|
|
|
+ trainer.add_callback(cb)
|
|
|
+ trainer.train()
|
|
|
|
|
|
self._model.save_pretrained(output_dir)
|
|
|
self._tokenizer.save_pretrained(output_dir)
|