import os from functools import lru_cache from pathlib import Path from pydantic import Field, field_validator from pydantic_settings import ( BaseSettings, EnvSettingsSource, DotEnvSettingsSource, SettingsConfigDict, ) class _CommaListMixin: """Mixin for handling comma-separated list values in env/dotenv sources.""" def decode_complex_value(self, field_name, field, value): if field_name == "backend_cors_origins" and isinstance(value, str): return [v.strip() for v in value.split(",") if v.strip()] return super().decode_complex_value(field_name, field, value) class DotEnvSourceWithCommaLists(_CommaListMixin, DotEnvSettingsSource): """Dotenv source that handles comma-separated lists.""" class EnvSettingsSourceWithCommaLists(_CommaListMixin, EnvSettingsSource): """Env source that handles comma-separated lists.""" class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=str(Path(__file__).resolve().parent.parent / ".env"), env_file_encoding="utf-8", case_sensitive=False, extra="ignore", ) @classmethod def settings_customise_sources( cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings ): return ( init_settings, EnvSettingsSourceWithCommaLists(settings_cls), DotEnvSourceWithCommaLists(settings_cls), file_secret_settings, ) # --- 数据路径 --- data_dir: Path = Path("/home/ubuntu/Fine-tuning/backend/data") # --- HuggingFace / ModelScope --- hf_token: str = "" hf_endpoint: str = "https://huggingface.co" use_modelscope: bool = False modelscope_endpoint: str = "https://modelscope.cn" # --- GPU / 硬件 --- cuda_visible_devices: str = "0" max_memory_per_gpu: str = "0" use_unsloth: bool = False # --- 后端 --- backend_host: str = "0.0.0.0" backend_port: int = 8000 backend_env: str = "production" backend_log_level: str = "INFO" backend_cors_origins: list[str] = ["http://183.220.37.46:5173"] # --- 数据库 --- database_url: str = "postgresql+asyncpg://finetune:finetune123@localhost:5432/finetuning" # --- 训练默认参数 --- default_peft_method: str = "lora" default_epochs: int = 3 default_batch_size: int = 4 default_gradient_accumulation: int = 4 default_lr: float = 2e-4 default_max_seq_length: int = 2048 default_warmup_ratio: float = 0.05 default_save_strategy: str = "epoch" default_eval_strategy: str = "epoch" default_eval_steps: int = 100 # --- LoRA --- lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: str = "all-linear" # --- QLoRA --- qlora_bits: int = 4 qlora_type: str = "nf4" qlora_double_quant: bool = True # --- 上传限制 --- max_upload_size_mb: int = 500 allowed_dataset_formats: str = "jsonl,csv,parquet,json" # --- SSO 统一认证 --- sso_base_url: str = "http://192.168.92.61:8200" sso_client_id: str = "5bdce571-c092-45ff-a491-44a14a000426" sso_client_secret: str = "hmDeOtXZVbeo2AZ-x58yPssZLg4Tcb1W" sso_redirect_uri: str = "http://183.220.37.46:23423/auth/callback" sso_frontend_url: str = "http://183.220.37.46:23423" sso_scope: str = "email" sso_logout_redirect_url: str = "http://192.168.92.61:9200/login" jwt_secret_key: str = "change-me-in-production-use-a-long-random-string" jwt_algorithm: str = "HS256" jwt_access_expire_minutes: int = 20 jwt_refresh_expire_hours: int = 24 @field_validator("backend_cors_origins", mode="before") @classmethod def parse_cors_origins(cls, v): if isinstance(v, str): return [origin.strip() for origin in v.split(",") if origin.strip()] return v @property def models_dir(self) -> Path: return self.data_dir / "models" @property def adapters_dir(self) -> Path: return self.data_dir / "adapters" @property def uploads_dir(self) -> Path: return self.data_dir / "uploads" @property def processed_dir(self) -> Path: return self.data_dir / "processed" def ensure_dirs(self) -> None: self.data_dir.mkdir(parents=True, exist_ok=True) for d in [self.models_dir, self.adapters_dir, self.uploads_dir, self.processed_dir]: d.mkdir(parents=True, exist_ok=True) @lru_cache def get_settings() -> Settings: settings = Settings() settings.ensure_dirs() # 设置 HF 环境变量 if settings.hf_token: os.environ["HF_TOKEN"] = settings.hf_token os.environ["HF_ENDPOINT"] = settings.hf_endpoint if settings.cuda_visible_devices: os.environ["CUDA_VISIBLE_DEVICES"] = settings.cuda_visible_devices return settings