| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- from typing import List, Tuple, Any, Optional, Dict
- from mysql.connector import Error
- from foundation.logger.loggering import server_logger
- from foundation.utils.common import handler_err
- from foundation.base.mysql.async_mysql_conn_pool import AsyncMySQLPool
- import aiomysql
- class AsyncBaseDAO:
- """异步数据库访问基类"""
-
- def __init__(self, db_pool: AsyncMySQLPool):
- self.db_pool = db_pool
-
-
- async def execute_query(self, query: str, params: Tuple = None) -> bool:
- """执行写操作"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(query, params or ())
- return True
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="执行查询失败")
- raise
-
- async def fetch_all(self, query: str, params: Tuple = None) -> List[Dict]:
- """查询多条记录"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(query, params or ())
- return await cursor.fetchall()
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="查询数据失败")
- raise
-
- async def fetch_one(self, query: str, params: Tuple = None) -> Optional[Dict]:
- """查询单条记录"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(query, params or ())
- return await cursor.fetchone()
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="查询单条数据失败")
- raise
-
- async def fetch_scalar(self, query: str, params: Tuple = None) -> Any:
- """查询单个值"""
- result = await self.fetch_one(query, params)
- return list(result.values())[0] if result else None
-
- async def execute_many(self, query: str, params_list: List[Tuple]) -> bool:
- """批量执行"""
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.executemany(query, params_list)
- return True
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="批量执行失败")
- raise
- async def update_record(self, table: str, updates: Dict, conditions: Dict) -> bool:
- """
- 通用更新记录方法
-
- Args:
- table: 表名
- updates: 要更新的字段和值,如 {'name': '新名字', 'age': 25}
- conditions: 更新条件,如 {'id': 1, 'status': 'active'}
-
- Returns:
- bool: 更新是否成功
- """
- if not updates:
- raise ValueError("更新字段不能为空")
-
- if not conditions:
- raise ValueError("更新条件不能为空")
-
- try:
- # 构建 SET 子句
- set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
- set_values = list(updates.values())
-
- # 构建 WHERE 子句
- where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
- where_values = list(conditions.values())
-
- # 构建完整 SQL
- sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
- params = set_values + where_values
-
- return await self.execute_query(sql, tuple(params))
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="更新记录失败")
- raise
-
- async def update_by_id(self, table: str, record_id: int, updates: Dict) -> bool:
- """
- 根据ID更新记录
-
- Args:
- table: 表名
- record_id: 记录ID
- updates: 要更新的字段和值
-
- Returns:
- bool: 更新是否成功
- """
- return await self.update_record(table, updates, {'id': record_id})
-
- async def update_with_condition(self, table: str, updates: Dict, where_sql: str, params: Tuple = None) -> bool:
- """
- 使用自定义WHERE条件更新记录
-
- Args:
- table: 表名
- updates: 要更新的字段和值
- where_sql: WHERE条件SQL
- params: WHERE条件参数
-
- Returns:
- bool: 更新是否成功
- """
- if not updates:
- raise ValueError("更新字段不能为空")
-
- try:
- # 构建 SET 子句
- set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
- set_values = list(updates.values())
-
- # 构建完整 SQL
- sql = f"UPDATE {table} SET {set_clause} WHERE {where_sql}"
-
- # 合并参数
- all_params = tuple(set_values) + (params if params else ())
-
- return await self.execute_query(sql, all_params)
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="条件更新失败")
- raise
-
- async def batch_update(self, table: str, updates_list: List[Dict], id_field: str = 'id') -> bool:
- """
- 批量更新记录(根据ID)
-
- Args:
- table: 表名
- updates_list: 更新数据列表,每个元素包含id和要更新的字段
- id_field: ID字段名,默认为'id'
-
- Returns:
- bool: 批量更新是否成功
- """
- if not updates_list:
- raise ValueError("更新数据列表不能为空")
-
- try:
- # 使用事务确保批量操作的原子性
- async with self.db_pool.get_connection() as conn:
- async with conn.cursor(aiomysql.DictCursor) as cursor:
- for update_data in updates_list:
- if id_field not in update_data:
- raise ValueError(f"更新数据中缺少{id_field}字段")
-
- record_id = update_data[id_field]
- # 从更新数据中移除ID字段
- update_fields = {k: v for k, v in update_data.items() if k != id_field}
-
- if not update_fields:
- continue
-
- # 构建SET子句
- set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
- set_values = list(update_fields.values())
-
- # 执行更新
- sql = f"UPDATE {table} SET {set_clause} WHERE {id_field} = %s"
- params = set_values + [record_id]
-
- await cursor.execute(sql, params)
-
- # 提交事务
- await conn.commit()
- return True
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="批量更新失败")
- raise
- class TestTabDAO(AsyncBaseDAO):
- """异步用户数据访问对象"""
-
- async def insert_user(self, name: str, email: str, age: int) -> int:
- """插入用户"""
- insert_sql = "INSERT INTO test_tab (name, email, age) VALUES (%s, %s, %s)"
- try:
- async with self.db_pool.get_cursor() as cursor:
- await cursor.execute(insert_sql, (name, email, age))
- return cursor.lastrowid
- except Exception as err:
- handler_err(logger=server_logger, err=err ,err_name="插入用户失败")
- raise
-
- async def get_user_by_id(self, user_id: int) -> Optional[Dict]:
- """根据ID获取用户"""
- query = "SELECT * FROM test_tab WHERE id = %s AND status = 'active'"
- return await self.fetch_one(query, (user_id,))
-
- async def get_all_users(self) -> List[Dict]:
- """获取所有用户"""
- query = "SELECT * FROM test_tab WHERE status = 'active' ORDER BY created_at DESC"
- return await self.fetch_all(query)
-
- async def get_users_by_condition(self, conditions: Dict) -> List[Dict]:
- """根据条件查询用户"""
- if not conditions:
- return await self.get_all_users()
-
- try:
- where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
- where_values = list(conditions.values())
-
- query = f"SELECT * FROM test_tab WHERE {where_clause} AND status = 'active' ORDER BY created_at DESC"
- return await self.fetch_all(query, tuple(where_values))
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="条件查询用户失败")
- raise
-
- # ========== 修改方法 ==========
-
- async def update_user(self, user_id: int, **updates) -> bool:
- """
- 更新用户信息
-
- Args:
- user_id: 用户ID
- **updates: 要更新的字段,如 name='新名字', age=25, email='new@email.com'
-
- Returns:
- bool: 更新是否成功
- """
- if not updates:
- server_logger.warning("没有提供更新字段")
- return False
-
- # 过滤允许更新的字段
- allowed_fields = {'name', 'email', 'age', 'status'}
- valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
-
- if not valid_updates:
- server_logger.warning("没有有效的更新字段")
- return False
-
- try:
- return await self.update_by_id('test_tab', user_id, valid_updates)
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="更新用户失败")
- raise
-
- async def update_user_by_email(self, email: str, **updates) -> bool:
- """
- 根据邮箱更新用户信息
-
- Args:
- email: 用户邮箱
- **updates: 要更新的字段
-
- Returns:
- bool: 更新是否成功
- """
- if not updates:
- server_logger.warning("没有提供更新字段")
- return False
-
- # 过滤允许更新的字段
- allowed_fields = {'name', 'age', 'status'}
- valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
-
- if not valid_updates:
- server_logger.warning("没有有效的更新字段")
- return False
-
- try:
- return await self.update_record('test_tab', valid_updates, {'email': email})
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="根据邮箱更新用户失败")
- raise
-
- async def update_user_status(self, user_id: int, status: str) -> bool:
- """
- 更新用户状态
-
- Args:
- user_id: 用户ID
- status: 状态值 ('active' 或 'inactive')
-
- Returns:
- bool: 更新是否成功
- """
- if status not in ('active', 'inactive'):
- raise ValueError("状态值必须是 'active' 或 'inactive'")
-
- try:
- return await self.update_user(user_id, status=status)
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="更新用户状态失败")
- raise
-
- async def batch_update_users(self, updates_list: List[Dict]) -> bool:
- """
- 批量更新用户信息
-
- Args:
- updates_list: 更新数据列表,每个元素必须包含id字段
-
- Returns:
- bool: 批量更新是否成功
- """
- try:
- return await self.batch_update('test_tab', updates_list, 'id')
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="批量更新用户失败")
- raise
-
- async def update_users_age_range(self, min_age: int, max_age: int, updates: Dict) -> bool:
- """
- 更新年龄范围内的用户
-
- Args:
- min_age: 最小年龄
- max_age: 最大年龄
- updates: 要更新的字段
-
- Returns:
- bool: 更新是否成功
- """
- try:
- where_sql = "age BETWEEN %s AND %s AND status = 'active'"
- params = (min_age, max_age)
-
- return await self.update_with_condition('test_tab', updates, where_sql, params)
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="更新年龄范围用户失败")
- raise
-
- async def increment_user_age(self, user_id: int, increment: int = 1) -> bool:
- """
- 增加用户年龄
-
- Args:
- user_id: 用户ID
- increment: 增加的值,默认为1
-
- Returns:
- bool: 更新是否成功
- """
- try:
- sql = "UPDATE test_tab SET age = age + %s WHERE id = %s AND status = 'active'"
- return await self.execute_query(sql, (increment, user_id))
-
- except Exception as err:
- handler_err(logger=server_logger, err=err, err_name="增加用户年龄失败")
- raise
|