|
|
@@ -277,6 +277,17 @@ class TextEngine(BaseEngine):
|
|
|
_ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
|
|
|
from trl import DPOConfig, DPOTrainer
|
|
|
|
|
|
+ # 兼容旧版 transformers:Trainer.__init__ 不接受 tokenizer/processing_class
|
|
|
+ from transformers import Trainer as _HFTrainer
|
|
|
+ _orig_trainer_init = _HFTrainer.__init__
|
|
|
+ if not getattr(_HFTrainer, "_patched_kwargs", False):
|
|
|
+ def _patched_trainer_init(self, *args, **kwargs):
|
|
|
+ kwargs.pop("tokenizer", None)
|
|
|
+ kwargs.pop("processing_class", None)
|
|
|
+ _orig_trainer_init(self, *args, **kwargs)
|
|
|
+ _HFTrainer.__init__ = _patched_trainer_init
|
|
|
+ _HFTrainer._patched_kwargs = True
|
|
|
+
|
|
|
# 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
|
|
|
ref_model = deepcopy(self._model)
|
|
|
ref_model.eval()
|