__init__.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from typing import Any
  2. def build_lora_config(params: dict[str, Any]):
  3. """返回实际的 peft.LoraConfig 对象。"""
  4. from peft import LoraConfig, TaskType
  5. target_modules = params.get("lora_target_modules", "all-linear")
  6. if isinstance(target_modules, str):
  7. if target_modules == "all-linear":
  8. target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
  9. return LoraConfig(
  10. r=params.get("lora_r", 16),
  11. lora_alpha=params.get("lora_alpha", 32),
  12. lora_dropout=params.get("lora_dropout", 0.05),
  13. target_modules=target_modules,
  14. task_type=TaskType.CAUSAL_LM,
  15. )
  16. def build_qlora_config(params: dict[str, Any]):
  17. """返回 (bitsandbytes_config, peft.LoraConfig) 二元组。"""
  18. from peft import LoraConfig, TaskType
  19. bnb_params = {
  20. "load_in_4bit": params.get("qlora_bits", 4) == 4,
  21. "load_in_8bit": params.get("qlora_bits", 4) == 8,
  22. "bnb_4bit_quant_type": params.get("qlora_type", "nf4"),
  23. "bnb_4bit_use_double_quant": params.get("qlora_double_quant", True),
  24. "bnb_4bit_compute_dtype": "float16",
  25. }
  26. target_modules = params.get("lora_target_modules", "all-linear")
  27. if isinstance(target_modules, str) and target_modules == "all-linear":
  28. target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
  29. lora_cfg = LoraConfig(
  30. r=params.get("lora_r", 16),
  31. lora_alpha=params.get("lora_alpha", 32),
  32. lora_dropout=params.get("lora_dropout", 0.05),
  33. target_modules=target_modules,
  34. task_type=TaskType.CAUSAL_LM,
  35. )
  36. return bnb_params, lora_cfg
  37. def build_ia3_config(params: dict[str, Any]):
  38. """返回实际的 peft.IA3Config 对象。"""
  39. from peft import IA3Config, TaskType
  40. target_modules = params.get("ia3_target_modules", "all-linear")
  41. if isinstance(target_modules, str) and target_modules == "all-linear":
  42. target_modules = ["k_proj", "v_proj", "ffn"]
  43. return IA3Config(
  44. target_modules=target_modules,
  45. task_type=TaskType.CAUSAL_LM,
  46. )
  47. def build_adalora_config(params: dict[str, Any]):
  48. """返回实际的 peft.AdaLoraConfig 对象。"""
  49. from peft import AdaLoraConfig, TaskType
  50. return AdaLoraConfig(
  51. init_r=params.get("adalora_init_r", 8),
  52. target_r=params.get("adalora_target_r", 16),
  53. beta1=params.get("adalora_beta1", 0.85),
  54. beta2=params.get("adalora_beta2", 0.85),
  55. task_type=TaskType.CAUSAL_LM,
  56. )
  57. def build_prefix_tuning_config(params: dict[str, Any]):
  58. """返回实际的 peft.PromptTuningConfig 对象。"""
  59. from peft import PromptTuningConfig, PromptTuningInit, TaskType
  60. return PromptTuningConfig(
  61. num_virtual_tokens=params.get("prefix_num_virtual_tokens", 20),
  62. prompt_tuning_init=PromptTuningInit.TEXT,
  63. prompt_tuning_init_text="Classify the following text: ",
  64. task_type=TaskType.CAUSAL_LM,
  65. )