| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- from typing import List, Tuple, Any, Optional, Dict
- from mysql.connector import Error
- from foundation.logger.loggering import server_logger
- from foundation.utils.common import handler_err
- from foundation.base.mysql.async_mysql_conn_pool import AsyncMySQLPool
- import aiomysql
- class AsyncBaseDAO:
- """异步数据库访问基类"""
-
- def __init__(self, db_pool: AsyncMySQLPool):
- self.db_pool = db_pool
-
-
- async def execute_query(self, query: str, params: Tuple = None) -> bool:
- """执行写操作"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(query, params or ())
- return True
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="执行查询失败")
- raise
-
- async def fetch_all(self, query: str, params: Tuple = None) -> List[Dict]:
- """查询多条记录"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(query, params or ())
- return await cursor.fetchall()
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="查询数据失败")
- raise
-
- async def fetch_one(self, query: str, params: Tuple = None) -> Optional[Dict]:
- """查询单条记录"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(query, params or ())
- return await cursor.fetchone()
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="查询单条数据失败")
- raise
-
- async def fetch_scalar(self, query: str, params: Tuple = None) -> Any:
- """查询单个值"""
- result = await self.fetch_one(query, params)
- return list(result.values())[0] if result else None
-
- async def execute_many(self, query: str, params_list: List[Tuple]) -> bool:
- """批量执行"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.executemany(query, params_list)
- return True
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="批量执行失败")
- raise
- async def update_record(self, table: str, updates: Dict, conditions: Dict) -> bool:
- """
- 通用更新记录方法
-
- Args:
- table: 表名
- updates: 要更新的字段和值,如 {'name': '新名字', 'age': 25}
- conditions: 更新条件,如 {'id': 1, 'status': 'active'}
-
- Returns:
- bool: 更新是否成功
- """
- if not updates:
- raise ValueError("更新字段不能为空")
-
- if not conditions:
- raise ValueError("更新条件不能为空")
-
- try:
- # 构建 SET 子句
- set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
- set_values = list(updates.values())
-
- # 构建 WHERE 子句
- where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
- where_values = list(conditions.values())
-
- # 构建完整 SQL
- sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
- params = set_values + where_values
-
- return await self.execute_query(sql, tuple(params))
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="更新记录失败")
- raise
-
- async def update_by_id(self, table: str, record_id: int, updates: Dict) -> bool:
- """
- 根据ID更新记录
-
- Args:
- table: 表名
- record_id: 记录ID
- updates: 要更新的字段和值
-
- Returns:
- bool: 更新是否成功
- """
- return await self.update_record(table, updates, {'id': record_id})
-
- async def update_with_condition(self, table: str, updates: Dict, where_sql: str, params: Tuple = None) -> bool:
- """
- 使用自定义WHERE条件更新记录
-
- Args:
- table: 表名
- updates: 要更新的字段和值
- where_sql: WHERE条件SQL
- params: WHERE条件参数
-
- Returns:
- bool: 更新是否成功
- """
- if not updates:
- raise ValueError("更新字段不能为空")
-
- try:
- # 构建 SET 子句
- set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
- set_values = list(updates.values())
-
- # 构建完整 SQL
- sql = f"UPDATE {table} SET {set_clause} WHERE {where_sql}"
-
- # 合并参数
- all_params = tuple(set_values) + (params if params else ())
-
- return await self.execute_query(sql, all_params)
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="条件更新失败")
- raise
-
- async def batch_update(self, table: str, updates_list: List[Dict], id_field: str = 'id') -> bool:
- """
- 批量更新记录(根据ID)
-
- Args:
- table: 表名
- updates_list: 更新数据列表,每个元素包含id和要更新的字段
- id_field: ID字段名,默认为'id'
-
- Returns:
- bool: 批量更新是否成功
- """
- if not updates_list:
- raise ValueError("更新数据列表不能为空")
-
- try:
- # 使用事务确保批量操作的原子性
- async with self.db_pool.get_connection() as conn:
- async with conn.cursor(aiomysql.DictCursor) as cursor:
- for update_data in updates_list:
- if id_field not in update_data:
- raise ValueError(f"更新数据中缺少{id_field}字段")
-
- record_id = update_data[id_field]
- # 从更新数据中移除ID字段
- update_fields = {k: v for k, v in update_data.items() if k != id_field}
-
- if not update_fields:
- continue
-
- # 构建SET子句
- set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
- set_values = list(update_fields.values())
-
- # 执行更新
- sql = f"UPDATE {table} SET {set_clause} WHERE {id_field} = %s"
- params = set_values + [record_id]
-
- await cursor.execute(sql, params)
-
- # 提交事务
- await conn.commit()
- return True
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="批量更新失败")
- raise
- class TestTabDAO(AsyncBaseDAO):
- """异步用户数据访问对象"""
-
- async def insert_user(self, name: str, email: str, age: int) -> int:
- """插入用户"""
- insert_sql = "INSERT INTO test_tab (name, email, age) VALUES (%s, %s, %s)"
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(insert_sql, (name, email, age))
- return cursor.lastrowid
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="插入用户失败")
- raise
-
- async def get_user_by_id(self, user_id: int) -> Optional[Dict]:
- """根据ID获取用户"""
- query = "SELECT * FROM test_tab WHERE id = %s AND status = 'active'"
- return await self.fetch_one(query, (user_id,))
-
- async def get_all_users(self) -> List[Dict]:
- """获取所有用户"""
- query = "SELECT * FROM test_tab WHERE status = 'active' ORDER BY created_at DESC"
- return await self.fetch_all(query)
-
|