| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import os
- from functools import lru_cache
- from pathlib import Path
- from pydantic import Field, field_validator
- from pydantic_settings import (
- BaseSettings,
- EnvSettingsSource,
- DotEnvSettingsSource,
- SettingsConfigDict,
- )
- class _CommaListMixin:
- """Mixin for handling comma-separated list values in env/dotenv sources."""
- def decode_complex_value(self, field_name, field, value):
- if field_name == "backend_cors_origins" and isinstance(value, str):
- return [v.strip() for v in value.split(",") if v.strip()]
- return super().decode_complex_value(field_name, field, value)
- class DotEnvSourceWithCommaLists(_CommaListMixin, DotEnvSettingsSource):
- """Dotenv source that handles comma-separated lists."""
- class EnvSettingsSourceWithCommaLists(_CommaListMixin, EnvSettingsSource):
- """Env source that handles comma-separated lists."""
- class Settings(BaseSettings):
- model_config = SettingsConfigDict(
- env_file=str(Path(__file__).resolve().parent.parent / ".env"),
- env_file_encoding="utf-8",
- case_sensitive=False,
- extra="ignore",
- )
- @classmethod
- def settings_customise_sources(
- cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings
- ):
- return (
- init_settings,
- EnvSettingsSourceWithCommaLists(settings_cls),
- DotEnvSourceWithCommaLists(settings_cls),
- file_secret_settings,
- )
- # --- 数据路径 ---
- data_dir: Path = Path("/home/ubuntu/Fine-tuning/backend/data")
- # --- HuggingFace / ModelScope ---
- hf_token: str = ""
- hf_endpoint: str = "https://huggingface.co"
- use_modelscope: bool = False
- modelscope_endpoint: str = "https://modelscope.cn"
- # --- GPU / 硬件 ---
- cuda_visible_devices: str = "0"
- max_memory_per_gpu: str = "0"
- use_unsloth: bool = False
- # --- 后端 ---
- backend_host: str = "0.0.0.0"
- backend_port: int = 8000
- backend_env: str = "production"
- backend_log_level: str = "INFO"
- backend_cors_origins: list[str] = ["http://183.220.37.46:5173"]
- # --- 数据库 ---
- database_url: str = "postgresql+asyncpg://finetune:finetune123@localhost:5432/finetuning"
- # --- 训练默认参数 ---
- default_peft_method: str = "lora"
- default_epochs: int = 3
- default_batch_size: int = 4
- default_gradient_accumulation: int = 4
- default_lr: float = 2e-4
- default_max_seq_length: int = 2048
- default_warmup_ratio: float = 0.05
- default_save_strategy: str = "epoch"
- default_eval_strategy: str = "epoch"
- default_eval_steps: int = 100
- # --- LoRA ---
- lora_r: int = 16
- lora_alpha: int = 32
- lora_dropout: float = 0.05
- lora_target_modules: str = "all-linear"
- # --- QLoRA ---
- qlora_bits: int = 4
- qlora_type: str = "nf4"
- qlora_double_quant: bool = True
- # --- 上传限制 ---
- max_upload_size_mb: int = 500
- allowed_dataset_formats: str = "jsonl,csv,parquet,json"
- # --- SSO 统一认证 ---
- sso_base_url: str = "http://192.168.92.61:8200"
- sso_client_id: str = "5bdce571-c092-45ff-a491-44a14a000426"
- sso_client_secret: str = "hmDeOtXZVbeo2AZ-x58yPssZLg4Tcb1W"
- sso_redirect_uri: str = "http://localhost:3000/auth/callback"
- sso_frontend_url: str = "http://localhost:3000"
- sso_scope: str = "email"
- sso_logout_redirect_url: str = "http://192.168.92.61:9200/login"
- jwt_secret_key: str = "change-me-in-production-use-a-long-random-string"
- jwt_algorithm: str = "HS256"
- jwt_access_expire_minutes: int = 20
- jwt_refresh_expire_hours: int = 24
- @field_validator("backend_cors_origins", mode="before")
- @classmethod
- def parse_cors_origins(cls, v):
- if isinstance(v, str):
- return [origin.strip() for origin in v.split(",") if origin.strip()]
- return v
- @property
- def models_dir(self) -> Path:
- return self.data_dir / "models"
- @property
- def adapters_dir(self) -> Path:
- return self.data_dir / "adapters"
- @property
- def uploads_dir(self) -> Path:
- return self.data_dir / "uploads"
- @property
- def processed_dir(self) -> Path:
- return self.data_dir / "processed"
- def ensure_dirs(self) -> None:
- self.data_dir.mkdir(parents=True, exist_ok=True)
- for d in [self.models_dir, self.adapters_dir, self.uploads_dir, self.processed_dir]:
- d.mkdir(parents=True, exist_ok=True)
- @lru_cache
- def get_settings() -> Settings:
- settings = Settings()
- settings.ensure_dirs()
- # 设置 HF 环境变量
- if settings.hf_token:
- os.environ["HF_TOKEN"] = settings.hf_token
- os.environ["HF_ENDPOINT"] = settings.hf_endpoint
- if settings.cuda_visible_devices:
- os.environ["CUDA_VISIBLE_DEVICES"] = settings.cuda_visible_devices
- return settings
|