lxylxy123321 1 день назад
Родитель
Сommit
0d10c166f4
1 измененных файлов с 8 добавлено и 0 удалено
  1. 8 0
      backend/app/engines/text_engine.py

+ 8 - 0
backend/app/engines/text_engine.py

@@ -288,6 +288,14 @@ class TextEngine(BaseEngine):
                 _HFTrainer.__init__ = _patched_trainer_init
                 _HFTrainer._patched_kwargs = True
 
+            # 兼容:新版 transformers Trainer 调用 get_batch_samples 时多传了 device 参数
+            if not getattr(DPOTrainer, "_patched_gbs", False):
+                _orig_gbs = DPOTrainer.get_batch_samples
+                def _patched_gbs(self, epoch_iterator, num_batches, *args, **kwargs):
+                    return _orig_gbs(self, epoch_iterator, num_batches)
+                DPOTrainer.get_batch_samples = _patched_gbs
+                DPOTrainer._patched_gbs = True
+
             # 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
             ref_model = deepcopy(self._model)
             ref_model.eval()