__init__.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from typing import Any
  2. def build_lora_config(params: dict[str, Any]):
  3. """Build LoRA config dict from parameters."""
  4. return {
  5. "r": params.get("lora_r", 16),
  6. "lora_alpha": params.get("lora_alpha", 32),
  7. "lora_dropout": params.get("lora_dropout", 0.05),
  8. "target_modules": params.get("lora_target_modules", "all-linear"),
  9. }
  10. def build_qlora_config(params: dict[str, Any]):
  11. """Build QLoRA config dict from parameters."""
  12. return {
  13. "bits": params.get("qlora_bits", 4),
  14. "qlora_type": params.get("qlora_type", "nf4"),
  15. "double_quant": params.get("qlora_double_quant", True),
  16. "lora": build_lora_config(params),
  17. }
  18. def build_ia3_config(params: dict[str, Any]):
  19. return {"target_modules": params.get("ia3_target_modules", "all-linear")}
  20. def build_adalora_config(params: dict[str, Any]):
  21. return {
  22. "init_r": params.get("adalora_init_r", 8),
  23. "target_r": params.get("adalora_target_r", 16),
  24. "beta1": params.get("adalora_beta1", 0.85),
  25. "beta2": params.get("adalora_beta2", 0.85),
  26. }
  27. def build_prefix_tuning_config(params: dict[str, Any]):
  28. return {
  29. "num_virtual_tokens": params.get("prefix_num_virtual_tokens", 20),
  30. "encoder_hidden_size": params.get("prefix_encoder_hidden_size", 128),
  31. }