__init__.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from typing import Any
  2. def build_lora_config(params: dict[str, Any]):
  3. """返回实际的 peft.LoraConfig 对象。"""
  4. from peft import LoraConfig, TaskType
  5. target_modules = params.get("lora_target_modules", "all-linear")
  6. if isinstance(target_modules, str):
  7. if target_modules == "all-linear":
  8. target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
  9. return LoraConfig(
  10. r=params.get("lora_r", 16),
  11. lora_alpha=params.get("lora_alpha", 32),
  12. lora_dropout=params.get("lora_dropout", 0.05),
  13. target_modules=target_modules,
  14. task_type=TaskType.CAUSAL_LM,
  15. )
  16. def build_qlora_config(params: dict[str, Any]):
  17. """返回 peft.LoraConfig 对象(量化已在 load_model 中通过 HQQ 处理)。"""
  18. from peft import LoraConfig, TaskType
  19. target_modules = params.get("lora_target_modules", "all-linear")
  20. if isinstance(target_modules, str) and target_modules == "all-linear":
  21. target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
  22. return LoraConfig(
  23. r=params.get("lora_r", 16),
  24. lora_alpha=params.get("lora_alpha", 32),
  25. lora_dropout=params.get("lora_dropout", 0.05),
  26. target_modules=target_modules,
  27. task_type=TaskType.CAUSAL_LM,
  28. )
  29. def build_adalora_config(params: dict[str, Any]):
  30. """返回实际的 peft.AdaLoraConfig 对象。"""
  31. from peft import AdaLoraConfig, TaskType
  32. target_modules = params.get("lora_target_modules", "all-linear")
  33. if isinstance(target_modules, str):
  34. if target_modules == "all-linear":
  35. target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
  36. # total_step 必须由外部传入,AdaLoraConfig 的 __post_init__ 会校验 > 0
  37. # 如果没有传入,给一个较大的默认值(10000),train() 中会重新覆盖
  38. total_step = params.get("total_step", 10000)
  39. return AdaLoraConfig(
  40. init_r=params.get("adalora_init_r", 8),
  41. target_r=params.get("adalora_target_r", 16),
  42. beta1=params.get("adalora_beta1", 0.85),
  43. beta2=params.get("adalora_beta2", 0.85),
  44. target_modules=target_modules,
  45. task_type=TaskType.CAUSAL_LM,
  46. total_step=total_step,
  47. )