| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- from typing import Any
- def build_lora_config(params: dict[str, Any]):
- """返回实际的 peft.LoraConfig 对象。"""
- from peft import LoraConfig, TaskType
- target_modules = params.get("lora_target_modules", "all-linear")
- if isinstance(target_modules, str):
- if target_modules == "all-linear":
- target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
- return LoraConfig(
- r=params.get("lora_r", 16),
- lora_alpha=params.get("lora_alpha", 32),
- lora_dropout=params.get("lora_dropout", 0.05),
- target_modules=target_modules,
- task_type=TaskType.CAUSAL_LM,
- )
- def build_qlora_config(params: dict[str, Any]):
- """返回 (bitsandbytes_config, peft.LoraConfig) 二元组。"""
- from peft import LoraConfig, TaskType
- from transformers import BitsAndBytesConfig
- import torch
- bnb_params = BitsAndBytesConfig(
- load_in_4bit=params.get("qlora_bits", 4) == 4,
- load_in_8bit=params.get("qlora_bits", 4) == 8,
- bnb_4bit_quant_type=params.get("qlora_type", "nf4"),
- bnb_4bit_use_double_quant=params.get("qlora_double_quant", True),
- bnb_4bit_compute_dtype=torch.float16,
- )
- target_modules = params.get("lora_target_modules", "all-linear")
- if isinstance(target_modules, str) and target_modules == "all-linear":
- target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
- lora_cfg = LoraConfig(
- r=params.get("lora_r", 16),
- lora_alpha=params.get("lora_alpha", 32),
- lora_dropout=params.get("lora_dropout", 0.05),
- target_modules=target_modules,
- task_type=TaskType.CAUSAL_LM,
- )
- return bnb_params, lora_cfg
- def build_adalora_config(params: dict[str, Any]):
- """返回实际的 peft.AdaLoraConfig 对象。"""
- from peft import AdaLoraConfig, TaskType
- return AdaLoraConfig(
- init_r=params.get("adalora_init_r", 8),
- target_r=params.get("adalora_target_r", 16),
- beta1=params.get("adalora_beta1", 0.85),
- beta2=params.get("adalora_beta2", 0.85),
- task_type=TaskType.CAUSAL_LM,
- )
|