| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- """
- 异步数据库配置模块
- 提供异步数据库连接、会话管理和依赖注入功能
- 使用 asyncpg 驱动实现高性能异步数据库操作
- """
- import os
- import time
- import logging
- from pathlib import Path
- from typing import List, Dict, Any, Type, Optional
- from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
- from sqlalchemy import text, inspect, event
- from sqlalchemy.dialects.postgresql import insert
- from dotenv import load_dotenv
- # 配置慢查询日志记录器
- logger = logging.getLogger(__name__)
- # 加载环境变量(从backend目录的.env文件)
- env_path = Path(__file__).parent.parent.parent / '.env'
- load_dotenv(dotenv_path=env_path)
- # 从环境变量读取数据库配置
- DB_HOST = os.getenv("DB_HOST", "localhost")
- DB_PORT = os.getenv("DB_PORT", "5432")
- DB_USER = os.getenv("DB_USER", "postgres")
- DB_PASSWORD = os.getenv("DB_PASSWORD", "")
- DB_NAME = os.getenv("DB_NAME", "model_square")
- DB_POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
- DB_MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "40"))
- DB_POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
- DB_POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "1800")) # 30分钟回收连接
- # 慢查询阈值配置(毫秒)
- SLOW_QUERY_THRESHOLD_MS = int(os.getenv("SLOW_QUERY_THRESHOLD_MS", "1000"))
- # 构建异步数据库连接URL(使用 asyncpg 驱动,对密码进行URL编码)
- from urllib.parse import quote_plus
- ASYNC_DATABASE_URL = f"postgresql+asyncpg://{DB_USER}:{quote_plus(DB_PASSWORD)}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
- # 创建异步数据库引擎
- async_engine = create_async_engine(
- ASYNC_DATABASE_URL,
- pool_size=DB_POOL_SIZE,
- max_overflow=DB_MAX_OVERFLOW,
- pool_timeout=DB_POOL_TIMEOUT,
- pool_recycle=DB_POOL_RECYCLE,
- pool_pre_ping=True, # 连接前检查连接是否有效
- echo=False # 生产环境关闭SQL日志
- )
- # 慢查询日志事件监听器
- @event.listens_for(async_engine.sync_engine, "before_cursor_execute")
- def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
- """
- 在SQL执行前记录开始时间
- """
- conn.info.setdefault("query_start_time", []).append(time.time())
- @event.listens_for(async_engine.sync_engine, "after_cursor_execute")
- def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
- """
- 在SQL执行后检查执行时间,超过阈值则记录慢查询日志
- """
- start_time_list = conn.info.get("query_start_time", [])
- if start_time_list:
- total_time_ms = (time.time() - start_time_list.pop()) * 1000
- if total_time_ms > SLOW_QUERY_THRESHOLD_MS:
- # 截断SQL语句避免日志过长
- truncated_sql = statement[:500] + "..." if len(statement) > 500 else statement
- # 截断参数避免日志过长
- params_str = str(parameters)[:200] + "..." if len(str(parameters)) > 200 else str(parameters)
- logger.warning(
- f"慢查询 ({total_time_ms:.2f}ms): {truncated_sql} | 参数: {params_str}"
- )
- # 创建异步会话工厂
- AsyncSessionLocal = async_sessionmaker(
- bind=async_engine,
- class_=AsyncSession,
- expire_on_commit=False,
- autocommit=False,
- autoflush=False
- )
- async def get_async_db():
- """
- 异步数据库会话依赖注入函数
-
- 用于FastAPI的依赖注入系统,自动管理异步数据库会话的生命周期
- 确保请求结束后会话被正确关闭
-
- Yields:
- AsyncSession: 异步数据库会话对象
- """
- async with AsyncSessionLocal() as session:
- try:
- yield session
- finally:
- await session.close()
- async def get_async_db_context():
- """
- 异步数据库会话上下文管理器
-
- 用于非依赖注入场景(如后台任务、定时任务等)
-
- Returns:
- AsyncSession: 异步数据库会话对象
- """
- return AsyncSessionLocal()
- async def async_bulk_insert(
- session: AsyncSession,
- model: Type,
- data_list: List[Dict[str, Any]],
- batch_size: int = 1000
- ) -> int:
- """
- 异步批量插入
-
- 使用 PostgreSQL 的批量插入优化,减少数据库往返次数
-
- Args:
- session: 异步数据库会话
- model: SQLAlchemy 模型类
- data_list: 要插入的数据列表
- batch_size: 每批次插入的记录数,默认1000
-
- Returns:
- int: 插入的记录数
- """
- if not data_list:
- return 0
-
- total_inserted = 0
-
- # 分批处理,避免单次插入数据量过大
- for i in range(0, len(data_list), batch_size):
- batch = data_list[i:i + batch_size]
- stmt = insert(model).values(batch)
- await session.execute(stmt)
- total_inserted += len(batch)
-
- return total_inserted
- async def async_bulk_update(
- session: AsyncSession,
- model: Type,
- data_list: List[Dict[str, Any]],
- key_field: str = "id",
- batch_size: int = 500
- ) -> int:
- """
- 异步批量更新
-
- 使用 PostgreSQL 的 ON CONFLICT 实现高效批量更新
-
- Args:
- session: 异步数据库会话
- model: SQLAlchemy 模型类
- data_list: 要更新的数据列表(必须包含 key_field)
- key_field: 用于匹配记录的字段名,默认为 "id"
- batch_size: 每批次更新的记录数,默认500
-
- Returns:
- int: 更新的记录数
- """
- if not data_list:
- return 0
-
- total_updated = 0
-
- # 获取模型的所有列名(排除主键用于更新)
- mapper = inspect(model)
- all_columns = [c.key for c in mapper.columns]
-
- # 分批处理
- for i in range(0, len(data_list), batch_size):
- batch = data_list[i:i + batch_size]
-
- # 构建 upsert 语句
- stmt = insert(model).values(batch)
-
- # 获取需要更新的列(排除 key_field)
- update_columns = {
- col: stmt.excluded[col]
- for col in all_columns
- if col != key_field and col in batch[0]
- }
-
- if update_columns:
- stmt = stmt.on_conflict_do_update(
- index_elements=[key_field],
- set_=update_columns
- )
- await session.execute(stmt)
- total_updated += len(batch)
-
- return total_updated
- async def async_execute_raw(
- session: AsyncSession,
- sql: str,
- params: Optional[Dict[str, Any]] = None
- ) -> Any:
- """
- 异步执行原生 SQL
-
- 用于执行复杂查询或特定的数据库操作
-
- Args:
- session: 异步数据库会话
- sql: SQL 语句
- params: SQL 参数字典
-
- Returns:
- Any: 查询结果
- """
- stmt = text(sql)
- if params:
- result = await session.execute(stmt, params)
- else:
- result = await session.execute(stmt)
- return result
- async def async_bulk_insert_returning(
- session: AsyncSession,
- model: Type,
- data_list: List[Dict[str, Any]],
- returning_fields: List[str] = None,
- batch_size: int = 1000
- ) -> List[Dict[str, Any]]:
- """
- 异步批量插入并返回插入的记录
-
- Args:
- session: 异步数据库会话
- model: SQLAlchemy 模型类
- data_list: 要插入的数据列表
- returning_fields: 需要返回的字段列表,默认返回所有字段
- batch_size: 每批次插入的记录数
-
- Returns:
- List[Dict[str, Any]]: 插入的记录列表
- """
- if not data_list:
- return []
-
- results = []
- mapper = inspect(model)
-
- # 确定返回的字段
- if returning_fields:
- return_columns = [getattr(model, f) for f in returning_fields if hasattr(model, f)]
- else:
- return_columns = [c for c in mapper.columns]
-
- # 分批处理
- for i in range(0, len(data_list), batch_size):
- batch = data_list[i:i + batch_size]
- stmt = insert(model).values(batch).returning(*return_columns)
- result = await session.execute(stmt)
- rows = result.fetchall()
-
- # 转换为字典列表
- for row in rows:
- if returning_fields:
- results.append(dict(zip(returning_fields, row)))
- else:
- results.append(dict(zip([c.key for c in mapper.columns], row)))
-
- return results
|