|
|
@@ -288,6 +288,14 @@ class TextEngine(BaseEngine):
|
|
|
_HFTrainer.__init__ = _patched_trainer_init
|
|
|
_HFTrainer._patched_kwargs = True
|
|
|
|
|
|
+ # 兼容:新版 transformers Trainer 调用 get_batch_samples 时多传了 device 参数
|
|
|
+ if not getattr(DPOTrainer, "_patched_gbs", False):
|
|
|
+ _orig_gbs = DPOTrainer.get_batch_samples
|
|
|
+ def _patched_gbs(self, epoch_iterator, num_batches, *args, **kwargs):
|
|
|
+ return _orig_gbs(self, epoch_iterator, num_batches)
|
|
|
+ DPOTrainer.get_batch_samples = _patched_gbs
|
|
|
+ DPOTrainer._patched_gbs = True
|
|
|
+
|
|
|
# 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
|
|
|
ref_model = deepcopy(self._model)
|
|
|
ref_model.eval()
|