Explorar o código

修复tokenize_fn嵌套值导致tensor创建失败

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
lxylxy123321 hai 1 semana
pai
achega
4d02c36ad8
Modificáronse 1 ficheiros con 27 adicións e 8 borrados
  1. 27 8
      backend/app/engines/text_engine.py

+ 27 - 8
backend/app/engines/text_engine.py

@@ -257,18 +257,37 @@ class TextEngine(BaseEngine):
             for line in f:
                 line = line.strip()
                 if line:
-                    data.append(json.loads(line))
+                    item = json.loads(line)
+                    # 确保 prompt 和 completion 是字符串
+                    if "prompt" in item:
+                        if isinstance(item["prompt"], (list, dict)):
+                            item["prompt"] = json.dumps(item["prompt"], ensure_ascii=False)
+                        item["prompt"] = str(item["prompt"])
+                    if "completion" in item:
+                        if isinstance(item["completion"], (list, dict)):
+                            item["completion"] = json.dumps(item["completion"], ensure_ascii=False)
+                        item["completion"] = str(item["completion"])
+                    data.append(item)
 
         hf_dataset = HFDataset.from_list(data)
 
         def tokenize_fn(batch):
-            prompts = batch.get("prompt", [""] * len(data))
-            completions = batch.get("completion", [""] * len(data))
-
-            if isinstance(prompts, str):
-                prompts = [prompts]
-            if isinstance(completions, str):
-                completions = [completions]
+            # batched=True: each value in batch is a list of samples.
+            # Some individual values may themselves be lists/dicts (e.g. from
+            # Alpaca template producing list values) — coerce each to string.
+            def _to_str(v):
+                if isinstance(v, (list, dict)):
+                    return json.dumps(v, ensure_ascii=False)
+                return str(v) if v is not None else ""
+
+            raw_prompts = batch.get("prompt", [])
+            raw_completions = batch.get("completion", [])
+
+            prompts = [_to_str(v) for v in raw_prompts]
+            completions = [_to_str(v) for v in raw_completions]
+
+            if not prompts:
+                return {"input_ids": [], "attention_mask": [], "labels": []}
 
             full_texts = [f"{p}\n{c}" for p, c in zip(prompts, completions)]
             tokenized = self._tokenizer(