""" 异步数据库配置模块 提供异步数据库连接、会话管理和依赖注入功能 使用 asyncpg 驱动实现高性能异步数据库操作 """ import os import time import logging from pathlib import Path from typing import List, Dict, Any, Type, Optional from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy import text, inspect, event from sqlalchemy.dialects.postgresql import insert from dotenv import load_dotenv # 配置慢查询日志记录器 logger = logging.getLogger(__name__) # 加载环境变量(从backend目录的.env文件) env_path = Path(__file__).parent.parent.parent / '.env' load_dotenv(dotenv_path=env_path) # 从环境变量读取数据库配置 DB_HOST = os.getenv("DB_HOST", "localhost") DB_PORT = os.getenv("DB_PORT", "5432") DB_USER = os.getenv("DB_USER", "postgres") DB_PASSWORD = os.getenv("DB_PASSWORD", "") DB_NAME = os.getenv("DB_NAME", "model_square") DB_POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20")) DB_MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "40")) DB_POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30")) DB_POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1800")) # 30分钟回收连接 # 慢查询阈值配置(毫秒) SLOW_QUERY_THRESHOLD_MS = int(os.getenv("SLOW_QUERY_THRESHOLD_MS", "1000")) # 构建异步数据库连接URL(使用 asyncpg 驱动,对密码进行URL编码) from urllib.parse import quote_plus ASYNC_DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{quote_plus(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}" # 创建异步数据库引擎 async_engine = create_async_engine( ASYNC_DATABASE_URL, pool_size=DB_POOL_SIZE, max_overflow=DB_MAX_OVERFLOW, pool_timeout=DB_POOL_TIMEOUT, pool_recycle=DB_POOL_RECYCLE, pool_pre_ping=True, # 连接前检查连接是否有效 echo=False # 生产环境关闭SQL日志 ) # 慢查询日志事件监听器 @event.listens_for(async_engine.sync_engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): """ 在SQL执行前记录开始时间 """ conn.info.setdefault("query_start_time", []).append(time.time()) @event.listens_for(async_engine.sync_engine, "after_cursor_execute") def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): """ 在SQL执行后检查执行时间,超过阈值则记录慢查询日志 """ start_time_list = conn.info.get("query_start_time", []) if start_time_list: total_time_ms = (time.time() - start_time_list.pop()) * 1000 if total_time_ms > SLOW_QUERY_THRESHOLD_MS: # 截断SQL语句避免日志过长 truncated_sql = statement[:500] + "..." if len(statement) > 500 else statement # 截断参数避免日志过长 params_str = str(parameters)[:200] + "..." if len(str(parameters)) > 200 else str(parameters) logger.warning( f"慢查询 ({total_time_ms:.2f}ms): {truncated_sql} | 参数: {params_str}" ) # 创建异步会话工厂 AsyncSessionLocal = async_sessionmaker( bind=async_engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False ) async def get_async_db(): """ 异步数据库会话依赖注入函数 用于FastAPI的依赖注入系统,自动管理异步数据库会话的生命周期 确保请求结束后会话被正确关闭 Yields: AsyncSession: 异步数据库会话对象 """ async with AsyncSessionLocal() as session: try: yield session finally: await session.close() async def get_async_db_context(): """ 异步数据库会话上下文管理器 用于非依赖注入场景(如后台任务、定时任务等) Returns: AsyncSession: 异步数据库会话对象 """ return AsyncSessionLocal() async def async_bulk_insert( session: AsyncSession, model: Type, data_list: List[Dict[str, Any]], batch_size: int = 1000 ) -> int: """ 异步批量插入 使用 PostgreSQL 的批量插入优化,减少数据库往返次数 Args: session: 异步数据库会话 model: SQLAlchemy 模型类 data_list: 要插入的数据列表 batch_size: 每批次插入的记录数,默认1000 Returns: int: 插入的记录数 """ if not data_list: return 0 total_inserted = 0 # 分批处理,避免单次插入数据量过大 for i in range(0, len(data_list), batch_size): batch = data_list[i:i + batch_size] stmt = insert(model).values(batch) await session.execute(stmt) total_inserted += len(batch) return total_inserted async def async_bulk_update( session: AsyncSession, model: Type, data_list: List[Dict[str, Any]], key_field: str = "id", batch_size: int = 500 ) -> int: """ 异步批量更新 使用 PostgreSQL 的 ON CONFLICT 实现高效批量更新 Args: session: 异步数据库会话 model: SQLAlchemy 模型类 data_list: 要更新的数据列表(必须包含 key_field) key_field: 用于匹配记录的字段名,默认为 "id" batch_size: 每批次更新的记录数,默认500 Returns: int: 更新的记录数 """ if not data_list: return 0 total_updated = 0 # 获取模型的所有列名(排除主键用于更新) mapper = inspect(model) all_columns = [c.key for c in mapper.columns] # 分批处理 for i in range(0, len(data_list), batch_size): batch = data_list[i:i + batch_size] # 构建 upsert 语句 stmt = insert(model).values(batch) # 获取需要更新的列(排除 key_field) update_columns = { col: stmt.excluded[col] for col in all_columns if col != key_field and col in batch[0] } if update_columns: stmt = stmt.on_conflict_do_update( index_elements=[key_field], set_=update_columns ) await session.execute(stmt) total_updated += len(batch) return total_updated async def async_execute_raw( session: AsyncSession, sql: str, params: Optional[Dict[str, Any]] = None ) -> Any: """ 异步执行原生 SQL 用于执行复杂查询或特定的数据库操作 Args: session: 异步数据库会话 sql: SQL 语句 params: SQL 参数字典 Returns: Any: 查询结果 """ stmt = text(sql) if params: result = await session.execute(stmt, params) else: result = await session.execute(stmt) return result async def async_bulk_insert_returning( session: AsyncSession, model: Type, data_list: List[Dict[str, Any]], returning_fields: List[str] = None, batch_size: int = 1000 ) -> List[Dict[str, Any]]: """ 异步批量插入并返回插入的记录 Args: session: 异步数据库会话 model: SQLAlchemy 模型类 data_list: 要插入的数据列表 returning_fields: 需要返回的字段列表,默认返回所有字段 batch_size: 每批次插入的记录数 Returns: List[Dict[str, Any]]: 插入的记录列表 """ if not data_list: return [] results = [] mapper = inspect(model) # 确定返回的字段 if returning_fields: return_columns = [getattr(model, f) for f in returning_fields if hasattr(model, f)] else: return_columns = [c for c in mapper.columns] # 分批处理 for i in range(0, len(data_list), batch_size): batch = data_list[i:i + batch_size] stmt = insert(model).values(batch).returning(*return_columns) result = await session.execute(stmt) rows = result.fetchall() # 转换为字典列表 for row in rows: if returning_fields: results.append(dict(zip(returning_fields, row))) else: results.append(dict(zip([c.key for c in mapper.columns], row))) return results