lxylxy123321 пре 1 дан
родитељ
комит
0f6bc04dad
1 измењених фајлова са 23 додато и 2 уклоњено
  1. 23 2
      backend/app/engines/text_engine.py

+ 23 - 2
backend/app/engines/text_engine.py

@@ -334,13 +334,34 @@ class TextEngine(BaseEngine):
                 max_prompt_length=max_seq_length // 2,
             )
 
-            trainer = DPOTrainer(
+            # 自动检测 DPOTrainer 接受 tokenizer 的参数名(不同 TRL 版本不同)
+            import inspect
+            import trl as _trl
+            _dpo_sig = inspect.signature(DPOTrainer.__init__)
+            _dpo_params = set(_dpo_sig.parameters.keys())
+            logger.info(f"TRL version: {_trl.__version__}, DPOTrainer params: {sorted(_dpo_params)}")
+            if "processing_class" in _dpo_params:
+                _tok_kw = "processing_class"
+            elif "tokenizer" in _dpo_params:
+                _tok_kw = "tokenizer"
+            else:
+                _tok_kw = None
+                logger.warning(f"DPOTrainer 不接受 tokenizer 参数,可用参数: {sorted(_dpo_params)}")
+
+            _dpo_trainer_kwargs = dict(
                 model=self._model,
                 ref_model=ref_model,
                 args=DPOConfig(**base_trainer_kwargs),
                 train_dataset=dataset,
-                processing_class=self._tokenizer,
             )
+            if _tok_kw:
+                _dpo_trainer_kwargs[_tok_kw] = self._tokenizer
+
+            trainer = DPOTrainer(**_dpo_trainer_kwargs)
+            # 如果 DPOTrainer 不直接接受 tokenizer 参数,手动设置
+            if _tok_kw is None:
+                trainer.tokenizer = self._tokenizer
+                trainer.processing_class = self._tokenizer
 
             # 修复 Qwen tokenizer bug(TRL #1073):
             # tokenize 后 input_ids 末尾可能含 None,导致 collator 中