|
@@ -324,27 +324,30 @@ class TextEngine(BaseEngine):
|
|
|
# 修复 Qwen tokenizer bug:tokenize 后 input_ids 末尾可能追加 None
|
|
# 修复 Qwen tokenizer bug:tokenize 后 input_ids 末尾可能追加 None
|
|
|
# 导致 DPODataCollatorWithPadding 中 torch.tensor([...None...], dtype=int64) 报错
|
|
# 导致 DPODataCollatorWithPadding 中 torch.tensor([...None...], dtype=int64) 报错
|
|
|
# 参考: https://github.com/huggingface/trl/issues/1073
|
|
# 参考: https://github.com/huggingface/trl/issues/1073
|
|
|
- if not getattr(self._tokenizer, "_patched_none_filter", False):
|
|
|
|
|
- _orig_tok_call = self._tokenizer.__class__.__call__
|
|
|
|
|
- def _call_filter_none(self_tok, *args, **kwargs):
|
|
|
|
|
- result = _orig_tok_call(self_tok, *args, **kwargs)
|
|
|
|
|
|
|
+ #
|
|
|
|
|
+ # 注意:不能用 types.MethodType 绑定到实例上,因为 Python 的特殊方法查找
|
|
|
|
|
+ # (如 obj() → type(obj).__call__(obj))会跳过实例属性,直接查类。
|
|
|
|
|
+ # 必须在类级别替换 __call__。
|
|
|
|
|
+ _tok_cls = type(self._tokenizer)
|
|
|
|
|
+ if not getattr(_tok_cls, "_patched_none_filter", False):
|
|
|
|
|
+ _orig_cls_call = _tok_cls.__call__
|
|
|
|
|
+
|
|
|
|
|
+ def _call_filter_none(cls_self, *args, **kwargs):
|
|
|
|
|
+ result = _orig_cls_call(cls_self, *args, **kwargs)
|
|
|
if isinstance(result, dict) and "input_ids" in result:
|
|
if isinstance(result, dict) and "input_ids" in result:
|
|
|
ids = result["input_ids"]
|
|
ids = result["input_ids"]
|
|
|
if isinstance(ids, list) and ids:
|
|
if isinstance(ids, list) and ids:
|
|
|
if isinstance(ids[0], list):
|
|
if isinstance(ids[0], list):
|
|
|
- # batched 输入:input_ids 是二维 list
|
|
|
|
|
result["input_ids"] = [
|
|
result["input_ids"] = [
|
|
|
[x for x in seq if x is not None] for seq in ids
|
|
[x for x in seq if x is not None] for seq in ids
|
|
|
]
|
|
]
|
|
|
else:
|
|
else:
|
|
|
- # 单条输入:input_ids 是一维 list,过滤 None
|
|
|
|
|
result["input_ids"] = [x for x in ids if x is not None]
|
|
result["input_ids"] = [x for x in ids if x is not None]
|
|
|
return result
|
|
return result
|
|
|
- # 绑定到实例(通过 type 避免 MRO 问题)
|
|
|
|
|
- import types
|
|
|
|
|
- self._tokenizer.__call__ = types.MethodType(_call_filter_none, self._tokenizer)
|
|
|
|
|
- self._tokenizer._patched_none_filter = True
|
|
|
|
|
- logger.info("Patched tokenizer to filter None values from input_ids (Qwen workaround)")
|
|
|
|
|
|
|
+
|
|
|
|
|
+ _tok_cls.__call__ = _call_filter_none
|
|
|
|
|
+ _tok_cls._patched_none_filter = True
|
|
|
|
|
+ logger.info(f"Patched {_tok_cls.__name__}.__call__ to filter None from input_ids (Qwen workaround)")
|
|
|
|
|
|
|
|
# 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
|
|
# 显式创建 reference model 并冻结,避免 AdaLora 多 adapter 冲突
|
|
|
ref_model = deepcopy(self._model)
|
|
ref_model = deepcopy(self._model)
|