async_database.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. """
  2. 异步数据库配置模块
  3. 提供异步数据库连接、会话管理和依赖注入功能
  4. 使用 asyncpg 驱动实现高性能异步数据库操作
  5. """
  6. import os
  7. import time
  8. import logging
  9. from pathlib import Path
  10. from typing import List, Dict, Any, Type, Optional
  11. from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
  12. from sqlalchemy import text, inspect, event
  13. from sqlalchemy.dialects.postgresql import insert
  14. from dotenv import load_dotenv
  15. # 配置慢查询日志记录器
  16. logger = logging.getLogger(__name__)
  17. # 加载环境变量(从backend目录的.env文件)
  18. env_path = Path(__file__).parent.parent.parent / '.env'
  19. load_dotenv(dotenv_path=env_path)
  20. # 从环境变量读取数据库配置
  21. DB_HOST = os.getenv("DB_HOST", "localhost")
  22. DB_PORT = os.getenv("DB_PORT", "5432")
  23. DB_USER = os.getenv("DB_USER", "postgres")
  24. DB_PASSWORD = os.getenv("DB_PASSWORD", "")
  25. DB_NAME = os.getenv("DB_NAME", "model_square")
  26. DB_POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
  27. DB_MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "40"))
  28. DB_POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
  29. DB_POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1800")) # 30分钟回收连接
  30. # 慢查询阈值配置(毫秒)
  31. SLOW_QUERY_THRESHOLD_MS = int(os.getenv("SLOW_QUERY_THRESHOLD_MS", "1000"))
  32. # 构建异步数据库连接URL(使用 asyncpg 驱动,对密码进行URL编码)
  33. from urllib.parse import quote_plus
  34. ASYNC_DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{quote_plus(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
  35. # 创建异步数据库引擎
  36. async_engine = create_async_engine(
  37. ASYNC_DATABASE_URL,
  38. pool_size=DB_POOL_SIZE,
  39. max_overflow=DB_MAX_OVERFLOW,
  40. pool_timeout=DB_POOL_TIMEOUT,
  41. pool_recycle=DB_POOL_RECYCLE,
  42. pool_pre_ping=True, # 连接前检查连接是否有效
  43. echo=False # 生产环境关闭SQL日志
  44. )
  45. # 慢查询日志事件监听器
  46. @event.listens_for(async_engine.sync_engine, "before_cursor_execute")
  47. def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
  48. """
  49. 在SQL执行前记录开始时间
  50. """
  51. conn.info.setdefault("query_start_time", []).append(time.time())
  52. @event.listens_for(async_engine.sync_engine, "after_cursor_execute")
  53. def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
  54. """
  55. 在SQL执行后检查执行时间,超过阈值则记录慢查询日志
  56. """
  57. start_time_list = conn.info.get("query_start_time", [])
  58. if start_time_list:
  59. total_time_ms = (time.time() - start_time_list.pop()) * 1000
  60. if total_time_ms > SLOW_QUERY_THRESHOLD_MS:
  61. # 截断SQL语句避免日志过长
  62. truncated_sql = statement[:500] + "..." if len(statement) > 500 else statement
  63. # 截断参数避免日志过长
  64. params_str = str(parameters)[:200] + "..." if len(str(parameters)) > 200 else str(parameters)
  65. logger.warning(
  66. f"慢查询 ({total_time_ms:.2f}ms): {truncated_sql} | 参数: {params_str}"
  67. )
  68. # 创建异步会话工厂
  69. AsyncSessionLocal = async_sessionmaker(
  70. bind=async_engine,
  71. class_=AsyncSession,
  72. expire_on_commit=False,
  73. autocommit=False,
  74. autoflush=False
  75. )
  76. async def get_async_db():
  77. """
  78. 异步数据库会话依赖注入函数
  79. 用于FastAPI的依赖注入系统,自动管理异步数据库会话的生命周期
  80. 确保请求结束后会话被正确关闭
  81. Yields:
  82. AsyncSession: 异步数据库会话对象
  83. """
  84. async with AsyncSessionLocal() as session:
  85. try:
  86. yield session
  87. finally:
  88. await session.close()
  89. async def get_async_db_context():
  90. """
  91. 异步数据库会话上下文管理器
  92. 用于非依赖注入场景(如后台任务、定时任务等)
  93. Returns:
  94. AsyncSession: 异步数据库会话对象
  95. """
  96. return AsyncSessionLocal()
  97. async def async_bulk_insert(
  98. session: AsyncSession,
  99. model: Type,
  100. data_list: List[Dict[str, Any]],
  101. batch_size: int = 1000
  102. ) -> int:
  103. """
  104. 异步批量插入
  105. 使用 PostgreSQL 的批量插入优化,减少数据库往返次数
  106. Args:
  107. session: 异步数据库会话
  108. model: SQLAlchemy 模型类
  109. data_list: 要插入的数据列表
  110. batch_size: 每批次插入的记录数,默认1000
  111. Returns:
  112. int: 插入的记录数
  113. """
  114. if not data_list:
  115. return 0
  116. total_inserted = 0
  117. # 分批处理,避免单次插入数据量过大
  118. for i in range(0, len(data_list), batch_size):
  119. batch = data_list[i:i + batch_size]
  120. stmt = insert(model).values(batch)
  121. await session.execute(stmt)
  122. total_inserted += len(batch)
  123. return total_inserted
  124. async def async_bulk_update(
  125. session: AsyncSession,
  126. model: Type,
  127. data_list: List[Dict[str, Any]],
  128. key_field: str = "id",
  129. batch_size: int = 500
  130. ) -> int:
  131. """
  132. 异步批量更新
  133. 使用 PostgreSQL 的 ON CONFLICT 实现高效批量更新
  134. Args:
  135. session: 异步数据库会话
  136. model: SQLAlchemy 模型类
  137. data_list: 要更新的数据列表(必须包含 key_field)
  138. key_field: 用于匹配记录的字段名,默认为 "id"
  139. batch_size: 每批次更新的记录数,默认500
  140. Returns:
  141. int: 更新的记录数
  142. """
  143. if not data_list:
  144. return 0
  145. total_updated = 0
  146. # 获取模型的所有列名(排除主键用于更新)
  147. mapper = inspect(model)
  148. all_columns = [c.key for c in mapper.columns]
  149. # 分批处理
  150. for i in range(0, len(data_list), batch_size):
  151. batch = data_list[i:i + batch_size]
  152. # 构建 upsert 语句
  153. stmt = insert(model).values(batch)
  154. # 获取需要更新的列(排除 key_field)
  155. update_columns = {
  156. col: stmt.excluded[col]
  157. for col in all_columns
  158. if col != key_field and col in batch[0]
  159. }
  160. if update_columns:
  161. stmt = stmt.on_conflict_do_update(
  162. index_elements=[key_field],
  163. set_=update_columns
  164. )
  165. await session.execute(stmt)
  166. total_updated += len(batch)
  167. return total_updated
  168. async def async_execute_raw(
  169. session: AsyncSession,
  170. sql: str,
  171. params: Optional[Dict[str, Any]] = None
  172. ) -> Any:
  173. """
  174. 异步执行原生 SQL
  175. 用于执行复杂查询或特定的数据库操作
  176. Args:
  177. session: 异步数据库会话
  178. sql: SQL 语句
  179. params: SQL 参数字典
  180. Returns:
  181. Any: 查询结果
  182. """
  183. stmt = text(sql)
  184. if params:
  185. result = await session.execute(stmt, params)
  186. else:
  187. result = await session.execute(stmt)
  188. return result
  189. async def async_bulk_insert_returning(
  190. session: AsyncSession,
  191. model: Type,
  192. data_list: List[Dict[str, Any]],
  193. returning_fields: List[str] = None,
  194. batch_size: int = 1000
  195. ) -> List[Dict[str, Any]]:
  196. """
  197. 异步批量插入并返回插入的记录
  198. Args:
  199. session: 异步数据库会话
  200. model: SQLAlchemy 模型类
  201. data_list: 要插入的数据列表
  202. returning_fields: 需要返回的字段列表,默认返回所有字段
  203. batch_size: 每批次插入的记录数
  204. Returns:
  205. List[Dict[str, Any]]: 插入的记录列表
  206. """
  207. if not data_list:
  208. return []
  209. results = []
  210. mapper = inspect(model)
  211. # 确定返回的字段
  212. if returning_fields:
  213. return_columns = [getattr(model, f) for f in returning_fields if hasattr(model, f)]
  214. else:
  215. return_columns = [c for c in mapper.columns]
  216. # 分批处理
  217. for i in range(0, len(data_list), batch_size):
  218. batch = data_list[i:i + batch_size]
  219. stmt = insert(model).values(batch).returning(*return_columns)
  220. result = await session.execute(stmt)
  221. rows = result.fetchall()
  222. # 转换为字典列表
  223. for row in rows:
  224. if returning_fields:
  225. results.append(dict(zip(returning_fields, row)))
  226. else:
  227. results.append(dict(zip([c.key for c in mapper.columns], row)))
  228. return results