config.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import os
  2. from functools import lru_cache
  3. from pathlib import Path
  4. from pydantic import Field, field_validator
  5. from pydantic_settings import (
  6. BaseSettings,
  7. EnvSettingsSource,
  8. DotEnvSettingsSource,
  9. SettingsConfigDict,
  10. )
  11. class _CommaListMixin:
  12. """Mixin for handling comma-separated list values in env/dotenv sources."""
  13. def decode_complex_value(self, field_name, field, value):
  14. if field_name == "backend_cors_origins" and isinstance(value, str):
  15. return [v.strip() for v in value.split(",") if v.strip()]
  16. return super().decode_complex_value(field_name, field, value)
  17. class DotEnvSourceWithCommaLists(_CommaListMixin, DotEnvSettingsSource):
  18. """Dotenv source that handles comma-separated lists."""
  19. class EnvSettingsSourceWithCommaLists(_CommaListMixin, EnvSettingsSource):
  20. """Env source that handles comma-separated lists."""
  21. class Settings(BaseSettings):
  22. model_config = SettingsConfigDict(
  23. env_file=str(Path(__file__).resolve().parent.parent / ".env"),
  24. env_file_encoding="utf-8",
  25. case_sensitive=False,
  26. extra="ignore",
  27. )
  28. @classmethod
  29. def settings_customise_sources(
  30. cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings
  31. ):
  32. return (
  33. init_settings,
  34. EnvSettingsSourceWithCommaLists(settings_cls),
  35. DotEnvSourceWithCommaLists(settings_cls),
  36. file_secret_settings,
  37. )
  38. # --- 数据路径 ---
  39. data_dir: Path = Path("/home/ubuntu/Fine-tuning/backend/data")
  40. # --- HuggingFace / ModelScope ---
  41. hf_token: str = ""
  42. hf_endpoint: str = "https://huggingface.co"
  43. use_modelscope: bool = False
  44. modelscope_endpoint: str = "https://modelscope.cn"
  45. # --- GPU / 硬件 ---
  46. cuda_visible_devices: str = "0"
  47. max_memory_per_gpu: str = "0"
  48. use_unsloth: bool = False
  49. # --- 后端 ---
  50. backend_host: str = "0.0.0.0"
  51. backend_port: int = 8000
  52. backend_env: str = "production"
  53. backend_log_level: str = "INFO"
  54. backend_cors_origins: list[str] = ["http://183.220.37.46:5173"]
  55. # --- 数据库 ---
  56. database_url: str = "postgresql+asyncpg://finetune:finetune123@localhost:5432/finetuning"
  57. # --- 训练默认参数 ---
  58. default_peft_method: str = "lora"
  59. default_epochs: int = 3
  60. default_batch_size: int = 4
  61. default_gradient_accumulation: int = 4
  62. default_lr: float = 2e-4
  63. default_max_seq_length: int = 2048
  64. default_warmup_ratio: float = 0.05
  65. default_save_strategy: str = "epoch"
  66. default_eval_strategy: str = "epoch"
  67. default_eval_steps: int = 100
  68. # --- LoRA ---
  69. lora_r: int = 16
  70. lora_alpha: int = 32
  71. lora_dropout: float = 0.05
  72. lora_target_modules: str = "all-linear"
  73. # --- QLoRA ---
  74. qlora_bits: int = 4
  75. qlora_type: str = "nf4"
  76. qlora_double_quant: bool = True
  77. # --- 上传限制 ---
  78. max_upload_size_mb: int = 500
  79. allowed_dataset_formats: str = "jsonl,csv,parquet,json"
  80. # --- SSO 统一认证 ---
  81. sso_base_url: str = "http://192.168.92.61:8200"
  82. sso_client_id: str = "5bdce571-c092-45ff-a491-44a14a000426"
  83. sso_client_secret: str = "hmDeOtXZVbeo2AZ-x58yPssZLg4Tcb1W"
  84. sso_redirect_uri: str = "http://localhost:3000/auth/callback"
  85. sso_frontend_url: str = "http://localhost:3000"
  86. sso_scope: str = "email"
  87. sso_logout_redirect_url: str = "http://192.168.92.61:9200/login"
  88. jwt_secret_key: str = "change-me-in-production-use-a-long-random-string"
  89. jwt_algorithm: str = "HS256"
  90. jwt_access_expire_minutes: int = 20
  91. jwt_refresh_expire_hours: int = 24
  92. @field_validator("backend_cors_origins", mode="before")
  93. @classmethod
  94. def parse_cors_origins(cls, v):
  95. if isinstance(v, str):
  96. return [origin.strip() for origin in v.split(",") if origin.strip()]
  97. return v
  98. @property
  99. def models_dir(self) -> Path:
  100. return self.data_dir / "models"
  101. @property
  102. def adapters_dir(self) -> Path:
  103. return self.data_dir / "adapters"
  104. @property
  105. def uploads_dir(self) -> Path:
  106. return self.data_dir / "uploads"
  107. @property
  108. def processed_dir(self) -> Path:
  109. return self.data_dir / "processed"
  110. def ensure_dirs(self) -> None:
  111. self.data_dir.mkdir(parents=True, exist_ok=True)
  112. for d in [self.models_dir, self.adapters_dir, self.uploads_dir, self.processed_dir]:
  113. d.mkdir(parents=True, exist_ok=True)
  114. @lru_cache
  115. def get_settings() -> Settings:
  116. settings = Settings()
  117. settings.ensure_dirs()
  118. # 设置 HF 环境变量
  119. if settings.hf_token:
  120. os.environ["HF_TOKEN"] = settings.hf_token
  121. os.environ["HF_ENDPOINT"] = settings.hf_endpoint
  122. if settings.cuda_visible_devices:
  123. os.environ["CUDA_VISIBLE_DEVICES"] = settings.cuda_visible_devices
  124. return settings