|
|
@@ -272,9 +272,6 @@ class TextEngine(BaseEngine):
|
|
|
hf_dataset = HFDataset.from_list(data)
|
|
|
|
|
|
def tokenize_fn(batch):
|
|
|
- # batched=True: each value in batch is a list of samples.
|
|
|
- # Some individual values may themselves be lists/dicts (e.g. from
|
|
|
- # Alpaca template producing list values) — coerce each to string.
|
|
|
def _to_str(v):
|
|
|
if isinstance(v, (list, dict)):
|
|
|
return json.dumps(v, ensure_ascii=False)
|
|
|
@@ -296,7 +293,12 @@ class TextEngine(BaseEngine):
|
|
|
tokenized["labels"] = list(tokenized["input_ids"])
|
|
|
return tokenized
|
|
|
|
|
|
- return hf_dataset.map(tokenize_fn, batched=True)
|
|
|
+ tokenized_dataset = hf_dataset.map(
|
|
|
+ tokenize_fn,
|
|
|
+ batched=True,
|
|
|
+ remove_columns=["prompt", "completion"],
|
|
|
+ )
|
|
|
+ return tokenized_dataset
|
|
|
|
|
|
|
|
|
class _ProgressCallback:
|