|
|
@@ -334,13 +334,34 @@ class TextEngine(BaseEngine):
|
|
|
max_prompt_length=max_seq_length // 2,
|
|
|
)
|
|
|
|
|
|
- trainer = DPOTrainer(
|
|
|
+ # 自动检测 DPOTrainer 接受 tokenizer 的参数名(不同 TRL 版本不同)
|
|
|
+ import inspect
|
|
|
+ import trl as _trl
|
|
|
+ _dpo_sig = inspect.signature(DPOTrainer.__init__)
|
|
|
+ _dpo_params = set(_dpo_sig.parameters.keys())
|
|
|
+ logger.info(f"TRL version: {_trl.__version__}, DPOTrainer params: {sorted(_dpo_params)}")
|
|
|
+ if "processing_class" in _dpo_params:
|
|
|
+ _tok_kw = "processing_class"
|
|
|
+ elif "tokenizer" in _dpo_params:
|
|
|
+ _tok_kw = "tokenizer"
|
|
|
+ else:
|
|
|
+ _tok_kw = None
|
|
|
+ logger.warning(f"DPOTrainer 不接受 tokenizer 参数,可用参数: {sorted(_dpo_params)}")
|
|
|
+
|
|
|
+ _dpo_trainer_kwargs = dict(
|
|
|
model=self._model,
|
|
|
ref_model=ref_model,
|
|
|
args=DPOConfig(**base_trainer_kwargs),
|
|
|
train_dataset=dataset,
|
|
|
- processing_class=self._tokenizer,
|
|
|
)
|
|
|
+ if _tok_kw:
|
|
|
+ _dpo_trainer_kwargs[_tok_kw] = self._tokenizer
|
|
|
+
|
|
|
+ trainer = DPOTrainer(**_dpo_trainer_kwargs)
|
|
|
+ # 如果 DPOTrainer 不直接接受 tokenizer 参数,手动设置
|
|
|
+ if _tok_kw is None:
|
|
|
+ trainer.tokenizer = self._tokenizer
|
|
|
+ trainer.processing_class = self._tokenizer
|
|
|
|
|
|
# 修复 Qwen tokenizer bug(TRL #1073):
|
|
|
# tokenize 后 input_ids 末尾可能含 None,导致 collator 中
|