Procházet zdrojové kódy

修复dpo训练tensor报错

lxylxy123321 před 17 hodinami
rodič
revize
4493b662da
1 změnil soubory, kde provedl 16 přidání a 7 odebrání
  1. 16 7
      backend/app/engines/text_engine.py

+ 16 - 7
backend/app/engines/text_engine.py

@@ -400,6 +400,7 @@ class TextEngine(BaseEngine):
             # torch.tensor([..., None], dtype=int64) 崩溃。
             # 直接 monkey-patch DPODataCollatorWithPadding.__call__(即崩溃点),
             # 在原始逻辑执行前清洗 features 中的 None 值。
+            # 注意:input_ids 和 attention_mask 必须同步截断,否则 tensor 长度不匹配。
             try:
                 from trl.trainer.utils import DPODataCollatorWithPadding as _DC
                 if not getattr(_DC, "_patched_none_filter", False):
@@ -409,13 +410,21 @@ class TextEngine(BaseEngine):
                         for ex in features:
                             for k in list(ex.keys()):
                                 v = ex[k]
-                                if isinstance(v, list):
-                                    if v and isinstance(v[0], list):
-                                        ex[k] = [[x for x in seq if x is not None] for seq in v]
-                                    else:
-                                        ex[k] = [x for x in v if x is not None]
-                                elif v is None:
-                                    ex[k] = []
+                                if isinstance(v, list) and v and not isinstance(v[0], list):
+                                    # 一维 list(如 input_ids / attention_mask / labels)
+                                    cleaned = [x for x in v if x is not None]
+                                    ex[k] = cleaned
+                            # 同步截断:确保同一组序列(如 chosen_input_ids / chosen_attention_mask)长度一致
+                            for prefix in ("prompt_", "chosen_", "rejected_"):
+                                ids_key = f"{prefix}input_ids"
+                                mask_key = f"{prefix}attention_mask"
+                                labels_key = f"{prefix}labels"
+                                if ids_key in ex and isinstance(ex[ids_key], list):
+                                    target_len = len(ex[ids_key])
+                                    if mask_key in ex and isinstance(ex[mask_key], list):
+                                        ex[mask_key] = ex[mask_key][:target_len]
+                                    if labels_key in ex and isinstance(ex[labels_key], list):
+                                        ex[labels_key] = ex[labels_key][:target_len]
                         return _orig_dc_call(self_dc, features)
 
                     _DC.__call__ = _dc_call_clean