import os from functools import lru_cache from pathlib import Path from pydantic import Field, field_validator from pydantic_settings import ( BaseSettings, EnvSettingsSource, SettingsConfigDict, ) class EnvSettingsSourceWithCommaLists(EnvSettingsSource): """Override decode_complex_value to handle comma-separated lists.""" def decode_complex_value(self, field_name, field, value): if field_name == "backend_cors_origins" and isinstance(value, str): return [v.strip() for v in value.split(",") if v.strip()] return super().decode_complex_value(field_name, field, value) class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=str(Path(__file__).resolve().parent.parent / ".env"), env_file_encoding="utf-8", case_sensitive=False, extra="ignore", ) @classmethod def settings_customise_sources( cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings ): return ( init_settings, dotenv_settings, EnvSettingsSourceWithCommaLists(settings_cls), file_secret_settings, ) # --- 数据路径 --- data_dir: Path = Path("/root/Fine-tuning/backend/data") # --- HuggingFace / ModelScope --- hf_token: str = "" hf_endpoint: str = "https://huggingface.co" use_modelscope: bool = False modelscope_endpoint: str = "https://modelscope.cn" # --- GPU / 硬件 --- cuda_visible_devices: str = "0" max_memory_per_gpu: str = "0" use_unsloth: bool = False # --- 后端 --- backend_host: str = "0.0.0.0" backend_port: int = 8000 backend_env: str = "production" backend_log_level: str = "INFO" backend_cors_origins: list[str] = ["http://192.168.91.253:5173"] # --- 数据库 --- database_url: str = "sqlite+aiosqlite:///root/Fine-tuning/backend/data/finetuning.db" # --- 训练默认参数 --- default_peft_method: str = "lora" default_epochs: int = 3 default_batch_size: int = 4 default_gradient_accumulation: int = 4 default_lr: float = 2e-4 default_max_seq_length: int = 2048 default_warmup_ratio: float = 0.05 default_save_strategy: str = "epoch" default_eval_strategy: str = "epoch" default_eval_steps: int = 100 # --- LoRA --- lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: str = "all-linear" # --- QLoRA --- qlora_bits: int = 4 qlora_type: str = "nf4" qlora_double_quant: bool = True # --- 上传限制 --- max_upload_size_mb: int = 500 allowed_dataset_formats: str = "jsonl,csv,parquet,json" @field_validator("backend_cors_origins", mode="before") @classmethod def parse_cors_origins(cls, v): if isinstance(v, str): return [origin.strip() for origin in v.split(",") if origin.strip()] return v @property def models_dir(self) -> Path: return self.data_dir / "models" @property def adapters_dir(self) -> Path: return self.data_dir / "adapters" @property def uploads_dir(self) -> Path: return self.data_dir / "uploads" @property def processed_dir(self) -> Path: return self.data_dir / "processed" def ensure_dirs(self) -> None: self.data_dir.mkdir(parents=True, exist_ok=True) for d in [self.models_dir, self.adapters_dir, self.uploads_dir, self.processed_dir]: d.mkdir(parents=True, exist_ok=True) @lru_cache def get_settings() -> Settings: settings = Settings() settings.ensure_dirs() # 设置 HF 环境变量 if settings.hf_token: os.environ["HF_TOKEN"] = settings.hf_token os.environ["HF_ENDPOINT"] = settings.hf_endpoint if settings.cuda_visible_devices: os.environ["CUDA_VISIBLE_DEVICES"] = settings.cuda_visible_devices return settings