__init__.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. """数据预处理器:将不同格式的数据集转换为训练所需格式。"""
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. # 常见列名映射
  6. _PROMPT_COLUMNS = {"prompt", "question", "query", "text", "input"}
  7. _COMPLETION_COLUMNS = {"completion", "answer", "response", "target", "output"}
  8. _ALPACA_COLUMNS = {"instruction", "input", "output"}
  9. _SHAREGPT_COLUMNS = {"conversations"}
  10. _DPO_COLUMNS = {"prompt", "chosen", "rejected"}
  11. def apply_auto_template(item: dict, column_map: dict[str, str]) -> dict:
  12. """Auto 模板:根据实际列名自动映射。"""
  13. prompt = ""
  14. completion = ""
  15. # 先找 prompt 列
  16. for col in column_map.get("prompt_candidates", []):
  17. if col in item and item[col] is not None:
  18. prompt = str(item[col])
  19. break
  20. # 再找 completion 列
  21. for col in column_map.get("completion_candidates", []):
  22. if col in item and item[col] is not None:
  23. completion = str(item[col])
  24. break
  25. return {"prompt": prompt, "completion": completion}
  26. def _detect_columns(raw_data: list[dict]) -> dict[str, list[str]]:
  27. """扫描数据集前几行,自动检测列名并返回映射关系。"""
  28. if not raw_data:
  29. return {"prompt_candidates": [], "completion_candidates": [], "template": "raw"}
  30. # 取前 5 行扫描
  31. sample = raw_data[:5]
  32. all_columns = set()
  33. for item in sample:
  34. all_columns.update(item.keys())
  35. lower_cols = {c.lower().strip(): c for c in all_columns}
  36. # 检测模板类型
  37. if _SHAREGPT_COLUMNS & all_columns:
  38. return {"template": "sharegpt"}
  39. if _DPO_COLUMNS & all_columns:
  40. return {"template": "dpo"}
  41. if _ALPACA_COLUMNS & all_columns:
  42. return {"template": "alpaca"}
  43. # 查找 prompt 和 completion 候选列
  44. prompt_candidates = [lower_cols.get(c) for c in ["prompt", "question", "query", "text", "input"] if lower_cols.get(c)]
  45. completion_candidates = [lower_cols.get(c) for c in ["completion", "answer", "response", "target", "output"] if lower_cols.get(c)]
  46. return {
  47. "template": "auto",
  48. "prompt_candidates": prompt_candidates,
  49. "completion_candidates": completion_candidates,
  50. }
  51. def apply_alpaca_template(item: dict) -> dict:
  52. """Alpaca 模板: instruction + input -> output。"""
  53. instruction = item.get("instruction", "")
  54. input_text = item.get("input", "")
  55. output = item.get("output", "")
  56. # 确保所有值为字符串
  57. instruction = str(instruction) if instruction is not None else ""
  58. input_text = str(input_text) if input_text is not None else ""
  59. output = str(output) if output is not None else ""
  60. prompt = f"{instruction}\n\n{input_text}" if input_text else instruction
  61. return {"prompt": prompt, "completion": output}
  62. def apply_sharegpt_template(item: dict) -> dict:
  63. """ShareGPT 模板: conversations list -> formatted prompt + completion。"""
  64. conversations = item.get("conversations", [])
  65. if len(conversations) < 2:
  66. return {"prompt": "", "completion": ""}
  67. prompt_parts = []
  68. completion = ""
  69. for i, turn in enumerate(conversations):
  70. role = turn.get("from", turn.get("role", "human"))
  71. content = turn.get("value", turn.get("content", ""))
  72. if i == 0:
  73. prompt_parts.append(content)
  74. elif i == 1:
  75. completion = content
  76. break
  77. else:
  78. prompt_parts.append(f"{role}: {content}")
  79. prompt = "\n".join(prompt_parts)
  80. return {"prompt": prompt, "completion": completion}
  81. def apply_raw_template(item: dict) -> dict:
  82. """Raw 模板: 直接读取 prompt/instruction/text 和 completion/output 字段。"""
  83. prompt = item.get("prompt", item.get("instruction", item.get("text", item.get("input", item.get("question", item.get("query", ""))))))
  84. completion = item.get("completion", item.get("output", item.get("target", item.get("answer", item.get("response", "")))))
  85. return {"prompt": str(prompt), "completion": str(completion)}
  86. def apply_dpo_template(item: dict) -> dict:
  87. """DPO 模板: prompt + chosen + rejected。"""
  88. prompt = item.get("prompt", item.get("instruction", item.get("input", item.get("question", item.get("query", "")))))
  89. chosen = item.get("chosen", item.get("positive", item.get("answer", "")))
  90. rejected = item.get("rejected", item.get("negative", ""))
  91. # 先处理列表型字段(如 ShareGPT messages 列表),再转字符串
  92. # 必须在 str() 之前判断,否则 list 会被 str() 转成 Python repr 字符串
  93. if isinstance(prompt, list):
  94. prompt = "\n".join(
  95. str(x.get("content", x)) if isinstance(x, dict) else str(x)
  96. for x in prompt
  97. if x is not None
  98. )
  99. elif prompt is not None:
  100. prompt = str(prompt)
  101. else:
  102. prompt = ""
  103. if isinstance(chosen, list):
  104. chosen = "\n".join(
  105. str(x.get("content", x)) if isinstance(x, dict) else str(x)
  106. for x in chosen
  107. if x is not None
  108. )
  109. elif chosen is not None:
  110. chosen = str(chosen)
  111. else:
  112. chosen = ""
  113. if isinstance(rejected, list):
  114. rejected = "\n".join(
  115. str(x.get("content", x)) if isinstance(x, dict) else str(x)
  116. for x in rejected
  117. if x is not None
  118. )
  119. elif rejected is not None:
  120. rejected = str(rejected)
  121. else:
  122. rejected = ""
  123. return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
  124. TEMPLATE_MAP = {
  125. "sft": {
  126. "auto": None, # 特殊处理:自动检测
  127. "alpaca": apply_alpaca_template,
  128. "sharegpt": apply_sharegpt_template,
  129. "raw": apply_raw_template,
  130. },
  131. "dpo": {
  132. "auto": apply_dpo_template,
  133. "alpaca": apply_dpo_template,
  134. "sharegpt": apply_dpo_template,
  135. "raw": apply_dpo_template,
  136. },
  137. "ppo": {
  138. "auto": apply_raw_template,
  139. "alpaca": apply_alpaca_template,
  140. "sharegpt": apply_sharegpt_template,
  141. "raw": apply_raw_template,
  142. },
  143. }
  144. def preprocess_file(
  145. input_path: str,
  146. output_path: str,
  147. task_type: str = "sft",
  148. template: str = "auto",
  149. ) -> list[dict[str, Any]]:
  150. """读取文件并应用模板,返回处理后的数据列表。"""
  151. input_p = Path(input_path)
  152. ext = input_p.suffix.lower()
  153. # 读取原始数据
  154. if ext == ".jsonl":
  155. with open(input_path, "r", encoding="utf-8") as f:
  156. raw_data = [json.loads(line) for line in f if line.strip()]
  157. elif ext == ".json":
  158. with open(input_path, "r", encoding="utf-8") as f:
  159. try:
  160. data = json.load(f)
  161. raw_data = data if isinstance(data, list) else [data]
  162. except json.JSONDecodeError:
  163. # 回退到 JSONL 格式(每行一个 JSON 对象)
  164. f.seek(0)
  165. raw_data = [json.loads(line) for line in f if line.strip()]
  166. elif ext == ".csv":
  167. import csv
  168. with open(input_path, "r", encoding="utf-8") as f:
  169. reader = csv.DictReader(f)
  170. raw_data = [dict(row) for row in reader]
  171. elif ext == ".parquet":
  172. import pandas as pd
  173. df = pd.read_parquet(input_path)
  174. raw_data = df.to_dict(orient="records")
  175. else:
  176. raise ValueError(f"Unsupported format: {ext}")
  177. # 获取模板函数
  178. templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
  179. apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
  180. # Auto 模板:自动检测列名
  181. column_map = {}
  182. if template == "auto" and apply_fn is None:
  183. column_map = _detect_columns(raw_data)
  184. detected = column_map.get("template", "raw")
  185. if detected == "sharegpt":
  186. apply_fn = apply_sharegpt_template
  187. elif detected == "alpaca":
  188. apply_fn = apply_alpaca_template
  189. elif detected == "dpo":
  190. apply_fn = apply_dpo_template
  191. elif detected == "auto":
  192. apply_fn = lambda item, cm=column_map: apply_auto_template(item, cm)
  193. else:
  194. apply_fn = apply_raw_template
  195. # 应用模板
  196. processed = []
  197. for item in raw_data:
  198. try:
  199. result = apply_fn(item)
  200. # DPO/偏好类任务需要同时保留 prompt/chosen/rejected,
  201. # 仅按 prompt 过滤会把合法偏好样本误删,最终导致空 batch 进入 collator 报错。
  202. if "chosen" in result or "rejected" in result:
  203. if result.get("prompt") and (result.get("chosen") or result.get("rejected")):
  204. processed.append(result)
  205. continue
  206. if result.get("prompt"):
  207. processed.append(result)
  208. except Exception:
  209. continue
  210. # 写入处理后的数据(先删旧文件避免权限冲突)
  211. output_p = Path(output_path)
  212. output_p.parent.mkdir(parents=True, exist_ok=True)
  213. if output_p.exists():
  214. output_p.unlink()
  215. with open(output_path, "w", encoding="utf-8") as f:
  216. for item in processed:
  217. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  218. return processed