|
@@ -40,6 +40,11 @@ def build_adalora_config(params: dict[str, Any]):
|
|
|
"""返回实际的 peft.AdaLoraConfig 对象。"""
|
|
"""返回实际的 peft.AdaLoraConfig 对象。"""
|
|
|
from peft import AdaLoraConfig, TaskType
|
|
from peft import AdaLoraConfig, TaskType
|
|
|
|
|
|
|
|
|
|
+ target_modules = params.get("lora_target_modules", "all-linear")
|
|
|
|
|
+ if isinstance(target_modules, str):
|
|
|
|
|
+ if target_modules == "all-linear":
|
|
|
|
|
+ target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
|
|
|
|
|
+
|
|
|
# total_step 必须由外部传入,AdaLoraConfig 的 __post_init__ 会校验 > 0
|
|
# total_step 必须由外部传入,AdaLoraConfig 的 __post_init__ 会校验 > 0
|
|
|
# 如果没有传入,给一个较大的默认值(10000),train() 中会重新覆盖
|
|
# 如果没有传入,给一个较大的默认值(10000),train() 中会重新覆盖
|
|
|
total_step = params.get("total_step", 10000)
|
|
total_step = params.get("total_step", 10000)
|
|
@@ -49,6 +54,7 @@ def build_adalora_config(params: dict[str, Any]):
|
|
|
target_r=params.get("adalora_target_r", 16),
|
|
target_r=params.get("adalora_target_r", 16),
|
|
|
beta1=params.get("adalora_beta1", 0.85),
|
|
beta1=params.get("adalora_beta1", 0.85),
|
|
|
beta2=params.get("adalora_beta2", 0.85),
|
|
beta2=params.get("adalora_beta2", 0.85),
|
|
|
|
|
+ target_modules=target_modules,
|
|
|
task_type=TaskType.CAUSAL_LM,
|
|
task_type=TaskType.CAUSAL_LM,
|
|
|
total_step=total_step,
|
|
total_step=total_step,
|
|
|
)
|
|
)
|