lxylxy123321 2 napja
szülő
commit
e6f2e0d49f
1 módosított fájl, 6 hozzáadás és 0 törlés
  1. 6 0
      backend/app/peft/__init__.py

+ 6 - 0
backend/app/peft/__init__.py

@@ -40,6 +40,11 @@ def build_adalora_config(params: dict[str, Any]):
     """返回实际的 peft.AdaLoraConfig 对象。"""
     from peft import AdaLoraConfig, TaskType
 
+    target_modules = params.get("lora_target_modules", "all-linear")
+    if isinstance(target_modules, str):
+        if target_modules == "all-linear":
+            target_modules = ["linear", "lm_head", "q_proj", "v_proj", "k_proj", "o_proj"]
+
     # total_step 必须由外部传入,AdaLoraConfig 的 __post_init__ 会校验 > 0
     # 如果没有传入,给一个较大的默认值(10000),train() 中会重新覆盖
     total_step = params.get("total_step", 10000)
@@ -49,6 +54,7 @@ def build_adalora_config(params: dict[str, Any]):
         target_r=params.get("adalora_target_r", 16),
         beta1=params.get("adalora_beta1", 0.85),
         beta2=params.get("adalora_beta2", 0.85),
+        target_modules=target_modules,
         task_type=TaskType.CAUSAL_LM,
         total_step=total_step,
     )