Parcourir la source

修复dpo问题

lxylxy123321 il y a 18 heures
Parent
commit
af57073c98
1 fichiers modifiés avec 14 ajouts et 11 suppressions
  1. 14 11
      backend/app/engines/text_engine.py

+ 14 - 11
backend/app/engines/text_engine.py

@@ -324,27 +324,30 @@ class TextEngine(BaseEngine):
             # 修复 Qwen tokenizer bug:tokenize 后 input_ids 末尾可能追加 None
             # 导致 DPODataCollatorWithPadding 中 torch.tensor([...None...], dtype=int64) 报错
             # 参考: https://github.com/huggingface/trl/issues/1073
-            if not getattr(self._tokenizer, "_patched_none_filter", False):
-                _orig_tok_call = self._tokenizer.__class__.__call__
-                def _call_filter_none(self_tok, *args, **kwargs):
-                    result = _orig_tok_call(self_tok, *args, **kwargs)
+            #
+            # 注意:不能用 types.MethodType 绑定到实例上,因为 Python 的特殊方法查找
+            # (如 obj() → type(obj).__call__(obj))会跳过实例属性,直接查类。
+            # 必须在类级别替换 __call__。
+            _tok_cls = type(self._tokenizer)
+            if not getattr(_tok_cls, "_patched_none_filter", False):
+                _orig_cls_call = _tok_cls.__call__
+
+                def _call_filter_none(cls_self, *args, **kwargs):
+                    result = _orig_cls_call(cls_self, *args, **kwargs)
                     if isinstance(result, dict) and "input_ids" in result:
                         ids = result["input_ids"]
                         if isinstance(ids, list) and ids:
                             if isinstance(ids[0], list):
-                                # batched 输入:input_ids 是二维 list
                                 result["input_ids"] = [
                                     [x for x in seq if x is not None] for seq in ids
                                 ]
                             else:
-                                # 单条输入:input_ids 是一维 list,过滤 None
                                 result["input_ids"] = [x for x in ids if x is not None]
                     return result
-                # 绑定到实例(通过 type 避免 MRO 问题)
-                import types
-                self._tokenizer.__call__ = types.MethodType(_call_filter_none, self._tokenizer)
-                self._tokenizer._patched_none_filter = True
-                logger.info("Patched tokenizer to filter None values from input_ids (Qwen workaround)")
+
+                _tok_cls.__call__ = _call_filter_none
+                _tok_cls._patched_none_filter = True
+                logger.info(f"Patched {_tok_cls.__name__}.__call__ to filter None from input_ids (Qwen workaround)")
 
             # 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
             ref_model = deepcopy(self._model)