from typing import List, Tuple, Any, Optional, Dict from mysql.connector import Error from foundation.observability.logger.loggering import server_logger from foundation.utils.common import handler_err from foundation.database.base.sql.async_mysql_conn_pool import AsyncMySQLPool import aiomysql class AsyncBaseDAO: """异步数据库访问基类""" def __init__(self, db_pool: AsyncMySQLPool): self.db_pool = db_pool async def execute_query(self, query: str, params: Tuple = None) -> bool: """执行写操作""" try: async with self.db_pool.get_cursor() as cursor: await cursor.execute(query, params or ()) return True except Exception as err: handler_err(logger=server_logger, err=err ,err_name="执行查询失败") raise async def fetch_all(self, query: str, params: Tuple = None) -> List[Dict]: """查询多条记录""" try: async with self.db_pool.get_cursor() as cursor: await cursor.execute(query, params or ()) return await cursor.fetchall() except Exception as err: handler_err(logger=server_logger, err=err ,err_name="查询数据失败") raise async def fetch_one(self, query: str, params: Tuple = None) -> Optional[Dict]: """查询单条记录""" try: async with self.db_pool.get_cursor() as cursor: await cursor.execute(query, params or ()) return await cursor.fetchone() except Exception as err: handler_err(logger=server_logger, err=err ,err_name="查询单条数据失败") raise async def fetch_scalar(self, query: str, params: Tuple = None) -> Any: """查询单个值""" result = await self.fetch_one(query, params) return list(result.values())[0] if result else None async def execute_many(self, query: str, params_list: List[Tuple]) -> bool: """批量执行""" try: async with self.db_pool.get_cursor() as cursor: await cursor.executemany(query, params_list) return True except Exception as err: handler_err(logger=server_logger, err=err ,err_name="批量执行失败") raise async def update_record(self, table: str, updates: Dict, conditions: Dict) -> bool: """ 通用更新记录方法 Args: table: 表名 updates: 要更新的字段和值,如 {'name': '新名字', 'age': 25} conditions: 更新条件,如 {'id': 1, 'status': 'active'} Returns: bool: 更新是否成功 """ if not updates: raise ValueError("更新字段不能为空") if not conditions: raise ValueError("更新条件不能为空") try: # 构建 SET 子句 set_clause = ", ".join([f"{field} = %s" for field in updates.keys()]) set_values = list(updates.values()) # 构建 WHERE 子句 where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()]) where_values = list(conditions.values()) # 构建完整 SQL sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}" params = set_values + where_values return await self.execute_query(sql, tuple(params)) except Exception as err: handler_err(logger=server_logger, err=err, err_name="更新记录失败") raise async def update_by_id(self, table: str, record_id: int, updates: Dict) -> bool: """ 根据ID更新记录 Args: table: 表名 record_id: 记录ID updates: 要更新的字段和值 Returns: bool: 更新是否成功 """ return await self.update_record(table, updates, {'id': record_id}) async def update_with_condition(self, table: str, updates: Dict, where_sql: str, params: Tuple = None) -> bool: """ 使用自定义WHERE条件更新记录 Args: table: 表名 updates: 要更新的字段和值 where_sql: WHERE条件SQL params: WHERE条件参数 Returns: bool: 更新是否成功 """ if not updates: raise ValueError("更新字段不能为空") try: # 构建 SET 子句 set_clause = ", ".join([f"{field} = %s" for field in updates.keys()]) set_values = list(updates.values()) # 构建完整 SQL sql = f"UPDATE {table} SET {set_clause} WHERE {where_sql}" # 合并参数 all_params = tuple(set_values) + (params if params else ()) return await self.execute_query(sql, all_params) except Exception as err: handler_err(logger=server_logger, err=err, err_name="条件更新失败") raise async def batch_update(self, table: str, updates_list: List[Dict], id_field: str = 'id') -> bool: """ 批量更新记录(根据ID) Args: table: 表名 updates_list: 更新数据列表,每个元素包含id和要更新的字段 id_field: ID字段名,默认为'id' Returns: bool: 批量更新是否成功 """ if not updates_list: raise ValueError("更新数据列表不能为空") try: # 使用事务确保批量操作的原子性 async with self.db_pool.get_connection() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: for update_data in updates_list: if id_field not in update_data: raise ValueError(f"更新数据中缺少{id_field}字段") record_id = update_data[id_field] # 从更新数据中移除ID字段 update_fields = {k: v for k, v in update_data.items() if k != id_field} if not update_fields: continue # 构建SET子句 set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()]) set_values = list(update_fields.values()) # 执行更新 sql = f"UPDATE {table} SET {set_clause} WHERE {id_field} = %s" params = set_values + [record_id] await cursor.execute(sql, params) # 提交事务 await conn.commit() return True except Exception as err: handler_err(logger=server_logger, err=err, err_name="批量更新失败") raise class TestTabDAO(AsyncBaseDAO): """异步用户数据访问对象""" async def insert_user(self, name: str, email: str, age: int) -> int: """插入用户""" insert_sql = "INSERT INTO test_tab (name, email, age) VALUES (%s, %s, %s)" try: async with self.db_pool.get_cursor() as cursor: await cursor.execute(insert_sql, (name, email, age)) return cursor.lastrowid except Exception as err: handler_err(logger=server_logger, err=err ,err_name="插入用户失败") raise async def get_user_by_id(self, user_id: int) -> Optional[Dict]: """根据ID获取用户""" query = "SELECT * FROM test_tab WHERE id = %s AND status = 'active'" return await self.fetch_one(query, (user_id,)) async def get_all_users(self) -> List[Dict]: """获取所有用户""" query = "SELECT * FROM test_tab WHERE status = 'active' ORDER BY created_at DESC" return await self.fetch_all(query)