|
|
@@ -0,0 +1,376 @@
|
|
|
+from typing import List, Tuple, Any, Optional, Dict
|
|
|
+from mysql.connector import Error
|
|
|
+from foundation.logger.loggering import server_logger
|
|
|
+from foundation.utils.common import handler_err
|
|
|
+from foundation.base.mysql.async_mysql_conn_pool import AsyncMySQLPool
|
|
|
+import aiomysql
|
|
|
+
|
|
|
+class AsyncBaseDAO:
|
|
|
+ """异步数据库访问基类"""
|
|
|
+
|
|
|
+ def __init__(self, db_pool: AsyncMySQLPool):
|
|
|
+ self.db_pool = db_pool
|
|
|
+
|
|
|
+
|
|
|
+ async def execute_query(self, query: str, params: Tuple = None) -> bool:
|
|
|
+ """执行写操作"""
|
|
|
+ try:
|
|
|
+ async with self.db_pool.get_cursor() as cursor:
|
|
|
+ await cursor.execute(query, params or ())
|
|
|
+ return True
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err ,err_name="执行查询失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def fetch_all(self, query: str, params: Tuple = None) -> List[Dict]:
|
|
|
+ """查询多条记录"""
|
|
|
+ try:
|
|
|
+ async with self.db_pool.get_cursor() as cursor:
|
|
|
+ await cursor.execute(query, params or ())
|
|
|
+ return await cursor.fetchall()
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err ,err_name="查询数据失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def fetch_one(self, query: str, params: Tuple = None) -> Optional[Dict]:
|
|
|
+ """查询单条记录"""
|
|
|
+ try:
|
|
|
+ async with self.db_pool.get_cursor() as cursor:
|
|
|
+ await cursor.execute(query, params or ())
|
|
|
+ return await cursor.fetchone()
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err ,err_name="查询单条数据失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def fetch_scalar(self, query: str, params: Tuple = None) -> Any:
|
|
|
+ """查询单个值"""
|
|
|
+ result = await self.fetch_one(query, params)
|
|
|
+ return list(result.values())[0] if result else None
|
|
|
+
|
|
|
+ async def execute_many(self, query: str, params_list: List[Tuple]) -> bool:
|
|
|
+ """批量执行"""
|
|
|
+ try:
|
|
|
+ async with self.db_pool.get_cursor() as cursor:
|
|
|
+ await cursor.executemany(query, params_list)
|
|
|
+ return True
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err ,err_name="批量执行失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def update_record(self, table: str, updates: Dict, conditions: Dict) -> bool:
|
|
|
+ """
|
|
|
+ 通用更新记录方法
|
|
|
+
|
|
|
+ Args:
|
|
|
+ table: 表名
|
|
|
+ updates: 要更新的字段和值,如 {'name': '新名字', 'age': 25}
|
|
|
+ conditions: 更新条件,如 {'id': 1, 'status': 'active'}
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ if not updates:
|
|
|
+ raise ValueError("更新字段不能为空")
|
|
|
+
|
|
|
+ if not conditions:
|
|
|
+ raise ValueError("更新条件不能为空")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 构建 SET 子句
|
|
|
+ set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
|
|
|
+ set_values = list(updates.values())
|
|
|
+
|
|
|
+ # 构建 WHERE 子句
|
|
|
+ where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
|
|
|
+ where_values = list(conditions.values())
|
|
|
+
|
|
|
+ # 构建完整 SQL
|
|
|
+ sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
|
|
|
+ params = set_values + where_values
|
|
|
+
|
|
|
+ return await self.execute_query(sql, tuple(params))
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="更新记录失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def update_by_id(self, table: str, record_id: int, updates: Dict) -> bool:
|
|
|
+ """
|
|
|
+ 根据ID更新记录
|
|
|
+
|
|
|
+ Args:
|
|
|
+ table: 表名
|
|
|
+ record_id: 记录ID
|
|
|
+ updates: 要更新的字段和值
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ return await self.update_record(table, updates, {'id': record_id})
|
|
|
+
|
|
|
+ async def update_with_condition(self, table: str, updates: Dict, where_sql: str, params: Tuple = None) -> bool:
|
|
|
+ """
|
|
|
+ 使用自定义WHERE条件更新记录
|
|
|
+
|
|
|
+ Args:
|
|
|
+ table: 表名
|
|
|
+ updates: 要更新的字段和值
|
|
|
+ where_sql: WHERE条件SQL
|
|
|
+ params: WHERE条件参数
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ if not updates:
|
|
|
+ raise ValueError("更新字段不能为空")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 构建 SET 子句
|
|
|
+ set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
|
|
|
+ set_values = list(updates.values())
|
|
|
+
|
|
|
+ # 构建完整 SQL
|
|
|
+ sql = f"UPDATE {table} SET {set_clause} WHERE {where_sql}"
|
|
|
+
|
|
|
+ # 合并参数
|
|
|
+ all_params = tuple(set_values) + (params if params else ())
|
|
|
+
|
|
|
+ return await self.execute_query(sql, all_params)
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="条件更新失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def batch_update(self, table: str, updates_list: List[Dict], id_field: str = 'id') -> bool:
|
|
|
+ """
|
|
|
+ 批量更新记录(根据ID)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ table: 表名
|
|
|
+ updates_list: 更新数据列表,每个元素包含id和要更新的字段
|
|
|
+ id_field: ID字段名,默认为'id'
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 批量更新是否成功
|
|
|
+ """
|
|
|
+ if not updates_list:
|
|
|
+ raise ValueError("更新数据列表不能为空")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 使用事务确保批量操作的原子性
|
|
|
+ async with self.db_pool.get_connection() as conn:
|
|
|
+ async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|
|
+ for update_data in updates_list:
|
|
|
+ if id_field not in update_data:
|
|
|
+ raise ValueError(f"更新数据中缺少{id_field}字段")
|
|
|
+
|
|
|
+ record_id = update_data[id_field]
|
|
|
+ # 从更新数据中移除ID字段
|
|
|
+ update_fields = {k: v for k, v in update_data.items() if k != id_field}
|
|
|
+
|
|
|
+ if not update_fields:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 构建SET子句
|
|
|
+ set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
|
|
|
+ set_values = list(update_fields.values())
|
|
|
+
|
|
|
+ # 执行更新
|
|
|
+ sql = f"UPDATE {table} SET {set_clause} WHERE {id_field} = %s"
|
|
|
+ params = set_values + [record_id]
|
|
|
+
|
|
|
+ await cursor.execute(sql, params)
|
|
|
+
|
|
|
+ # 提交事务
|
|
|
+ await conn.commit()
|
|
|
+ return True
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="批量更新失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+class TestTabDAO(AsyncBaseDAO):
|
|
|
+ """异步用户数据访问对象"""
|
|
|
+
|
|
|
+
|
|
|
+ async def insert_user(self, name: str, email: str, age: int) -> int:
|
|
|
+ """插入用户"""
|
|
|
+ insert_sql = "INSERT INTO test_tab (name, email, age) VALUES (%s, %s, %s)"
|
|
|
+ try:
|
|
|
+ async with self.db_pool.get_cursor() as cursor:
|
|
|
+ await cursor.execute(insert_sql, (name, email, age))
|
|
|
+ return cursor.lastrowid
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err ,err_name="插入用户失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def get_user_by_id(self, user_id: int) -> Optional[Dict]:
|
|
|
+ """根据ID获取用户"""
|
|
|
+ query = "SELECT * FROM test_tab WHERE id = %s AND status = 'active'"
|
|
|
+ return await self.fetch_one(query, (user_id,))
|
|
|
+
|
|
|
+ async def get_all_users(self) -> List[Dict]:
|
|
|
+ """获取所有用户"""
|
|
|
+ query = "SELECT * FROM test_tab WHERE status = 'active' ORDER BY created_at DESC"
|
|
|
+ return await self.fetch_all(query)
|
|
|
+
|
|
|
+
|
|
|
+ async def get_users_by_condition(self, conditions: Dict) -> List[Dict]:
|
|
|
+ """根据条件查询用户"""
|
|
|
+ if not conditions:
|
|
|
+ return await self.get_all_users()
|
|
|
+
|
|
|
+ try:
|
|
|
+ where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
|
|
|
+ where_values = list(conditions.values())
|
|
|
+
|
|
|
+ query = f"SELECT * FROM test_tab WHERE {where_clause} AND status = 'active' ORDER BY created_at DESC"
|
|
|
+ return await self.fetch_all(query, tuple(where_values))
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="条件查询用户失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ # ========== 修改方法 ==========
|
|
|
+
|
|
|
+ async def update_user(self, user_id: int, **updates) -> bool:
|
|
|
+ """
|
|
|
+ 更新用户信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ user_id: 用户ID
|
|
|
+ **updates: 要更新的字段,如 name='新名字', age=25, email='new@email.com'
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ if not updates:
|
|
|
+ server_logger.warning("没有提供更新字段")
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 过滤允许更新的字段
|
|
|
+ allowed_fields = {'name', 'email', 'age', 'status'}
|
|
|
+ valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
|
|
|
+
|
|
|
+ if not valid_updates:
|
|
|
+ server_logger.warning("没有有效的更新字段")
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ return await self.update_by_id('test_tab', user_id, valid_updates)
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="更新用户失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def update_user_by_email(self, email: str, **updates) -> bool:
|
|
|
+ """
|
|
|
+ 根据邮箱更新用户信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ email: 用户邮箱
|
|
|
+ **updates: 要更新的字段
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ if not updates:
|
|
|
+ server_logger.warning("没有提供更新字段")
|
|
|
+ return False
|
|
|
+
|
|
|
+ # 过滤允许更新的字段
|
|
|
+ allowed_fields = {'name', 'age', 'status'}
|
|
|
+ valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
|
|
|
+
|
|
|
+ if not valid_updates:
|
|
|
+ server_logger.warning("没有有效的更新字段")
|
|
|
+ return False
|
|
|
+
|
|
|
+ try:
|
|
|
+ return await self.update_record('test_tab', valid_updates, {'email': email})
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="根据邮箱更新用户失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def update_user_status(self, user_id: int, status: str) -> bool:
|
|
|
+ """
|
|
|
+ 更新用户状态
|
|
|
+
|
|
|
+ Args:
|
|
|
+ user_id: 用户ID
|
|
|
+ status: 状态值 ('active' 或 'inactive')
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ if status not in ('active', 'inactive'):
|
|
|
+ raise ValueError("状态值必须是 'active' 或 'inactive'")
|
|
|
+
|
|
|
+ try:
|
|
|
+ return await self.update_user(user_id, status=status)
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="更新用户状态失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def batch_update_users(self, updates_list: List[Dict]) -> bool:
|
|
|
+ """
|
|
|
+ 批量更新用户信息
|
|
|
+
|
|
|
+ Args:
|
|
|
+ updates_list: 更新数据列表,每个元素必须包含id字段
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 批量更新是否成功
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ return await self.batch_update('test_tab', updates_list, 'id')
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="批量更新用户失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def update_users_age_range(self, min_age: int, max_age: int, updates: Dict) -> bool:
|
|
|
+ """
|
|
|
+ 更新年龄范围内的用户
|
|
|
+
|
|
|
+ Args:
|
|
|
+ min_age: 最小年龄
|
|
|
+ max_age: 最大年龄
|
|
|
+ updates: 要更新的字段
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ where_sql = "age BETWEEN %s AND %s AND status = 'active'"
|
|
|
+ params = (min_age, max_age)
|
|
|
+
|
|
|
+ return await self.update_with_condition('test_tab', updates, where_sql, params)
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="更新年龄范围用户失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+ async def increment_user_age(self, user_id: int, increment: int = 1) -> bool:
|
|
|
+ """
|
|
|
+ 增加用户年龄
|
|
|
+
|
|
|
+ Args:
|
|
|
+ user_id: 用户ID
|
|
|
+ increment: 增加的值,默认为1
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ bool: 更新是否成功
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ sql = "UPDATE test_tab SET age = age + %s WHERE id = %s AND status = 'active'"
|
|
|
+ return await self.execute_query(sql, (increment, user_id))
|
|
|
+
|
|
|
+ except Exception as err:
|
|
|
+ handler_err(logger=server_logger, err=err, err_name="增加用户年龄失败")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|