config.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. from functools import lru_cache
  3. from pathlib import Path
  4. from pydantic import Field, field_validator
  5. from pydantic_settings import (
  6. BaseSettings,
  7. EnvSettingsSource,
  8. SettingsConfigDict,
  9. )
  10. class EnvSettingsSourceWithCommaLists(EnvSettingsSource):
  11. """Override decode_complex_value to handle comma-separated lists."""
  12. def decode_complex_value(self, field_name, field, value):
  13. if field_name == "backend_cors_origins" and isinstance(value, str):
  14. return [v.strip() for v in value.split(",") if v.strip()]
  15. return super().decode_complex_value(field_name, field, value)
  16. class Settings(BaseSettings):
  17. model_config = SettingsConfigDict(
  18. env_file=str(Path(__file__).resolve().parent.parent / ".env"),
  19. env_file_encoding="utf-8",
  20. case_sensitive=False,
  21. extra="ignore",
  22. )
  23. @classmethod
  24. def settings_customise_sources(
  25. cls, settings_cls, init_settings, env_settings, dotenv_settings, file_secret_settings
  26. ):
  27. return (
  28. init_settings,
  29. dotenv_settings,
  30. EnvSettingsSourceWithCommaLists(settings_cls),
  31. file_secret_settings,
  32. )
  33. # --- 数据路径 ---
  34. data_dir: Path = Path("/root/Fine-tuning/backend/data")
  35. # --- HuggingFace / ModelScope ---
  36. hf_token: str = ""
  37. hf_endpoint: str = "https://huggingface.co"
  38. use_modelscope: bool = False
  39. modelscope_endpoint: str = "https://modelscope.cn"
  40. # --- GPU / 硬件 ---
  41. cuda_visible_devices: str = "0"
  42. max_memory_per_gpu: str = "0"
  43. use_unsloth: bool = False
  44. # --- 后端 ---
  45. backend_host: str = "0.0.0.0"
  46. backend_port: int = 8000
  47. backend_env: str = "production"
  48. backend_log_level: str = "INFO"
  49. backend_cors_origins: list[str] = ["http://192.168.91.253:5173"]
  50. # --- 数据库 ---
  51. database_url: str = "postgresql+asyncpg://finetune:finetune123@localhost:5432/finetuning"
  52. # --- 训练默认参数 ---
  53. default_peft_method: str = "lora"
  54. default_epochs: int = 3
  55. default_batch_size: int = 4
  56. default_gradient_accumulation: int = 4
  57. default_lr: float = 2e-4
  58. default_max_seq_length: int = 2048
  59. default_warmup_ratio: float = 0.05
  60. default_save_strategy: str = "epoch"
  61. default_eval_strategy: str = "epoch"
  62. default_eval_steps: int = 100
  63. # --- LoRA ---
  64. lora_r: int = 16
  65. lora_alpha: int = 32
  66. lora_dropout: float = 0.05
  67. lora_target_modules: str = "all-linear"
  68. # --- QLoRA ---
  69. qlora_bits: int = 4
  70. qlora_type: str = "nf4"
  71. qlora_double_quant: bool = True
  72. # --- 上传限制 ---
  73. max_upload_size_mb: int = 500
  74. allowed_dataset_formats: str = "jsonl,csv,parquet,json"
  75. # --- 分布式计算节点 ---
  76. compute_node_host: str = "" # 算力节点 IP,为空则本地执行
  77. compute_node_ssh_port: int = 22
  78. compute_node_ssh_user: str = "root"
  79. compute_node_ssh_password: str = "" # SSH 密码(与密钥二选一)
  80. compute_node_ssh_key: str = "" # SSH 私钥路径
  81. compute_node_docker_container: str = "finetune-trainer" # 算力节点上的训练容器名
  82. compute_node_python: str = "/opt/conda/bin/python"
  83. compute_node_workdir: str = "/root/Fine-tuning/backend"
  84. compute_node_remote_data_dir: str = "/root/Fine-tuning/backend/data"
  85. compute_node_remote_env: str = "production"
  86. compute_node_ssh_timeout: int = 300 # SSH 命令超时(秒)
  87. @field_validator("backend_cors_origins", mode="before")
  88. @classmethod
  89. def parse_cors_origins(cls, v):
  90. if isinstance(v, str):
  91. return [origin.strip() for origin in v.split(",") if origin.strip()]
  92. return v
  93. @property
  94. def models_dir(self) -> Path:
  95. return self.data_dir / "models"
  96. @property
  97. def adapters_dir(self) -> Path:
  98. return self.data_dir / "adapters"
  99. @property
  100. def uploads_dir(self) -> Path:
  101. return self.data_dir / "uploads"
  102. @property
  103. def processed_dir(self) -> Path:
  104. return self.data_dir / "processed"
  105. @property
  106. def use_remote_compute(self) -> bool:
  107. """是否启用远程算力节点。"""
  108. return bool(self.compute_node_host)
  109. def ensure_dirs(self) -> None:
  110. self.data_dir.mkdir(parents=True, exist_ok=True)
  111. for d in [self.models_dir, self.adapters_dir, self.uploads_dir, self.processed_dir]:
  112. d.mkdir(parents=True, exist_ok=True)
  113. @lru_cache
  114. def get_settings() -> Settings:
  115. settings = Settings()
  116. settings.ensure_dirs()
  117. # 设置 HF 环境变量
  118. if settings.hf_token:
  119. os.environ["HF_TOKEN"] = settings.hf_token
  120. os.environ["HF_ENDPOINT"] = settings.hf_endpoint
  121. if settings.cuda_visible_devices:
  122. os.environ["CUDA_VISIBLE_DEVICES"] = settings.cuda_visible_devices
  123. return settings