async_mysql_base_dao.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. from typing import List, Tuple, Any, Optional, Dict
  2. from mysql.connector import Error
  3. from foundation.logger.loggering import server_logger
  4. from foundation.utils.common import handler_err
  5. from foundation.base.mysql.async_mysql_conn_pool import AsyncMySQLPool
  6. import aiomysql
  7. class AsyncBaseDAO:
  8. """异步数据库访问基类"""
  9. def __init__(self, db_pool: AsyncMySQLPool):
  10. self.db_pool = db_pool
  11. async def execute_query(self, query: str, params: Tuple = None) -> bool:
  12. """执行写操作"""
  13. try:
  14. async with self.db_pool.get_cursor() as cursor:
  15. await cursor.execute(query, params or ())
  16. return True
  17. except Exception as err:
  18. handler_err(logger=server_logger, err=err ,err_name="执行查询失败")
  19. raise
  20. async def fetch_all(self, query: str, params: Tuple = None) -> List[Dict]:
  21. """查询多条记录"""
  22. try:
  23. async with self.db_pool.get_cursor() as cursor:
  24. await cursor.execute(query, params or ())
  25. return await cursor.fetchall()
  26. except Exception as err:
  27. handler_err(logger=server_logger, err=err ,err_name="查询数据失败")
  28. raise
  29. async def fetch_one(self, query: str, params: Tuple = None) -> Optional[Dict]:
  30. """查询单条记录"""
  31. try:
  32. async with self.db_pool.get_cursor() as cursor:
  33. await cursor.execute(query, params or ())
  34. return await cursor.fetchone()
  35. except Exception as err:
  36. handler_err(logger=server_logger, err=err ,err_name="查询单条数据失败")
  37. raise
  38. async def fetch_scalar(self, query: str, params: Tuple = None) -> Any:
  39. """查询单个值"""
  40. result = await self.fetch_one(query, params)
  41. return list(result.values())[0] if result else None
  42. async def execute_many(self, query: str, params_list: List[Tuple]) -> bool:
  43. """批量执行"""
  44. try:
  45. async with self.db_pool.get_cursor() as cursor:
  46. await cursor.executemany(query, params_list)
  47. return True
  48. except Exception as err:
  49. handler_err(logger=server_logger, err=err ,err_name="批量执行失败")
  50. raise
  51. async def update_record(self, table: str, updates: Dict, conditions: Dict) -> bool:
  52. """
  53. 通用更新记录方法
  54. Args:
  55. table: 表名
  56. updates: 要更新的字段和值,如 {'name': '新名字', 'age': 25}
  57. conditions: 更新条件,如 {'id': 1, 'status': 'active'}
  58. Returns:
  59. bool: 更新是否成功
  60. """
  61. if not updates:
  62. raise ValueError("更新字段不能为空")
  63. if not conditions:
  64. raise ValueError("更新条件不能为空")
  65. try:
  66. # 构建 SET 子句
  67. set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
  68. set_values = list(updates.values())
  69. # 构建 WHERE 子句
  70. where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
  71. where_values = list(conditions.values())
  72. # 构建完整 SQL
  73. sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
  74. params = set_values + where_values
  75. return await self.execute_query(sql, tuple(params))
  76. except Exception as err:
  77. handler_err(logger=server_logger, err=err, err_name="更新记录失败")
  78. raise
  79. async def update_by_id(self, table: str, record_id: int, updates: Dict) -> bool:
  80. """
  81. 根据ID更新记录
  82. Args:
  83. table: 表名
  84. record_id: 记录ID
  85. updates: 要更新的字段和值
  86. Returns:
  87. bool: 更新是否成功
  88. """
  89. return await self.update_record(table, updates, {'id': record_id})
  90. async def update_with_condition(self, table: str, updates: Dict, where_sql: str, params: Tuple = None) -> bool:
  91. """
  92. 使用自定义WHERE条件更新记录
  93. Args:
  94. table: 表名
  95. updates: 要更新的字段和值
  96. where_sql: WHERE条件SQL
  97. params: WHERE条件参数
  98. Returns:
  99. bool: 更新是否成功
  100. """
  101. if not updates:
  102. raise ValueError("更新字段不能为空")
  103. try:
  104. # 构建 SET 子句
  105. set_clause = ", ".join([f"{field} = %s" for field in updates.keys()])
  106. set_values = list(updates.values())
  107. # 构建完整 SQL
  108. sql = f"UPDATE {table} SET {set_clause} WHERE {where_sql}"
  109. # 合并参数
  110. all_params = tuple(set_values) + (params if params else ())
  111. return await self.execute_query(sql, all_params)
  112. except Exception as err:
  113. handler_err(logger=server_logger, err=err, err_name="条件更新失败")
  114. raise
  115. async def batch_update(self, table: str, updates_list: List[Dict], id_field: str = 'id') -> bool:
  116. """
  117. 批量更新记录(根据ID)
  118. Args:
  119. table: 表名
  120. updates_list: 更新数据列表,每个元素包含id和要更新的字段
  121. id_field: ID字段名,默认为'id'
  122. Returns:
  123. bool: 批量更新是否成功
  124. """
  125. if not updates_list:
  126. raise ValueError("更新数据列表不能为空")
  127. try:
  128. # 使用事务确保批量操作的原子性
  129. async with self.db_pool.get_connection() as conn:
  130. async with conn.cursor(aiomysql.DictCursor) as cursor:
  131. for update_data in updates_list:
  132. if id_field not in update_data:
  133. raise ValueError(f"更新数据中缺少{id_field}字段")
  134. record_id = update_data[id_field]
  135. # 从更新数据中移除ID字段
  136. update_fields = {k: v for k, v in update_data.items() if k != id_field}
  137. if not update_fields:
  138. continue
  139. # 构建SET子句
  140. set_clause = ", ".join([f"{field} = %s" for field in update_fields.keys()])
  141. set_values = list(update_fields.values())
  142. # 执行更新
  143. sql = f"UPDATE {table} SET {set_clause} WHERE {id_field} = %s"
  144. params = set_values + [record_id]
  145. await cursor.execute(sql, params)
  146. # 提交事务
  147. await conn.commit()
  148. return True
  149. except Exception as err:
  150. handler_err(logger=server_logger, err=err, err_name="批量更新失败")
  151. raise
  152. class TestTabDAO(AsyncBaseDAO):
  153. """异步用户数据访问对象"""
  154. async def insert_user(self, name: str, email: str, age: int) -> int:
  155. """插入用户"""
  156. insert_sql = "INSERT INTO test_tab (name, email, age) VALUES (%s, %s, %s)"
  157. try:
  158. async with self.db_pool.get_cursor() as cursor:
  159. await cursor.execute(insert_sql, (name, email, age))
  160. return cursor.lastrowid
  161. except Exception as err:
  162. handler_err(logger=server_logger, err=err ,err_name="插入用户失败")
  163. raise
  164. async def get_user_by_id(self, user_id: int) -> Optional[Dict]:
  165. """根据ID获取用户"""
  166. query = "SELECT * FROM test_tab WHERE id = %s AND status = 'active'"
  167. return await self.fetch_one(query, (user_id,))
  168. async def get_all_users(self) -> List[Dict]:
  169. """获取所有用户"""
  170. query = "SELECT * FROM test_tab WHERE status = 'active' ORDER BY created_at DESC"
  171. return await self.fetch_all(query)
  172. async def get_users_by_condition(self, conditions: Dict) -> List[Dict]:
  173. """根据条件查询用户"""
  174. if not conditions:
  175. return await self.get_all_users()
  176. try:
  177. where_clause = " AND ".join([f"{field} = %s" for field in conditions.keys()])
  178. where_values = list(conditions.values())
  179. query = f"SELECT * FROM test_tab WHERE {where_clause} AND status = 'active' ORDER BY created_at DESC"
  180. return await self.fetch_all(query, tuple(where_values))
  181. except Exception as err:
  182. handler_err(logger=server_logger, err=err, err_name="条件查询用户失败")
  183. raise
  184. # ========== 修改方法 ==========
  185. async def update_user(self, user_id: int, **updates) -> bool:
  186. """
  187. 更新用户信息
  188. Args:
  189. user_id: 用户ID
  190. **updates: 要更新的字段,如 name='新名字', age=25, email='new@email.com'
  191. Returns:
  192. bool: 更新是否成功
  193. """
  194. if not updates:
  195. server_logger.warning("没有提供更新字段")
  196. return False
  197. # 过滤允许更新的字段
  198. allowed_fields = {'name', 'email', 'age', 'status'}
  199. valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
  200. if not valid_updates:
  201. server_logger.warning("没有有效的更新字段")
  202. return False
  203. try:
  204. return await self.update_by_id('test_tab', user_id, valid_updates)
  205. except Exception as err:
  206. handler_err(logger=server_logger, err=err, err_name="更新用户失败")
  207. raise
  208. async def update_user_by_email(self, email: str, **updates) -> bool:
  209. """
  210. 根据邮箱更新用户信息
  211. Args:
  212. email: 用户邮箱
  213. **updates: 要更新的字段
  214. Returns:
  215. bool: 更新是否成功
  216. """
  217. if not updates:
  218. server_logger.warning("没有提供更新字段")
  219. return False
  220. # 过滤允许更新的字段
  221. allowed_fields = {'name', 'age', 'status'}
  222. valid_updates = {k: v for k, v in updates.items() if k in allowed_fields}
  223. if not valid_updates:
  224. server_logger.warning("没有有效的更新字段")
  225. return False
  226. try:
  227. return await self.update_record('test_tab', valid_updates, {'email': email})
  228. except Exception as err:
  229. handler_err(logger=server_logger, err=err, err_name="根据邮箱更新用户失败")
  230. raise
  231. async def update_user_status(self, user_id: int, status: str) -> bool:
  232. """
  233. 更新用户状态
  234. Args:
  235. user_id: 用户ID
  236. status: 状态值 ('active' 或 'inactive')
  237. Returns:
  238. bool: 更新是否成功
  239. """
  240. if status not in ('active', 'inactive'):
  241. raise ValueError("状态值必须是 'active' 或 'inactive'")
  242. try:
  243. return await self.update_user(user_id, status=status)
  244. except Exception as err:
  245. handler_err(logger=server_logger, err=err, err_name="更新用户状态失败")
  246. raise
  247. async def batch_update_users(self, updates_list: List[Dict]) -> bool:
  248. """
  249. 批量更新用户信息
  250. Args:
  251. updates_list: 更新数据列表,每个元素必须包含id字段
  252. Returns:
  253. bool: 批量更新是否成功
  254. """
  255. try:
  256. return await self.batch_update('test_tab', updates_list, 'id')
  257. except Exception as err:
  258. handler_err(logger=server_logger, err=err, err_name="批量更新用户失败")
  259. raise
  260. async def update_users_age_range(self, min_age: int, max_age: int, updates: Dict) -> bool:
  261. """
  262. 更新年龄范围内的用户
  263. Args:
  264. min_age: 最小年龄
  265. max_age: 最大年龄
  266. updates: 要更新的字段
  267. Returns:
  268. bool: 更新是否成功
  269. """
  270. try:
  271. where_sql = "age BETWEEN %s AND %s AND status = 'active'"
  272. params = (min_age, max_age)
  273. return await self.update_with_condition('test_tab', updates, where_sql, params)
  274. except Exception as err:
  275. handler_err(logger=server_logger, err=err, err_name="更新年龄范围用户失败")
  276. raise
  277. async def increment_user_age(self, user_id: int, increment: int = 1) -> bool:
  278. """
  279. 增加用户年龄
  280. Args:
  281. user_id: 用户ID
  282. increment: 增加的值,默认为1
  283. Returns:
  284. bool: 更新是否成功
  285. """
  286. try:
  287. sql = "UPDATE test_tab SET age = age + %s WHERE id = %s AND status = 'active'"
  288. return await self.execute_query(sql, (increment, user_id))
  289. except Exception as err:
  290. handler_err(logger=server_logger, err=err, err_name="增加用户年龄失败")
  291. raise