redis_duplicate_checker.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. 基于Redis的重复任务检查器
  3. 支持多进程间的重复任务检查
  4. """
  5. import os
  6. import json
  7. from datetime import datetime, timedelta
  8. import redis
  9. from foundation.logger.loggering import server_logger as logger
  10. class RedisDuplicateChecker:
  11. """基于Redis的重复任务检查器"""
  12. def __init__(self):
  13. try:
  14. # 从配置文件读取Redis连接信息
  15. from foundation.base.config import config_handler
  16. redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
  17. redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
  18. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
  19. # 构建Redis连接URL
  20. if redis_password:
  21. redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/2"
  22. else:
  23. redis_url = f"redis://{redis_host}:{redis_port}/2"
  24. logger.info(f"连接Redis: {redis_url}")
  25. # 连接Redis
  26. self.redis_client = redis.from_url(redis_url, decode_responses=True)
  27. # 测试连接
  28. self.redis_client.ping()
  29. logger.info("Redis重复检查器连接成功")
  30. self.use_redis = True
  31. except Exception as e:
  32. logger.error(f"Redis连接失败,回退到内存模式: {str(e)}")
  33. # 回退到内存模式
  34. self.task_cache = {}
  35. self.use_redis = False
  36. else:
  37. self.use_redis = True
  38. async def is_duplicate_task(self, file_id: str) -> bool:
  39. """检查是否为重复任务"""
  40. try:
  41. if self.use_redis:
  42. # 使用Redis检查
  43. task_info = self.redis_client.get(f"task:{file_id}")
  44. if task_info:
  45. # 检查任务是否过期
  46. task_data = json.loads(task_info)
  47. created_at = datetime.fromisoformat(task_data['created_at'])
  48. if datetime.now() - created_at < timedelta(minutes=2):
  49. logger.info(f"发现重复任务: {file_id}")
  50. return True
  51. else:
  52. # 任务已过期,清理
  53. self.redis_client.delete(f"task:{file_id}")
  54. return False
  55. return False
  56. else:
  57. # 回退到内存模式
  58. if file_id in self.task_cache:
  59. logger.info(f"发现重复任务: {file_id}")
  60. return True
  61. return False
  62. except Exception as e:
  63. logger.error(f"检查重复任务失败: {str(e)}")
  64. return False
  65. async def register_task(self, file_info: dict, callback_task_id: str):
  66. """注册任务"""
  67. try:
  68. # 过滤掉不可序列化的字段(如file_content等bytes数据)
  69. serializable_file_info = {
  70. k: v for k, v in file_info.items()
  71. if k not in ['file_content'] and not isinstance(v, bytes)
  72. }
  73. task_data = {
  74. "callback_task_id": callback_task_id,
  75. "created_at": datetime.now().isoformat(),
  76. "used": False, # 标记任务是否已被使用启动审查
  77. "file_info": serializable_file_info
  78. }
  79. if self.use_redis:
  80. # 使用Redis存储,设置1小时过期
  81. self.redis_client.setex(
  82. f"task:{file_info['file_id']}",
  83. 3600, # 1小时
  84. json.dumps(task_data, ensure_ascii=False)
  85. )
  86. else:
  87. # 回退到内存模式
  88. self.task_cache[file_info['file_id']] = task_data
  89. logger.info(f"注册任务: {file_info['file_id']} -> {callback_task_id}")
  90. except Exception as e:
  91. logger.error(f"注册任务失败: {str(e)}")
  92. raise
  93. async def unregister_task(self, file_id: str):
  94. """取消注册任务"""
  95. try:
  96. if self.use_redis:
  97. self.redis_client.delete(f"task:{file_id}")
  98. else:
  99. if file_id in self.task_cache:
  100. del self.task_cache[file_id]
  101. logger.info(f"取消注册任务: {file_id}")
  102. except Exception as e:
  103. logger.error(f"取消注册任务失败: {str(e)}")
  104. async def is_valid_task_id(self, callback_task_id: str) -> bool:
  105. """验证任务ID是否存在且未过期"""
  106. try:
  107. if self.use_redis:
  108. # 遍历所有任务键,查找匹配的callback_task_id
  109. keys = self.redis_client.keys("task:*")
  110. for key in keys:
  111. task_info = self.redis_client.get(key)
  112. if task_info:
  113. task_data = json.loads(task_info)
  114. if task_data.get("callback_task_id") == callback_task_id:
  115. created_at = datetime.fromisoformat(task_data['created_at'])
  116. if datetime.now() - created_at < timedelta(minutes=2):
  117. return True
  118. else:
  119. # 任务已过期,清理
  120. self.redis_client.delete(key)
  121. return False
  122. else:
  123. # 内存模式检查
  124. for file_id, task_info in self.task_cache.items():
  125. if task_info.get("callback_task_id") == callback_task_id:
  126. created_at = datetime.fromisoformat(task_info['created_at'])
  127. if datetime.now() - created_at < timedelta(minutes=2):
  128. return True
  129. return False
  130. except Exception as e:
  131. logger.error(f"验证任务ID失败: {str(e)}")
  132. return False
  133. async def get_task_info(self, file_id: str) -> str:
  134. """获取任务信息"""
  135. try:
  136. if self.use_redis:
  137. task_info = self.redis_client.get(f"task:{file_id}")
  138. if task_info:
  139. task_data = json.loads(task_info)
  140. return task_data.get("callback_task_id", "")
  141. return ""
  142. else:
  143. if file_id in self.task_cache:
  144. return self.task_cache[file_id].get("callback_task_id", "")
  145. return ""
  146. except Exception as e:
  147. logger.error(f"获取任务信息失败: {str(e)}")
  148. return ""
  149. def cleanup_expired_cache(self):
  150. """清理过期缓存(Redis自动处理)"""
  151. try:
  152. if not self.use_redis:
  153. current_time = datetime.now()
  154. expired_files = []
  155. for file_id, task_info in list(self.task_cache.items()):
  156. created_at = datetime.fromisoformat(task_info['created_at'])
  157. if current_time - created_at > timedelta(hours=1):
  158. expired_files.append(file_id)
  159. for file_id in expired_files:
  160. del self.task_cache[file_id]
  161. if expired_files:
  162. logger.info(f"清理过期缓存: {len(expired_files)} 个文件")
  163. except Exception as e:
  164. logger.error(f"清理过期缓存失败: {str(e)}")
  165. async def is_task_already_used(self, callback_task_id: str) -> bool:
  166. """检查任务是否已经被使用启动审查"""
  167. try:
  168. if self.use_redis:
  169. # 遍历所有任务键,查找匹配的callback_task_id
  170. keys = self.redis_client.keys("task:*")
  171. for key in keys:
  172. task_info = self.redis_client.get(key)
  173. if task_info:
  174. task_data = json.loads(task_info)
  175. if task_data.get("callback_task_id") == callback_task_id:
  176. # 检查任务是否已被使用
  177. if task_data.get("used", False):
  178. logger.info(f"任务已被使用: {callback_task_id}")
  179. return True
  180. else:
  181. return False
  182. return False
  183. else:
  184. # 内存模式检查
  185. for file_id, task_info in self.task_cache.items():
  186. if task_info.get("callback_task_id") == callback_task_id:
  187. if task_info.get("used", False):
  188. return True
  189. else:
  190. return False
  191. return False
  192. except Exception as e:
  193. logger.error(f"检查任务使用状态失败: {str(e)}")
  194. return False
  195. async def mark_task_as_used(self, callback_task_id: str):
  196. """标记任务为已使用"""
  197. try:
  198. if self.use_redis:
  199. # 遍历所有任务键,查找匹配的callback_task_id
  200. keys = self.redis_client.keys("task:*")
  201. for key in keys:
  202. task_info = self.redis_client.get(key)
  203. if task_info:
  204. task_data = json.loads(task_info)
  205. if task_data.get("callback_task_id") == callback_task_id:
  206. # 更新used字段为True
  207. task_data["used"] = True
  208. self.redis_client.setex(
  209. key,
  210. 3600, # 1小时
  211. json.dumps(task_data, ensure_ascii=False)
  212. )
  213. logger.info(f"任务已标记为使用: {callback_task_id}")
  214. return
  215. else:
  216. # 内存模式
  217. for file_id, task_info in self.task_cache.items():
  218. if task_info.get("callback_task_id") == callback_task_id:
  219. task_info["used"] = True
  220. logger.info(f"任务已标记为使用: {callback_task_id}")
  221. return
  222. except Exception as e:
  223. logger.error(f"标记任务使用状态失败: {str(e)}")