Explorar el Código

修复dpo报错

lxylxy123321 hace 17 horas
padre
commit
5fa306dcd4
Se han modificado 1 ficheros con 8 adiciones y 0 borrados
  1. 8 0
      backend/app/engines/text_engine.py

+ 8 - 0
backend/app/engines/text_engine.py

@@ -296,6 +296,14 @@ class TextEngine(BaseEngine):
                 _ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
             from trl import DPOConfig, DPOTrainer
 
+            # 兼容 TRL 新版本 get_batch_samples 签名变化
+            # 某些版本调用方传了 extra 参数但方法不接受,patch 为接受任意参数
+            if hasattr(DPOTrainer, 'get_batch_samples'):
+                _orig_get_batch_samples = DPOTrainer.get_batch_samples
+                def _patched_get_batch_samples(self, batch, *args, **kwargs):
+                    return _orig_get_batch_samples(self, batch)
+                DPOTrainer.get_batch_samples = _patched_get_batch_samples
+
             # 兼容:当前版本 transformers.Trainer.__init__ 不接受 tokenizer/processing_class,
             # 但 DPOTrainer 内部会将这些参数透传给 Trainer,导致 TypeError。
             # 拦截 Trainer.__init__,弹出不认识的 kwargs。