""" MySQL数据库异步连接管理 """ import os import pymysql from urllib.parse import urlparse from typing import Optional, AsyncGenerator import logging from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import DeclarativeBase # 导入配置 from app.core.config import config_handler logger = logging.getLogger(__name__) _db_pool = None _engine = None _async_session_local = None def get_db_config() -> dict: """获取数据库配置""" database_url = config_handler.get("admin_app", "DATABASE_URL", "") if not database_url: raise ValueError("DATABASE_URL配置未设置") parsed = urlparse(database_url) config = { 'host': parsed.hostname or 'localhost', 'port': parsed.port or 3306, 'user': parsed.username or 'root', 'password': parsed.password or '', 'database': parsed.path[1:] if parsed.path else 'sso_db', 'charset': 'utf8mb4', 'cursorclass': pymysql.cursors.DictCursor } return config def get_db_connection(): """获取数据库连接(同步方式,用于简单查询)""" try: config = get_db_config() connection = pymysql.connect(**config) return connection except Exception as e: logger.error(f"数据库连接失败: {e}") return None # ==================== SQLAlchemy 异步支持 ==================== class Base(DeclarativeBase): """数据库模型基类""" pass def get_engine(): """获取 SQLAlchemy 异步引擎""" global _engine if _engine is None: database_url = config_handler.get("admin_app", "DATABASE_URL", "") if not database_url: raise ValueError("DATABASE_URL配置未设置") # 将 mysql:// 转换为 mysql+aiomysql:// if database_url.startswith('mysql://'): database_url = database_url.replace('mysql://', 'mysql+aiomysql://', 1) database_echo = config_handler.get_bool("admin_app", "DATABASE_ECHO", False) _engine = create_async_engine( database_url, echo=database_echo, pool_pre_ping=True, pool_recycle=300, # 每5分钟回收连接(合理) pool_size=10, max_overflow=20, connect_args={ "connect_timeout": 10 } ) logger.info(f"SQLAlchemy 异步引擎已创建") return _engine def get_async_session_local(): """获取异步会话工厂""" global _async_session_local if _async_session_local is None: engine = get_engine() _async_session_local = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) logger.info("异步会话工厂已创建") return _async_session_local async def get_db() -> AsyncGenerator[AsyncSession, None]: """获取数据库会话(依赖注入使用)""" async_session_local = get_async_session_local() async with async_session_local() as session: try: yield session except Exception: await session.rollback() raise finally: await session.close() async def init_db(): """初始化数据库连接池""" global _db_pool try: # 测试同步连接 config = get_db_config() logger.info(f"初始化数据库连接: {config['host']}:{config['port']}/{config['database']}") conn = pymysql.connect(**config) conn.close() logger.info("数据库连接测试成功") # 初始化异步引擎 engine = get_engine() # 创建所有表(如果需要) # async with engine.begin() as conn: # await conn.run_sync(Base.metadata.create_all) logger.info("数据库初始化成功") except Exception as e: logger.error(f"数据库初始化失败: {e}") raise async def close_db(): """关闭数据库连接池""" global _db_pool, _engine, _async_session_local if _engine: try: await _engine.dispose() logger.info("数据库连接池已关闭") except Exception as e: logger.error(f"关闭数据库连接池失败: {e}") finally: _engine = None _async_session_local = None _db_pool = None