| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- from typing import Any
- def build_lora_config(params: dict[str, Any]):
- """Build LoRA config dict from parameters."""
- return {
- "r": params.get("lora_r", 16),
- "lora_alpha": params.get("lora_alpha", 32),
- "lora_dropout": params.get("lora_dropout", 0.05),
- "target_modules": params.get("lora_target_modules", "all-linear"),
- }
- def build_qlora_config(params: dict[str, Any]):
- """Build QLoRA config dict from parameters."""
- return {
- "bits": params.get("qlora_bits", 4),
- "qlora_type": params.get("qlora_type", "nf4"),
- "double_quant": params.get("qlora_double_quant", True),
- "lora": build_lora_config(params),
- }
- def build_ia3_config(params: dict[str, Any]):
- return {"target_modules": params.get("ia3_target_modules", "all-linear")}
- def build_adalora_config(params: dict[str, Any]):
- return {
- "init_r": params.get("adalora_init_r", 8),
- "target_r": params.get("adalora_target_r", 16),
- "beta1": params.get("adalora_beta1", 0.85),
- "beta2": params.get("adalora_beta2", 0.85),
- }
- def build_prefix_tuning_config(params: dict[str, Any]):
- return {
- "num_virtual_tokens": params.get("prefix_num_virtual_tokens", 20),
- "encoder_hidden_size": params.get("prefix_encoder_hidden_size", 128),
- }
|