|
@@ -20,24 +20,14 @@ def build_lora_config(params: dict[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_qlora_config(params: dict[str, Any]):
|
|
def build_qlora_config(params: dict[str, Any]):
|
|
|
- """返回 (bitsandbytes_config, peft.LoraConfig) 二元组。"""
|
|
|
|
|
|
|
+ """返回 peft.LoraConfig 对象(量化已在 load_model 中通过 HQQ 处理)。"""
|
|
|
from peft import LoraConfig, TaskType
|
|
from peft import LoraConfig, TaskType
|
|
|
- from transformers import BitsAndBytesConfig
|
|
|
|
|
- import torch
|
|
|
|
|
-
|
|
|
|
|
- bnb_params = BitsAndBytesConfig(
|
|
|
|
|
- load_in_4bit=params.get("qlora_bits", 4) == 4,
|
|
|
|
|
- load_in_8bit=params.get("qlora_bits", 4) == 8,
|
|
|
|
|
- bnb_4bit_quant_type=params.get("qlora_type", "nf4"),
|
|
|
|
|
- bnb_4bit_use_double_quant=params.get("qlora_double_quant", True),
|
|
|
|
|
- bnb_4bit_compute_dtype=torch.float16,
|
|
|
|
|
- )
|
|
|
|
|
|
|
|
|
|
target_modules = params.get("lora_target_modules", "all-linear")
|
|
target_modules = params.get("lora_target_modules", "all-linear")
|
|
|
if isinstance(target_modules, str) and target_modules == "all-linear":
|
|
if isinstance(target_modules, str) and target_modules == "all-linear":
|
|
|
target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
|
|
target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
|
|
|
|
|
|
|
|
- lora_cfg = LoraConfig(
|
|
|
|
|
|
|
+ return LoraConfig(
|
|
|
r=params.get("lora_r", 16),
|
|
r=params.get("lora_r", 16),
|
|
|
lora_alpha=params.get("lora_alpha", 32),
|
|
lora_alpha=params.get("lora_alpha", 32),
|
|
|
lora_dropout=params.get("lora_dropout", 0.05),
|
|
lora_dropout=params.get("lora_dropout", 0.05),
|
|
@@ -45,8 +35,6 @@ def build_qlora_config(params: dict[str, Any]):
|
|
|
task_type=TaskType.CAUSAL_LM,
|
|
task_type=TaskType.CAUSAL_LM,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- return bnb_params, lora_cfg
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
def build_adalora_config(params: dict[str, Any]):
|
|
def build_adalora_config(params: dict[str, Any]):
|
|
|
"""返回实际的 peft.AdaLoraConfig 对象。"""
|
|
"""返回实际的 peft.AdaLoraConfig 对象。"""
|