소스 검색

修复dpo报错

lxylxy123321 1 시간 전
부모
커밋
852c668625
1개의 변경된 파일7개의 추가작업 그리고 8개의 파일을 삭제
  1. 7 8
      backend/app/engines/text_engine.py

+ 7 - 8
backend/app/engines/text_engine.py

@@ -296,18 +296,17 @@ class TextEngine(BaseEngine):
                 _ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
             from trl import DPOConfig, DPOTrainer
 
-            # 兼容 TRL 新版本 get_batch_samples 签名变化
-            # 某些版本调用方传了 extra 参数但方法不接受,patch 为接受任意参数
-            if hasattr(DPOTrainer, 'get_batch_samples'):
-                _orig_get_batch_samples = DPOTrainer.get_batch_samples
-                def _patched_get_batch_samples(self, batch, *args, **kwargs):
-                    return _orig_get_batch_samples(self, batch)
-                DPOTrainer.get_batch_samples = _patched_get_batch_samples
+            # 兼容 transformers 4.46.0+ 与 TRL 的 get_batch_samples 签名冲突:
+            # transformers.Trainer.get_batch_samples(epoch_iterator, num_batches)
+            # DPOTrainer.get_batch_samples(model, batch) — 签名不同导致调用崩溃
+            # 方案:用基类 Trainer 的实现替换掉 DPOTrainer 的不兼容覆盖
+            from transformers import Trainer as _HFTrainer
+            if hasattr(DPOTrainer, 'get_batch_samples') and hasattr(_HFTrainer, 'get_batch_samples'):
+                DPOTrainer.get_batch_samples = _HFTrainer.get_batch_samples
 
             # 兼容:当前版本 transformers.Trainer.__init__ 不接受 tokenizer/processing_class,
             # 但 DPOTrainer 内部会将这些参数透传给 Trainer,导致 TypeError。
             # 拦截 Trainer.__init__,弹出不认识的 kwargs。
-            from transformers import Trainer as _HFTrainer
             if not getattr(_HFTrainer, "_patched_kwargs", False):
                 _orig_trainer_init = _HFTrainer.__init__
                 def _patched_trainer_init(self, *args, **kwargs):