|
|
@@ -296,6 +296,14 @@ class TextEngine(BaseEngine):
|
|
|
_ma.MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = {}
|
|
|
from trl import DPOConfig, DPOTrainer
|
|
|
|
|
|
+ # 兼容 TRL 新版本 get_batch_samples 签名变化
|
|
|
+ # 某些版本调用方传了 extra 参数但方法不接受,patch 为接受任意参数
|
|
|
+ if hasattr(DPOTrainer, 'get_batch_samples'):
|
|
|
+ _orig_get_batch_samples = DPOTrainer.get_batch_samples
|
|
|
+ def _patched_get_batch_samples(self, batch, *args, **kwargs):
|
|
|
+ return _orig_get_batch_samples(self, batch)
|
|
|
+ DPOTrainer.get_batch_samples = _patched_get_batch_samples
|
|
|
+
|
|
|
# 兼容:当前版本 transformers.Trainer.__init__ 不接受 tokenizer/processing_class,
|
|
|
# 但 DPOTrainer 内部会将这些参数透传给 Trainer,导致 TypeError。
|
|
|
# 拦截 Trainer.__init__,弹出不认识的 kwargs。
|