|
|
@@ -400,6 +400,7 @@ class TextEngine(BaseEngine):
|
|
|
# torch.tensor([..., None], dtype=int64) 崩溃。
|
|
|
# 直接 monkey-patch DPODataCollatorWithPadding.__call__(即崩溃点),
|
|
|
# 在原始逻辑执行前清洗 features 中的 None 值。
|
|
|
+ # 注意:input_ids 和 attention_mask 必须同步截断,否则 tensor 长度不匹配。
|
|
|
try:
|
|
|
from trl.trainer.utils import DPODataCollatorWithPadding as _DC
|
|
|
if not getattr(_DC, "_patched_none_filter", False):
|
|
|
@@ -409,13 +410,21 @@ class TextEngine(BaseEngine):
|
|
|
for ex in features:
|
|
|
for k in list(ex.keys()):
|
|
|
v = ex[k]
|
|
|
- if isinstance(v, list):
|
|
|
- if v and isinstance(v[0], list):
|
|
|
- ex[k] = [[x for x in seq if x is not None] for seq in v]
|
|
|
- else:
|
|
|
- ex[k] = [x for x in v if x is not None]
|
|
|
- elif v is None:
|
|
|
- ex[k] = []
|
|
|
+ if isinstance(v, list) and v and not isinstance(v[0], list):
|
|
|
+ # 一维 list(如 input_ids / attention_mask / labels)
|
|
|
+ cleaned = [x for x in v if x is not None]
|
|
|
+ ex[k] = cleaned
|
|
|
+ # 同步截断:确保同一组序列(如 chosen_input_ids / chosen_attention_mask)长度一致
|
|
|
+ for prefix in ("prompt_", "chosen_", "rejected_"):
|
|
|
+ ids_key = f"{prefix}input_ids"
|
|
|
+ mask_key = f"{prefix}attention_mask"
|
|
|
+ labels_key = f"{prefix}labels"
|
|
|
+ if ids_key in ex and isinstance(ex[ids_key], list):
|
|
|
+ target_len = len(ex[ids_key])
|
|
|
+ if mask_key in ex and isinstance(ex[mask_key], list):
|
|
|
+ ex[mask_key] = ex[mask_key][:target_len]
|
|
|
+ if labels_key in ex and isinstance(ex[labels_key], list):
|
|
|
+ ex[labels_key] = ex[labels_key][:target_len]
|
|
|
return _orig_dc_call(self_dc, features)
|
|
|
|
|
|
_DC.__call__ = _dc_call_clean
|