|
@@ -604,10 +604,10 @@ class TextEngine(BaseEngine):
|
|
|
return tokenized_dataset
|
|
return tokenized_dataset
|
|
|
|
|
|
|
|
def _load_dataset_dpo(self, dataset_path: str):
|
|
def _load_dataset_dpo(self, dataset_path: str):
|
|
|
- """加载 DPO 数据集,保留 prompt/chosen/rejected 原始文本,由 DPOTrainer 内部 tokenize。"""
|
|
|
|
|
|
|
+ """加载并 tokenize DPO 数据集,生成 TRL 0.9.x 需要的 prompt/chosen/rejected_input_ids。"""
|
|
|
from datasets import Dataset as HFDataset
|
|
from datasets import Dataset as HFDataset
|
|
|
|
|
|
|
|
- data = []
|
|
|
|
|
|
|
+ raw_data = []
|
|
|
with open(dataset_path, "r", encoding="utf-8") as f:
|
|
with open(dataset_path, "r", encoding="utf-8") as f:
|
|
|
for line in f:
|
|
for line in f:
|
|
|
line = line.strip()
|
|
line = line.strip()
|
|
@@ -616,12 +616,42 @@ class TextEngine(BaseEngine):
|
|
|
prompt = item.get("prompt", item.get("instruction", item.get("input", "")))
|
|
prompt = item.get("prompt", item.get("instruction", item.get("input", "")))
|
|
|
chosen = item.get("chosen", item.get("positive", ""))
|
|
chosen = item.get("chosen", item.get("positive", ""))
|
|
|
rejected = item.get("rejected", item.get("negative", ""))
|
|
rejected = item.get("rejected", item.get("negative", ""))
|
|
|
- data.append({
|
|
|
|
|
- "prompt": str(prompt) if prompt else "",
|
|
|
|
|
- "chosen": str(chosen) if chosen else "",
|
|
|
|
|
- "rejected": str(rejected) if rejected else "",
|
|
|
|
|
- })
|
|
|
|
|
- return HFDataset.from_list(data)
|
|
|
|
|
|
|
+ if prompt and chosen and rejected:
|
|
|
|
|
+ raw_data.append({
|
|
|
|
|
+ "prompt": str(prompt),
|
|
|
|
|
+ "chosen": str(chosen),
|
|
|
|
|
+ "rejected": str(rejected),
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ max_len = 2048
|
|
|
|
|
+
|
|
|
|
|
+ def tokenize_fn(examples):
|
|
|
|
|
+ prompts = examples.get("prompt", [])
|
|
|
|
|
+ chosens = examples.get("chosen", [])
|
|
|
|
|
+ rejecteds = examples.get("rejected", [])
|
|
|
|
|
+
|
|
|
|
|
+ prompt_tok = self._tokenizer(prompts, truncation=True, max_length=max_len, padding=False)
|
|
|
|
|
+ chosen_tok = self._tokenizer(
|
|
|
|
|
+ [p + c for p, c in zip(prompts, chosens)],
|
|
|
|
|
+ truncation=True, max_length=max_len, padding=False,
|
|
|
|
|
+ )
|
|
|
|
|
+ rejected_tok = self._tokenizer(
|
|
|
|
|
+ [p + r for p, r in zip(prompts, rejecteds)],
|
|
|
|
|
+ truncation=True, max_length=max_len, padding=False,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "prompt_input_ids": prompt_tok["input_ids"],
|
|
|
|
|
+ "prompt_attention_mask": prompt_tok["attention_mask"],
|
|
|
|
|
+ "chosen_input_ids": chosen_tok["input_ids"],
|
|
|
|
|
+ "chosen_attention_mask": chosen_tok["attention_mask"],
|
|
|
|
|
+ "rejected_input_ids": rejected_tok["input_ids"],
|
|
|
|
|
+ "rejected_attention_mask": rejected_tok["attention_mask"],
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ hf_dataset = HFDataset.from_list(raw_data)
|
|
|
|
|
+ tokenized = hf_dataset.map(tokenize_fn, batched=True, remove_columns=hf_dataset.column_names)
|
|
|
|
|
+ return tokenized
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
try:
|