config.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. from functools import lru_cache
  3. from pathlib import Path
  4. from pydantic import Field, field_validator
  5. from pydantic_settings import BaseSettings, SettingsConfigDict
  6. class Settings(BaseSettings):
  7. model_config = SettingsConfigDict(
  8. env_file=str(Path(__file__).resolve().parents[1] / ".env"),
  9. env_file_encoding="utf-8",
  10. case_sensitive=False,
  11. extra="ignore",
  12. )
  13. # --- 数据路径 ---
  14. data_dir: Path = Path("/root/Fine-tuning/backend/data")
  15. # --- HuggingFace / ModelScope ---
  16. hf_token: str = ""
  17. hf_endpoint: str = "https://huggingface.co"
  18. use_modelscope: bool = False
  19. modelscope_endpoint: str = "https://modelscope.cn"
  20. # --- GPU / 硬件 ---
  21. cuda_visible_devices: str = "0"
  22. max_memory_per_gpu: str = "0"
  23. use_unsloth: bool = False
  24. # --- 后端 ---
  25. backend_host: str = "0.0.0.0"
  26. backend_port: int = 8000
  27. backend_env: str = "production"
  28. backend_log_level: str = "INFO"
  29. backend_cors_origins: list[str] = ["http://192.168.91.253:5173"]
  30. # --- 数据库 ---
  31. database_url: str = "sqlite+aiosqlite:///root/Fine-tuning/backend/data/finetuning.db"
  32. # --- 训练默认参数 ---
  33. default_peft_method: str = "lora"
  34. default_epochs: int = 3
  35. default_batch_size: int = 4
  36. default_gradient_accumulation: int = 4
  37. default_lr: float = 2e-4
  38. default_max_seq_length: int = 2048
  39. default_warmup_ratio: float = 0.05
  40. default_save_strategy: str = "epoch"
  41. default_eval_strategy: str = "epoch"
  42. default_eval_steps: int = 100
  43. # --- LoRA ---
  44. lora_r: int = 16
  45. lora_alpha: int = 32
  46. lora_dropout: float = 0.05
  47. lora_target_modules: str = "all-linear"
  48. # --- QLoRA ---
  49. qlora_bits: int = 4
  50. qlora_type: str = "nf4"
  51. qlora_double_quant: bool = True
  52. # --- 上传限制 ---
  53. max_upload_size_mb: int = 500
  54. allowed_dataset_formats: str = "jsonl,csv,parquet,json"
  55. @field_validator("backend_cors_origins", mode="before")
  56. @classmethod
  57. def parse_cors_origins(cls, v):
  58. if isinstance(v, str):
  59. return [origin.strip() for origin in v.split(",") if origin.strip()]
  60. return v
  61. @property
  62. def models_dir(self) -> Path:
  63. return self.data_dir / "models"
  64. @property
  65. def adapters_dir(self) -> Path:
  66. return self.data_dir / "adapters"
  67. @property
  68. def uploads_dir(self) -> Path:
  69. return self.data_dir / "uploads"
  70. @property
  71. def processed_dir(self) -> Path:
  72. return self.data_dir / "processed"
  73. def ensure_dirs(self) -> None:
  74. for d in [self.models_dir, self.adapters_dir, self.uploads_dir, self.processed_dir]:
  75. d.mkdir(parents=True, exist_ok=True)
  76. @lru_cache
  77. def get_settings() -> Settings:
  78. settings = Settings()
  79. settings.ensure_dirs()
  80. # 设置 HF 环境变量
  81. if settings.hf_token:
  82. os.environ["HF_TOKEN"] = settings.hf_token
  83. os.environ["HF_ENDPOINT"] = settings.hf_endpoint
  84. if settings.cuda_visible_devices:
  85. os.environ["CUDA_VISIBLE_DEVICES"] = settings.cuda_visible_devices
  86. return settings