| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- """
- MySQL数据库异步连接管理
- """
- import os
- import pymysql
- from urllib.parse import urlparse
- from typing import Optional, AsyncGenerator
- import logging
- from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
- from sqlalchemy.orm import DeclarativeBase
- # 导入配置
- from app.core.config import config_handler
- logger = logging.getLogger(__name__)
- _db_pool = None
- _engine = None
- _async_session_local = None
- def get_db_config() -> dict:
- """获取数据库配置"""
- database_url = config_handler.get("admin_app", "DATABASE_URL", "")
- if not database_url:
- raise ValueError("DATABASE_URL配置未设置")
-
- parsed = urlparse(database_url)
- config = {
- 'host': parsed.hostname or 'localhost',
- 'port': parsed.port or 3306,
- 'user': parsed.username or 'root',
- 'password': parsed.password or '',
- 'database': parsed.path[1:] if parsed.path else 'sso_db',
- 'charset': 'utf8mb4',
- 'cursorclass': pymysql.cursors.DictCursor
- }
- return config
- def get_db_connection():
- """获取数据库连接(同步方式,用于简单查询)"""
- try:
- config = get_db_config()
- connection = pymysql.connect(**config)
- return connection
- except Exception as e:
- logger.error(f"数据库连接失败: {e}")
- return None
- # ==================== SQLAlchemy 异步支持 ====================
- class Base(DeclarativeBase):
- """数据库模型基类"""
- pass
- def get_engine():
- """获取 SQLAlchemy 异步引擎"""
- global _engine
- if _engine is None:
- database_url = config_handler.get("admin_app", "DATABASE_URL", "")
- if not database_url:
- raise ValueError("DATABASE_URL配置未设置")
-
- # 将 mysql:// 转换为 mysql+aiomysql://
- if database_url.startswith('mysql://'):
- database_url = database_url.replace('mysql://', 'mysql+aiomysql://', 1)
-
- database_echo = config_handler.get_bool("admin_app", "DATABASE_ECHO", False)
-
- _engine = create_async_engine(
- database_url,
- echo=database_echo,
- pool_pre_ping=True,
- pool_recycle=300, # 每5分钟回收连接(合理)
- pool_size=10,
- max_overflow=20,
- connect_args={
- "connect_timeout": 10
- }
- )
- logger.info(f"SQLAlchemy 异步引擎已创建")
-
- return _engine
- def get_async_session_local():
- """获取异步会话工厂"""
- global _async_session_local
- if _async_session_local is None:
- engine = get_engine()
- _async_session_local = async_sessionmaker(
- engine,
- class_=AsyncSession,
- expire_on_commit=False,
- autocommit=False,
- autoflush=False,
- )
- logger.info("异步会话工厂已创建")
-
- return _async_session_local
- async def get_db() -> AsyncGenerator[AsyncSession, None]:
- """获取数据库会话(依赖注入使用)"""
- async_session_local = get_async_session_local()
- async with async_session_local() as session:
- try:
- yield session
- except Exception:
- await session.rollback()
- raise
- finally:
- await session.close()
- async def init_db():
- """初始化数据库连接池"""
- global _db_pool
- try:
- # 测试同步连接
- config = get_db_config()
- logger.info(f"初始化数据库连接: {config['host']}:{config['port']}/{config['database']}")
- conn = pymysql.connect(**config)
- conn.close()
- logger.info("数据库连接测试成功")
-
- # 初始化异步引擎
- engine = get_engine()
-
- # 创建所有表(如果需要)
- # async with engine.begin() as conn:
- # await conn.run_sync(Base.metadata.create_all)
-
- logger.info("数据库初始化成功")
- except Exception as e:
- logger.error(f"数据库初始化失败: {e}")
- raise
- async def close_db():
- """关闭数据库连接池"""
- global _db_pool, _engine, _async_session_local
-
- if _engine:
- try:
- await _engine.dispose()
- logger.info("数据库连接池已关闭")
- except Exception as e:
- logger.error(f"关闭数据库连接池失败: {e}")
- finally:
- _engine = None
- _async_session_local = None
-
- _db_pool = None
|