| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- """
- 数据库配置模块
- 提供数据库连接、会话管理和依赖注入功能
- """
- import os
- from pathlib import Path
- from sqlalchemy import create_engine
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy.orm import sessionmaker
- from dotenv import load_dotenv
- # 加载环境变量(从backend目录的.env文件)
- env_path = Path(__file__).parent.parent / '.env'
- load_dotenv(dotenv_path=env_path)
- # 从环境变量读取数据库配置
- DB_HOST = os.getenv("DB_HOST", "localhost")
- DB_PORT = os.getenv("DB_PORT", "5432")
- DB_USER = os.getenv("DB_USER", "postgres")
- DB_PASSWORD = os.getenv("DB_PASSWORD", "")
- DB_NAME = os.getenv("DB_NAME", "model_square")
- # 连接池配置说明:
- # gunicorn 多进程模式下,每个 worker 都有独立连接池
- # 总连接数 = workers × (pool_size + max_overflow)
- # 服务器 2核 → workers ≈ 5,pool_size=5 → 总连接 ≈ 25,不超过 PostgreSQL 默认 100
- # 如果升级到 4核 → workers ≈ 9,pool_size=5 → 总连接 ≈ 45,仍然安全
- DB_POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "5"))
- DB_MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
- DB_POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "10")) # 等待连接超时缩短到 10s,快速失败
- DB_POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1800"))
- # 构建数据库连接URL(对密码进行URL编码,处理特殊字符如@)
- from urllib.parse import quote_plus
- DATABASE_URL = f"postgresql://{DB_USER}:{quote_plus(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
- # 创建数据库引擎
- engine = create_engine(
- DATABASE_URL,
- pool_size=DB_POOL_SIZE,
- max_overflow=DB_MAX_OVERFLOW,
- pool_timeout=DB_POOL_TIMEOUT,
- pool_recycle=DB_POOL_RECYCLE, # 定期回收连接,防止远程数据库断开空闲连接
- pool_pre_ping=True,
- echo=False
- )
- # 创建会话工厂
- SessionLocal = sessionmaker(
- autocommit=False,
- autoflush=False,
- bind=engine
- )
- # 创建基类
- Base = declarative_base()
- def get_db():
- """
- 数据库会话依赖注入函数
-
- 用于FastAPI的依赖注入系统,自动管理数据库会话的生命周期
- 确保请求结束后会话被正确关闭
-
- Yields:
- Session: 数据库会话对象
- """
- db = SessionLocal()
- try:
- yield db
- finally:
- db.close()
|