|
|
@@ -296,18 +296,17 @@ class TextEngine(BaseEngine):
|
|
|
_ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
|
|
|
from trl import DPOConfig, DPOTrainer
|
|
|
|
|
|
- # 兼容 TRL 新版本 get_batch_samples 签名变化
|
|
|
- # 某些版本调用方传了 extra 参数但方法不接受,patch 为接受任意参数
|
|
|
- if hasattr(DPOTrainer, 'get_batch_samples'):
|
|
|
- _orig_get_batch_samples = DPOTrainer.get_batch_samples
|
|
|
- def _patched_get_batch_samples(self, batch, *args, **kwargs):
|
|
|
- return _orig_get_batch_samples(self, batch)
|
|
|
- DPOTrainer.get_batch_samples = _patched_get_batch_samples
|
|
|
+ # 兼容 transformers 4.46.0+ 与 TRL 的 get_batch_samples 签名冲突:
|
|
|
+ # transformers.Trainer.get_batch_samples(epoch_iterator, num_batches)
|
|
|
+ # DPOTrainer.get_batch_samples(model, batch) — 签名不同导致调用崩溃
|
|
|
+ # 方案:用基类 Trainer 的实现替换掉 DPOTrainer 的不兼容覆盖
|
|
|
+ from transformers import Trainer as _HFTrainer
|
|
|
+ if hasattr(DPOTrainer, 'get_batch_samples') and hasattr(_HFTrainer, 'get_batch_samples'):
|
|
|
+ DPOTrainer.get_batch_samples = _HFTrainer.get_batch_samples
|
|
|
|
|
|
# 兼容:当前版本 transformers.Trainer.__init__ 不接受 tokenizer/processing_class,
|
|
|
# 但 DPOTrainer 内部会将这些参数透传给 Trainer,导致 TypeError。
|
|
|
# 拦截 Trainer.__init__,弹出不认识的 kwargs。
|
|
|
- from transformers import Trainer as _HFTrainer
|
|
|
if not getattr(_HFTrainer, "_patched_kwargs", False):
|
|
|
_orig_trainer_init = _HFTrainer.__init__
|
|
|
def _patched_trainer_init(self, *args, **kwargs):
|