__init__.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """数据预处理器:将不同格式的数据集转换为训练所需格式。"""
  2. import json
  3. from pathlib import Path
  4. from typing import Any
  5. def apply_alpaca_template(item: dict) -> dict:
  6. """Alpaca 模板: instruction + input -> output。"""
  7. instruction = item.get("instruction", "")
  8. input_text = item.get("input", "")
  9. output = item.get("output", "")
  10. # 确保所有值为字符串
  11. instruction = str(instruction) if instruction is not None else ""
  12. input_text = str(input_text) if input_text is not None else ""
  13. output = str(output) if output is not None else ""
  14. prompt = f"{instruction}\n\n{input_text}" if input_text else instruction
  15. return {"prompt": prompt, "completion": output}
  16. def apply_sharegpt_template(item: dict) -> dict:
  17. """ShareGPT 模板: conversations list -> formatted prompt + completion。"""
  18. conversations = item.get("conversations", [])
  19. if len(conversations) < 2:
  20. return {"prompt": "", "completion": ""}
  21. prompt_parts = []
  22. completion = ""
  23. for i, turn in enumerate(conversations):
  24. role = turn.get("from", turn.get("role", "human"))
  25. content = turn.get("value", turn.get("content", ""))
  26. if i == 0:
  27. prompt_parts.append(content)
  28. elif i == 1:
  29. completion = content
  30. break
  31. else:
  32. prompt_parts.append(f"{role}: {content}")
  33. prompt = "\n".join(prompt_parts)
  34. return {"prompt": prompt, "completion": completion}
  35. def apply_raw_template(item: dict) -> dict:
  36. """Raw 模板: 直接读取 prompt/text 和 completion/output 字段。"""
  37. prompt = item.get("prompt", item.get("text", item.get("input", "")))
  38. completion = item.get("completion", item.get("output", item.get("target", "")))
  39. return {"prompt": str(prompt), "completion": str(completion)}
  40. def apply_dpo_template(item: dict) -> dict:
  41. """DPO 模板: prompt + chosen + rejected。"""
  42. return {
  43. "prompt": item.get("prompt", item.get("input", "")),
  44. "chosen": item.get("chosen", item.get("positive", "")),
  45. "rejected": item.get("rejected", item.get("negative", "")),
  46. }
  47. def apply_kto_template(item: dict) -> dict:
  48. """KTO 模板: prompt + completion + label。"""
  49. return {
  50. "prompt": item.get("prompt", item.get("input", "")),
  51. "completion": item.get("completion", item.get("output", "")),
  52. "label": item.get("label", True),
  53. }
  54. def apply_orpo_template(item: dict) -> dict:
  55. """ORPO 模板: prompt + chosen + rejected (类似 DPO)。"""
  56. return {
  57. "prompt": item.get("prompt", item.get("input", "")),
  58. "chosen": item.get("chosen", item.get("positive", "")),
  59. "rejected": item.get("rejected", item.get("negative", "")),
  60. }
  61. def apply_rm_template(item: dict) -> dict:
  62. """Reward Modeling 模板: prompt + chosen + rejected。"""
  63. return {
  64. "prompt": item.get("prompt", item.get("input", "")),
  65. "chosen": item.get("chosen", item.get("positive", "")),
  66. "rejected": item.get("rejected", item.get("negative", "")),
  67. }
  68. TEMPLATE_MAP = {
  69. "sft": {
  70. "alpaca": apply_alpaca_template,
  71. "sharegpt": apply_sharegpt_template,
  72. "raw": apply_raw_template,
  73. },
  74. "dpo": {
  75. "alpaca": apply_dpo_template,
  76. "sharegpt": apply_dpo_template,
  77. "raw": apply_dpo_template,
  78. },
  79. "kto": {
  80. "raw": apply_kto_template,
  81. },
  82. "orpo": {
  83. "alpaca": apply_orpo_template,
  84. "raw": apply_orpo_template,
  85. },
  86. "rm": {
  87. "raw": apply_rm_template,
  88. },
  89. "ppo": {
  90. "raw": apply_raw_template,
  91. },
  92. }
  93. def preprocess_file(
  94. input_path: str,
  95. output_path: str,
  96. task_type: str = "sft",
  97. template: str = "alpaca",
  98. ) -> list[dict[str, Any]]:
  99. """读取文件并应用模板,返回处理后的数据列表。"""
  100. input_p = Path(input_path)
  101. ext = input_p.suffix.lower()
  102. # 读取原始数据
  103. if ext == ".jsonl":
  104. with open(input_path, "r", encoding="utf-8") as f:
  105. raw_data = [json.loads(line) for line in f if line.strip()]
  106. elif ext == ".json":
  107. with open(input_path, "r", encoding="utf-8") as f:
  108. try:
  109. data = json.load(f)
  110. raw_data = data if isinstance(data, list) else [data]
  111. except json.JSONDecodeError:
  112. # 回退到 JSONL 格式(每行一个 JSON 对象)
  113. f.seek(0)
  114. raw_data = [json.loads(line) for line in f if line.strip()]
  115. elif ext == ".csv":
  116. import csv
  117. with open(input_path, "r", encoding="utf-8") as f:
  118. reader = csv.DictReader(f)
  119. raw_data = [dict(row) for row in reader]
  120. elif ext == ".parquet":
  121. import pandas as pd
  122. df = pd.read_parquet(input_path)
  123. raw_data = df.to_dict(orient="records")
  124. else:
  125. raise ValueError(f"Unsupported format: {ext}")
  126. # 获取模板函数
  127. templates = TEMPLATE_MAP.get(task_type, TEMPLATE_MAP["sft"])
  128. apply_fn = templates.get(template, templates.get("raw", apply_raw_template))
  129. # 应用模板
  130. processed = []
  131. for item in raw_data:
  132. try:
  133. result = apply_fn(item)
  134. if result.get("prompt"):
  135. processed.append(result)
  136. except Exception:
  137. continue
  138. # 写入处理后的数据(先删旧文件避免权限冲突)
  139. output_p = Path(output_path)
  140. output_p.parent.mkdir(parents=True, exist_ok=True)
  141. if output_p.exists():
  142. output_p.unlink()
  143. with open(output_path, "w", encoding="utf-8") as f:
  144. for item in processed:
  145. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  146. return processed