瀏覽代碼

修复dpo报错

lxylxy123321 19 小時之前
父節點
當前提交
a8d133631d
共有 1 個文件被更改,包括 11 次插入0 次删除
  1. 11 0
      backend/app/engines/text_engine.py

+ 11 - 0
backend/app/engines/text_engine.py

@@ -277,6 +277,17 @@ class TextEngine(BaseEngine):
                 _ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
             from trl import DPOConfig, DPOTrainer
 
+            # 兼容旧版 transformers:Trainer.__init__ 不接受 tokenizer/processing_class
+            from transformers import Trainer as _HFTrainer
+            _orig_trainer_init = _HFTrainer.__init__
+            if not getattr(_HFTrainer, "_patched_kwargs", False):
+                def _patched_trainer_init(self, *args, **kwargs):
+                    kwargs.pop("tokenizer", None)
+                    kwargs.pop("processing_class", None)
+                    _orig_trainer_init(self, *args, **kwargs)
+                _HFTrainer.__init__ = _patched_trainer_init
+                _HFTrainer._patched_kwargs = True
+
             # 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
             ref_model = deepcopy(self._model)
             ref_model.eval()