redis_connection.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :redis_connection.py.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/21 15:07
  9. '''
  10. import redis # 同步专用
  11. # 尝试导入异步Redis模块
  12. try:
  13. from redis import asyncio as redis_asyncio
  14. except ImportError:
  15. try:
  16. import aioredis as redis_asyncio
  17. except ImportError:
  18. raise ImportError("Neither redis.asyncio nor aioredis is available. Please install 'redis[asyncio]' or 'aioredis'")
  19. # 导入Redis异常类
  20. from redis.exceptions import ConnectionError as redis_ConnectionError
  21. from typing import Optional, Protocol, Dict, Any, Set, Tuple
  22. from functools import wraps
  23. import asyncio
  24. from foundation.infrastructure.cache.redis_config import RedisConfig
  25. from foundation.infrastructure.cache.redis_config import load_config_from_env
  26. # 延迟导入logger以避免循环依赖
  27. def _get_redis_logger():
  28. try:
  29. from foundation.observability.logger.loggering import server_logger
  30. return server_logger
  31. except ImportError:
  32. import logging
  33. return logging.getLogger(__name__)
  34. from typing import Dict, Any, List, Tuple
  35. try:
  36. from langchain_community.storage import RedisStore
  37. except ModuleNotFoundError:
  38. RedisStore = None
  39. def with_redis_retry(max_retries: int = 3, delay: float = 1.0):
  40. """
  41. Redis操作重连装饰器
  42. Args:
  43. max_retries: 最大重试次数,默认3次
  44. delay: 重试间隔秒数,默认1秒
  45. """
  46. def decorator(func):
  47. @wraps(func)
  48. async def wrapper(self, *args, **kwargs):
  49. last_exception = None
  50. for attempt in range(max_retries + 1): # +1 包含第一次尝试
  51. try:
  52. return await func(self, *args, **kwargs)
  53. except (ConnectionResetError, redis_ConnectionError) as e:
  54. last_exception = e
  55. if attempt < max_retries:
  56. _get_redis_logger().warning(
  57. f"Redis连接异常 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
  58. )
  59. # 尝试重连
  60. try:
  61. await self._reconnect()
  62. except Exception as reconnect_error:
  63. _get_redis_logger().error(f"Redis重连失败: {str(reconnect_error)}")
  64. # 如果重连失败,继续重试
  65. await asyncio.sleep(delay * (attempt + 1)) # 指数退避
  66. continue
  67. _get_redis_logger().info(f"Redis重连成功,重新执行操作")
  68. await asyncio.sleep(delay) # 等待连接稳定
  69. else:
  70. _get_redis_logger().error(f"Redis操作失败,已达最大重试次数: {str(e)}")
  71. break
  72. except Exception as e:
  73. # 非连接相关的异常直接抛出
  74. raise e
  75. # 所有重试都失败了
  76. raise last_exception
  77. return wrapper
  78. return decorator
  79. class RedisConnection(Protocol):
  80. """
  81. Redis 接口协议
  82. """
  83. async def get(self, key: str) -> Any: ...
  84. async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool: ...
  85. async def hget(self, key: str, field: str) -> Any: ...
  86. async def hset(self, key: str, field: str, value: Any) -> int: ...
  87. async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool: ...
  88. async def hgetall(self, key: str) -> Dict[str, Any]: ...
  89. async def delete(self, *keys: str) -> int: ...
  90. async def exists(self, key: str) -> int: ...
  91. async def expire(self, key: str, seconds: int) -> bool: ...
  92. async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
  93. int, list[str]]: ...
  94. async def eval(self, script: str, keys: list[str], args: list[str]) -> Any: ...
  95. # 集合操作方法
  96. async def sadd(self, key: str, *values: str) -> int: ...
  97. async def scard(self, key: str) -> int: ...
  98. async def srem(self, key: str, *values: str) -> int: ...
  99. async def smembers(self, key: str) -> Set[str]: ...
  100. async def close(self) -> None: ...
  101. class RedisAdapter(RedisConnection):
  102. """
  103. Redis 适配器
  104. """
  105. def __init__(self, config: RedisConfig):
  106. self.config = config
  107. # 用于普通Redis 操作存储
  108. self._redis = None
  109. # 用于 langchain RedisStore 存储
  110. self._langchain_redis_client = None
  111. async def connect(self):
  112. """创建Redis连接"""
  113. # 简化的TCP Keep-Alive配置(兼容Windows系统)
  114. socket_options = {
  115. 'socket_keepalive': True,
  116. 'socket_connect_timeout': 10, # 连接超时10秒
  117. 'socket_timeout': 30, # 读写超时30秒
  118. }
  119. # 使用新版本的redis.asyncio
  120. self._redis = redis_asyncio.from_url(
  121. self.config.url,
  122. password=self.config.password,
  123. db=self.config.db,
  124. encoding="utf-8",
  125. decode_responses=True,
  126. max_connections=self.config.max_connections,
  127. **socket_options
  128. )
  129. # 用于 langchain RedisStore 存储
  130. # 必须设为 False(LangChain 需要 bytes 数据)
  131. self._langchain_redis_client = redis_asyncio.from_url(
  132. self.config.url,
  133. password=self.config.password,
  134. db=self.config.db,
  135. encoding="utf-8",
  136. decode_responses=False,
  137. max_connections=self.config.max_connections,
  138. **socket_options
  139. )
  140. # ✅ 使用同步 Redis 客户端
  141. # self._langchain_redis_client = redis.Redis.from_url(
  142. # self.config.url,
  143. # password=self.config.password,
  144. # db=self.config.db,
  145. # decode_responses=False, # LangChain 需要 bytes
  146. # )
  147. #错误:Expected Redis client, got Redis instead
  148. # self._langchain_redis_client = async_redis.from_url(
  149. # self.config.url,
  150. # password=self.config.password,
  151. # db=self.config.db,
  152. # decode_responses=False
  153. # )
  154. return self
  155. @with_redis_retry()
  156. async def get(self, key: str) -> Any:
  157. """获取Redis键值"""
  158. return await self._redis.get(key)
  159. @with_redis_retry()
  160. async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool:
  161. """设置Redis键值"""
  162. return await self._redis.set(key, value, ex=ex, nx=nx)
  163. @with_redis_retry()
  164. async def setex(self, key: str, time: int, value: Any) -> bool:
  165. """设置Redis键值并指定过期时间"""
  166. return await self._redis.setex(key, time, value)
  167. @with_redis_retry()
  168. async def hget(self, key: str, field: str) -> Any:
  169. return await self._redis.hget(key, field)
  170. @with_redis_retry()
  171. async def hset(self, key: str, field: str, value: Any) -> int:
  172. return await self._redis.hset(key, field, value)
  173. @with_redis_retry()
  174. async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool:
  175. return await self._redis.hmset(key, mapping)
  176. @with_redis_retry()
  177. async def hgetall(self, key: str) -> Dict[str, Any]:
  178. return await self._redis.hgetall(key)
  179. @with_redis_retry()
  180. async def delete(self, *keys: str) -> int:
  181. return await self._redis.delete(*keys)
  182. @with_redis_retry()
  183. async def exists(self, key: str) -> int:
  184. return await self._redis.exists(key)
  185. @with_redis_retry()
  186. async def expire(self, key: str, seconds: int) -> bool:
  187. return await self._redis.expire(key, seconds)
  188. @with_redis_retry()
  189. async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
  190. int, list[str]]:
  191. return await self._redis.scan(cursor, match=match, count=count)
  192. @with_redis_retry()
  193. async def eval(self, script: str, numkeys: int, *keys_and_args: str) -> Any:
  194. """执行Redis脚本"""
  195. return await self._redis.eval(script, numkeys, *keys_and_args) # 解包成独立参数
  196. # 集合操作方法实现
  197. @with_redis_retry()
  198. async def sadd(self, key: str, *values: str) -> int:
  199. """向集合添加成员,返回添加的成员数量"""
  200. return await self._redis.sadd(key, *values)
  201. @with_redis_retry()
  202. async def scard(self, key: str) -> int:
  203. """获取集合成员数量"""
  204. return await self._redis.scard(key)
  205. @with_redis_retry()
  206. async def srem(self, key: str, *values: str) -> int:
  207. """从集合删除成员,返回删除的成员数量"""
  208. return await self._redis.srem(key, *values)
  209. @with_redis_retry()
  210. async def smembers(self, key: str) -> Set[str]:
  211. """获取集合所有成员"""
  212. return await self._redis.smembers(key)
  213. def get_langchain_redis_client(self):
  214. return self._langchain_redis_client
  215. async def _reconnect(self) -> None:
  216. """重新连接Redis"""
  217. try:
  218. _get_redis_logger().info("正在重新连接Redis...")
  219. if self._redis:
  220. await self._redis.close()
  221. if self._langchain_redis_client:
  222. await self._langchain_redis_client.close()
  223. # 等待短暂时间后重连
  224. await asyncio.sleep(1)
  225. # 重新建立连接
  226. await self.connect()
  227. _get_redis_logger().info("Redis重连成功")
  228. except Exception as e:
  229. _get_redis_logger().error(f"Redis重连失败: {str(e)}")
  230. raise
  231. async def close(self) -> None:
  232. if self._redis:
  233. await self._redis.close()
  234. #await self._redis.wait_closed() #该方法已弃用
  235. if self._langchain_redis_client:
  236. await self._langchain_redis_client.close()
  237. #await self._langchain_redis_client.wait_closed()
  238. class RedisConnectionFactory:
  239. """
  240. redis 连接工厂函数
  241. """
  242. _connections: Dict[str, RedisConnection] = {}
  243. _stores: Dict[str, RedisStore] = {}
  244. _connection_loops: Dict[str, asyncio.AbstractEventLoop] = {} # 记录每个连接的事件循环
  245. @classmethod
  246. async def get_connection(cls) -> RedisConnection:
  247. """获取Redis连接(单例模式,支持事件循环检测)"""
  248. # 加载配置
  249. redis_config = load_config_from_env()
  250. #_get_redis_logger().info(f"redis_config={redis_config}")
  251. # 使用配置参数生成唯一标识
  252. conn_id = f"{redis_config.url}-{redis_config.db}"
  253. # 获取当前事件循环
  254. try:
  255. current_loop = asyncio.get_running_loop()
  256. except RuntimeError:
  257. # 如果没有运行的事件循环,创建一个新的
  258. current_loop = asyncio.new_event_loop()
  259. asyncio.set_event_loop(current_loop)
  260. # 检查连接是否存在以及事件循环是否匹配
  261. if conn_id in cls._connections:
  262. stored_loop = cls._connection_loops.get(conn_id)
  263. if stored_loop != current_loop:
  264. # 事件循环不匹配,需要重新创建连接
  265. _get_redis_logger().warning(
  266. f"检测到事件循环变化,重新创建Redis连接: {conn_id}"
  267. )
  268. # 关闭旧连接
  269. try:
  270. await cls._connections[conn_id].close()
  271. except Exception as e:
  272. _get_redis_logger().debug(f"关闭旧Redis连接时出错: {e}")
  273. # 删除旧连接
  274. del cls._connections[conn_id]
  275. del cls._connection_loops[conn_id]
  276. # 创建新连接
  277. if conn_id not in cls._connections:
  278. adapter = RedisAdapter(redis_config)
  279. await adapter.connect()
  280. cls._connections[conn_id] = adapter
  281. cls._connection_loops[conn_id] = current_loop
  282. _get_redis_logger().info(f"创建新的Redis连接: {conn_id}")
  283. return cls._connections[conn_id]
  284. @classmethod
  285. async def get_redis_store(cls) -> RedisStore:
  286. """获取 LangChain RedisStore 实例"""
  287. # 加载配置
  288. redis_config = load_config_from_env()
  289. conn = await cls.get_connection() # 或通过其他方式获取
  290. client = conn.get_langchain_redis_client()
  291. return client
  292. @classmethod
  293. async def get_langchain_redis_store(cls) -> RedisStore:
  294. """获取 LangChain RedisStore 实例
  295. 目前该方法存在问题
  296. """
  297. # 加载配置
  298. redis_config = load_config_from_env()
  299. # 使用配置参数生成唯一标识
  300. store_id = f"{redis_config.url}-{redis_config.db}"
  301. if store_id not in cls._stores:
  302. conn = await cls.get_connection() # 或通过其他方式获取
  303. client = conn.get_langchain_redis_client()
  304. store = client
  305. _get_redis_logger().info(f"client={client}")
  306. _get_redis_logger().info(f"store={dir(store)}")
  307. cls._stores[store_id] = store
  308. return cls._stores[store_id]
  309. @classmethod
  310. async def close_all(cls):
  311. """关闭所有Redis连接"""
  312. for conn in cls._connections.values():
  313. try:
  314. await conn.close()
  315. except Exception as e:
  316. _get_redis_logger().debug(f"关闭Redis连接时出错: {e}")
  317. cls._connections = {}
  318. cls._connection_loops = {} # 同时清理事件循环记录
  319. cls._stores = {}
  320. @classmethod
  321. def get_connection_count(cls) -> int:
  322. """获取当前连接数"""
  323. return len(cls._connections)