config.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. """
  2. Application configuration module.
  3. Manages OAuth/SSO settings from YAML configuration file.
  4. Supports dev/prod environments via APP_ENV environment variable.
  5. """
  6. import os
  7. import logging
  8. import yaml
  9. from pathlib import Path
  10. logger = logging.getLogger(__name__)
  11. def get_config_path() -> Path:
  12. """
  13. 根据 APP_ENV 环境变量获取配置文件路径
  14. APP_ENV=prod -> config.prod.yaml
  15. APP_ENV=dev -> config.dev.yaml
  16. 默认 -> config.dev.yaml
  17. """
  18. app_env = os.getenv("APP_ENV", "").lower()
  19. base_path = Path(__file__).parent
  20. if app_env == "prod":
  21. config_file = base_path / "config/config.prod.yaml"
  22. logger.info("使用生产环境配置: config.prod.yaml")
  23. elif app_env == "test":
  24. config_file = base_path / "config/config.test.yaml"
  25. logger.info("使用测试环境配置: config.test.yaml")
  26. elif app_env == "dev":
  27. config_file = base_path / "config/config.dev.yaml"
  28. logger.info("使用开发环境配置: config.dev.yaml")
  29. else:
  30. print("默认使用开发环境")
  31. config_file = base_path / "config/config.dev.yaml"
  32. if app_env:
  33. logger.warning(f"未知的 APP_ENV 值: {app_env},使用默认 config.dev.yaml")
  34. return config_file
  35. class Settings:
  36. """Application settings loaded from config YAML."""
  37. def __init__(self):
  38. """Load configuration from YAML file."""
  39. config_path = get_config_path()
  40. if not config_path.exists():
  41. raise FileNotFoundError(f"配置文件不存在: {config_path}")
  42. with open(config_path, 'r', encoding='utf-8') as f:
  43. config = yaml.safe_load(f)
  44. self.APP_ENV = os.getenv("APP_ENV", "default").lower()
  45. print(f"[Config] APP_ENV={self.APP_ENV}, 配置文件={config_path}")
  46. # Database Settings (MySQL only)
  47. db_config = config.get('database', {})
  48. mysql_config = db_config.get('mysql', {})
  49. self.MYSQL_HOST = mysql_config.get('host', 'localhost')
  50. self.MYSQL_PORT = mysql_config.get('port', 3306)
  51. self.MYSQL_USER = mysql_config.get('user', 'root')
  52. self.MYSQL_PASSWORD = mysql_config.get('password', '')
  53. self.MYSQL_DATABASE = mysql_config.get('database', 'annotation_platform')
  54. # OAuth/SSO Settings
  55. oauth_config = config.get('oauth', {})
  56. self.OAUTH_ENABLED = oauth_config.get('enabled', False)
  57. self.OAUTH_BASE_URL = oauth_config.get('base_url', '')
  58. self.OAUTH_CLIENT_ID = oauth_config.get('client_id', '')
  59. self.OAUTH_CLIENT_SECRET = oauth_config.get('client_secret', '')
  60. self.OAUTH_REDIRECT_URI = oauth_config.get('redirect_uri', '')
  61. self.OAUTH_SCOPE = oauth_config.get('scope', 'profile email')
  62. # OAuth Endpoints
  63. self.OAUTH_AUTHORIZE_ENDPOINT = oauth_config.get('authorize_endpoint', '/oauth/authorize')
  64. self.OAUTH_TOKEN_ENDPOINT = oauth_config.get('token_endpoint', '/oauth/token')
  65. self.OAUTH_USERINFO_ENDPOINT = oauth_config.get('userinfo_endpoint', '/oauth/userinfo')
  66. self.OAUTH_REVOKE_ENDPOINT = oauth_config.get('revoke_endpoint', '/oauth/revoke')
  67. # Token Cache TTL (seconds)
  68. self.TOKEN_CACHE_TTL = oauth_config.get('token_cache_ttl', 300)
  69. # Server Settings
  70. server_config = config.get('server', {})
  71. self.SERVER_HOST = server_config.get('host', '0.0.0.0')
  72. self.SERVER_PORT = server_config.get('port', 8000)
  73. self.SERVER_RELOAD = server_config.get('reload', True)
  74. # Create settings instance
  75. settings = Settings()