async_mysql_base_dao.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. from typing import List, Tuple, Any, Optional, Dict
  2. from mysql.connector import Error
  3. from foundation.observability.logger.loggering import server_logger
  4. from foundation.utils.common import handler_err
  5. from async_mysql_conn_pool import AsyncMySQLPool
  6. import aiomysql
  7. class AsyncBaseDAO:
  8. """异步数据库访问基类"""
  9. def __init__(self, db_pool: AsyncMySQLPool):
  10. self.db_pool = db_pool
  11. async def execute_query(self, query: str, params: Tuple = None) -> bool:
  12. """执行写操作"""
  13. try:
  14. async with self.db_pool.get_cursor() as cursor:
  15. await cursor.execute(query, params or ())
  16. return True
  17. except Exception as err:
  18. handler_err(logger=server_logger, err=err ,err_name="执行查询失败")
  19. raise
  20. async def fetch_all(self, query: str, params: Tuple = None) -> List[Dict]:
  21. """查询多条记录"""
  22. try:
  23. async with self.db_pool.get_cursor() as cursor:
  24. await cursor.execute(query, params or ())
  25. return await cursor.fetchall()
  26. except Exception as err:
  27. handler_err(logger=server_logger, err=err ,err_name="查询数据失败")
  28. raise
  29. async def fetch_one(self, query: str, params: Tuple = None) -> Optional[Dict]:
  30. """查询单条记录"""
  31. try:
  32. async with self.db_pool.get_cursor() as cursor:
  33. await cursor.execute(query, params or ())
  34. return await cursor.fetchone()
  35. except Exception as err:
  36. handler_err(logger=server_logger, err=err ,err_name="查询单条数据失败")
  37. raise
  38. async def fetch_scalar(self, query: str, params: Tuple = None) -> Any:
  39. """查询单个值"""
  40. result = await self.fetch_one(query, params)
  41. return list(result.values())[0] if result else None
  42. async def execute_many(self, query: str, params_list: List[Tuple]) -> bool:
  43. """批量执行"""
  44. try:
  45. async with self.db_pool.get_cursor() as cursor:
  46. await cursor.executemany(query, params_list)
  47. return True
  48. except Exception as err:
  49. handler_err(logger=server_logger, err=err ,err_name="批量执行失败")
  50. raise
  51. async def update_record(self, table: str, updates: Dict, conditions: Dict) -> bool:
  52. """
  53. 通用更新记录方法
  54. Args:
  55. table: 表名
  56. updates: 要更新的字段和值,如 {'name': '新名字', 'age': 25}
  57. conditions: 更新条件,如 {'id': 1, 'status': 'active'}
  58. Returns:
  59. bool: 更新是否成功
  60. """
  61. if not updates:
  62. raise ValueError("更新字段不能为空")
  63. if not conditions:
  64. raise ValueError("更新条件不能为空")
  65. try:
  66. # 构建 SET 子句
  67. set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
  68. set_values = list(updates.values())
  69. # 构建 WHERE 子句
  70. where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
  71. where_values = list(conditions.values())
  72. # 构建完整 SQL
  73. sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
  74. params = set_values + where_values
  75. return await self.execute_query(sql, tuple(params))
  76. except Exception as err:
  77. handler_err(logger=server_logger, err=err, err_name="更新记录失败")
  78. raise
  79. async def update_by_id(self, table: str, record_id: int, updates: Dict) -> bool:
  80. """
  81. 根据ID更新记录
  82. Args:
  83. table: 表名
  84. record_id: 记录ID
  85. updates: 要更新的字段和值
  86. Returns:
  87. bool: 更新是否成功
  88. """
  89. return await self.update_record(table, updates, {'id': record_id})
  90. async def update_with_condition(self, table: str, updates: Dict, where_sql: str, params: Tuple = None) -> bool:
  91. """
  92. 使用自定义WHERE条件更新记录
  93. Args:
  94. table: 表名
  95. updates: 要更新的字段和值
  96. where_sql: WHERE条件SQL
  97. params: WHERE条件参数
  98. Returns:
  99. bool: 更新是否成功
  100. """
  101. if not updates:
  102. raise ValueError("更新字段不能为空")
  103. try:
  104. # 构建 SET 子句
  105. set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
  106. set_values = list(updates.values())
  107. # 构建完整 SQL
  108. sql = f"UPDATE {table} SET {set_clause} WHERE {where_sql}"
  109. # 合并参数
  110. all_params = tuple(set_values) + (params if params else ())
  111. return await self.execute_query(sql, all_params)
  112. except Exception as err:
  113. handler_err(logger=server_logger, err=err, err_name="条件更新失败")
  114. raise
  115. async def batch_update(self, table: str, updates_list: List[Dict], id_field: str = 'id') -> bool:
  116. """
  117. 批量更新记录(根据ID)
  118. Args:
  119. table: 表名
  120. updates_list: 更新数据列表,每个元素包含id和要更新的字段
  121. id_field: ID字段名,默认为'id'
  122. Returns:
  123. bool: 批量更新是否成功
  124. """
  125. if not updates_list:
  126. raise ValueError("更新数据列表不能为空")
  127. try:
  128. # 使用事务确保批量操作的原子性
  129. async with self.db_pool.get_connection() as conn:
  130. async with conn.cursor(aiomysql.DictCursor) as cursor:
  131. for update_data in updates_list:
  132. if id_field not in update_data:
  133. raise ValueError(f"更新数据中缺少{id_field}字段")
  134. record_id = update_data[id_field]
  135. # 从更新数据中移除ID字段
  136. update_fields = {k: v for k, v in update_data.items() if k != id_field}
  137. if not update_fields:
  138. continue
  139. # 构建SET子句
  140. set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
  141. set_values = list(update_fields.values())
  142. # 执行更新
  143. sql = f"UPDATE {table} SET {set_clause} WHERE {id_field} = %s"
  144. params = set_values + [record_id]
  145. await cursor.execute(sql, params)
  146. # 提交事务
  147. await conn.commit()
  148. return True
  149. except Exception as err:
  150. handler_err(logger=server_logger, err=err, err_name="批量更新失败")
  151. raise
  152. class TestTabDAO(AsyncBaseDAO):
  153. """异步用户数据访问对象"""
  154. async def insert_user(self, name: str, email: str, age: int) -> int:
  155. """插入用户"""
  156. insert_sql = "INSERT INTO test_tab (name, email, age) VALUES (%s, %s, %s)"
  157. try:
  158. async with self.db_pool.get_cursor() as cursor:
  159. await cursor.execute(insert_sql, (name, email, age))
  160. return cursor.lastrowid
  161. except Exception as err:
  162. handler_err(logger=server_logger, err=err ,err_name="插入用户失败")
  163. raise
  164. async def get_user_by_id(self, user_id: int) -> Optional[Dict]:
  165. """根据ID获取用户"""
  166. query = "SELECT * FROM test_tab WHERE id = %s AND status = 'active'"
  167. return await self.fetch_one(query, (user_id,))
  168. async def get_all_users(self) -> List[Dict]:
  169. """获取所有用户"""
  170. query = "SELECT * FROM test_tab WHERE status = 'active' ORDER BY created_at DESC"
  171. return await self.fetch_all(query)