async_mysql_connection.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. """
  2. MySQL数据库异步连接管理
  3. """
  4. import os
  5. import pymysql
  6. from urllib.parse import urlparse
  7. from typing import Optional, AsyncGenerator
  8. import logging
  9. from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
  10. from sqlalchemy.orm import DeclarativeBase
  11. # 导入配置
  12. from app.core.config import config_handler
  13. logger = logging.getLogger(__name__)
  14. _db_pool = None
  15. _engine = None
  16. _async_session_local = None
  17. def get_db_config() -> dict:
  18. """获取数据库配置"""
  19. database_url = config_handler.get("admin_app", "DATABASE_URL", "")
  20. if not database_url:
  21. raise ValueError("DATABASE_URL配置未设置")
  22. parsed = urlparse(database_url)
  23. config = {
  24. 'host': parsed.hostname or 'localhost',
  25. 'port': parsed.port or 3306,
  26. 'user': parsed.username or 'root',
  27. 'password': parsed.password or '',
  28. 'database': parsed.path[1:] if parsed.path else 'sso_db',
  29. 'charset': 'utf8mb4',
  30. 'cursorclass': pymysql.cursors.DictCursor
  31. }
  32. return config
  33. def get_db_connection():
  34. """获取数据库连接(同步方式,用于简单查询)"""
  35. try:
  36. config = get_db_config()
  37. connection = pymysql.connect(**config)
  38. return connection
  39. except Exception as e:
  40. logger.error(f"数据库连接失败: {e}")
  41. return None
  42. # ==================== SQLAlchemy 异步支持 ====================
  43. class Base(DeclarativeBase):
  44. """数据库模型基类"""
  45. pass
  46. def get_engine():
  47. """获取 SQLAlchemy 异步引擎"""
  48. global _engine
  49. if _engine is None:
  50. database_url = config_handler.get("admin_app", "DATABASE_URL", "")
  51. if not database_url:
  52. raise ValueError("DATABASE_URL配置未设置")
  53. # 将 mysql:// 转换为 mysql+aiomysql://
  54. if database_url.startswith('mysql://'):
  55. database_url = database_url.replace('mysql://', 'mysql+aiomysql://', 1)
  56. database_echo = config_handler.get_bool("admin_app", "DATABASE_ECHO", False)
  57. _engine = create_async_engine(
  58. database_url,
  59. echo=database_echo,
  60. pool_pre_ping=True,
  61. pool_recycle=300, # 每5分钟回收连接(合理)
  62. pool_size=10,
  63. max_overflow=20,
  64. connect_args={
  65. "connect_timeout": 10
  66. }
  67. )
  68. logger.info(f"SQLAlchemy 异步引擎已创建")
  69. return _engine
  70. def get_async_session_local():
  71. """获取异步会话工厂"""
  72. global _async_session_local
  73. if _async_session_local is None:
  74. engine = get_engine()
  75. _async_session_local = async_sessionmaker(
  76. engine,
  77. class_=AsyncSession,
  78. expire_on_commit=False,
  79. autocommit=False,
  80. autoflush=False,
  81. )
  82. logger.info("异步会话工厂已创建")
  83. return _async_session_local
  84. async def get_db() -> AsyncGenerator[AsyncSession, None]:
  85. """获取数据库会话(依赖注入使用)"""
  86. async_session_local = get_async_session_local()
  87. async with async_session_local() as session:
  88. try:
  89. yield session
  90. except Exception:
  91. await session.rollback()
  92. raise
  93. finally:
  94. await session.close()
  95. async def init_db():
  96. """初始化数据库连接池"""
  97. global _db_pool
  98. try:
  99. # 测试同步连接
  100. config = get_db_config()
  101. logger.info(f"初始化数据库连接: {config['host']}:{config['port']}/{config['database']}")
  102. conn = pymysql.connect(**config)
  103. conn.close()
  104. logger.info("数据库连接测试成功")
  105. # 初始化异步引擎
  106. engine = get_engine()
  107. # 创建所有表(如果需要)
  108. # async with engine.begin() as conn:
  109. # await conn.run_sync(Base.metadata.create_all)
  110. logger.info("数据库初始化成功")
  111. except Exception as e:
  112. logger.error(f"数据库初始化失败: {e}")
  113. raise
  114. async def close_db():
  115. """关闭数据库连接池"""
  116. global _db_pool, _engine, _async_session_local
  117. if _engine:
  118. try:
  119. await _engine.dispose()
  120. logger.info("数据库连接池已关闭")
  121. except Exception as e:
  122. logger.error(f"关闭数据库连接池失败: {e}")
  123. finally:
  124. _engine = None
  125. _async_session_local = None
  126. _db_pool = None