from typing import Any def build_lora_config(params: dict[str, Any]): """返回实际的 peft.LoraConfig 对象。""" from peft import LoraConfig, TaskType target_modules = params.get("lora_target_modules", "all-linear") if isinstance(target_modules, str): if target_modules == "all-linear": target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"] return LoraConfig( r=params.get("lora_r", 16), lora_alpha=params.get("lora_alpha", 32), lora_dropout=params.get("lora_dropout", 0.05), target_modules=target_modules, task_type=TaskType.CAUSAL_LM, ) def build_qlora_config(params: dict[str, Any]): """返回 (bitsandbytes_config, peft.LoraConfig) 二元组。""" from peft import LoraConfig, TaskType bnb_params = { "load_in_4bit": params.get("qlora_bits", 4) == 4, "load_in_8bit": params.get("qlora_bits", 4) == 8, "bnb_4bit_quant_type": params.get("qlora_type", "nf4"), "bnb_4bit_use_double_quant": params.get("qlora_double_quant", True), "bnb_4bit_compute_dtype": "float16", } target_modules = params.get("lora_target_modules", "all-linear") if isinstance(target_modules, str) and target_modules == "all-linear": target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"] lora_cfg = LoraConfig( r=params.get("lora_r", 16), lora_alpha=params.get("lora_alpha", 32), lora_dropout=params.get("lora_dropout", 0.05), target_modules=target_modules, task_type=TaskType.CAUSAL_LM, ) return bnb_params, lora_cfg def build_ia3_config(params: dict[str, Any]): """返回实际的 peft.IA3Config 对象。""" from peft import IA3Config, TaskType target_modules = params.get("ia3_target_modules", "all-linear") if isinstance(target_modules, str) and target_modules == "all-linear": target_modules = ["k_proj", "v_proj", "ffn"] return IA3Config( target_modules=target_modules, task_type=TaskType.CAUSAL_LM, ) def build_adalora_config(params: dict[str, Any]): """返回实际的 peft.AdaLoraConfig 对象。""" from peft import AdaLoraConfig, TaskType return AdaLoraConfig( init_r=params.get("adalora_init_r", 8), target_r=params.get("adalora_target_r", 16), beta1=params.get("adalora_beta1", 0.85), beta2=params.get("adalora_beta2", 0.85), task_type=TaskType.CAUSAL_LM, ) def build_prefix_tuning_config(params: dict[str, Any]): """返回实际的 peft.PromptTuningConfig 对象。""" from peft import PromptTuningConfig, PromptTuningInit, TaskType return PromptTuningConfig( num_virtual_tokens=params.get("prefix_num_virtual_tokens", 20), prompt_tuning_init=PromptTuningInit.TEXT, prompt_tuning_init_text="Classify the following text: ", task_type=TaskType.CAUSAL_LM, )