redis_connection.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :redis_connection.py.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/21 15:07
  9. '''
  10. import redis # 同步专用
  11. # 尝试导入异步Redis模块
  12. try:
  13. from redis import asyncio as redis_asyncio
  14. except ImportError:
  15. try:
  16. import aioredis as redis_asyncio
  17. except ImportError:
  18. raise ImportError("Neither redis.asyncio nor aioredis is available. Please install 'redis[asyncio]' or 'aioredis'")
  19. # 导入Redis异常类
  20. from redis.exceptions import ConnectionError as redis_ConnectionError
  21. from typing import Optional, Protocol, Dict, Any
  22. from functools import wraps
  23. import asyncio
  24. from foundation.base.redis_config import RedisConfig
  25. from foundation.base.redis_config import load_config_from_env
  26. from foundation.logger.loggering import server_logger
  27. from typing import Dict, Any, List, Tuple
  28. from langchain_community.storage import RedisStore
  29. def with_redis_retry(max_retries: int = 3, delay: float = 1.0):
  30. """
  31. Redis操作重连装饰器
  32. Args:
  33. max_retries: 最大重试次数,默认3次
  34. delay: 重试间隔秒数,默认1秒
  35. """
  36. def decorator(func):
  37. @wraps(func)
  38. async def wrapper(self, *args, **kwargs):
  39. last_exception = None
  40. for attempt in range(max_retries + 1): # +1 包含第一次尝试
  41. try:
  42. return await func(self, *args, **kwargs)
  43. except (ConnectionResetError, redis_ConnectionError) as e:
  44. last_exception = e
  45. if attempt < max_retries:
  46. server_logger.warning(
  47. f"Redis连接异常 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
  48. )
  49. # 尝试重连
  50. try:
  51. await self._reconnect()
  52. except Exception as reconnect_error:
  53. server_logger.error(f"Redis重连失败: {str(reconnect_error)}")
  54. # 如果重连失败,继续重试
  55. await asyncio.sleep(delay * (attempt + 1)) # 指数退避
  56. continue
  57. server_logger.info(f"Redis重连成功,重新执行操作")
  58. await asyncio.sleep(delay) # 等待连接稳定
  59. else:
  60. server_logger.error(f"Redis操作失败,已达最大重试次数: {str(e)}")
  61. break
  62. except Exception as e:
  63. # 非连接相关的异常直接抛出
  64. raise e
  65. # 所有重试都失败了
  66. raise last_exception
  67. return wrapper
  68. return decorator
  69. class RedisConnection(Protocol):
  70. """
  71. Redis 接口协议
  72. """
  73. async def get(self, key: str) -> Any: ...
  74. async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool: ...
  75. async def hget(self, key: str, field: str) -> Any: ...
  76. async def hset(self, key: str, field: str, value: Any) -> int: ...
  77. async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool: ...
  78. async def hgetall(self, key: str) -> Dict[str, Any]: ...
  79. async def delete(self, *keys: str) -> int: ...
  80. async def exists(self, key: str) -> int: ...
  81. async def expire(self, key: str, seconds: int) -> bool: ...
  82. async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
  83. int, list[str]]: ...
  84. async def eval(self, script: str, keys: list[str], args: list[str]) -> Any: ...
  85. async def close(self) -> None: ...
  86. class RedisAdapter(RedisConnection):
  87. """
  88. Redis 适配器
  89. """
  90. def __init__(self, config: RedisConfig):
  91. self.config = config
  92. # 用于普通Redis 操作存储
  93. self._redis = None
  94. # 用于 langchain RedisStore 存储
  95. self._langchain_redis_client = None
  96. async def connect(self):
  97. """创建Redis连接"""
  98. # 简化的TCP Keep-Alive配置(兼容Windows系统)
  99. socket_options = {
  100. 'socket_keepalive': True,
  101. 'socket_connect_timeout': 10, # 连接超时10秒
  102. 'socket_timeout': 30, # 读写超时30秒
  103. }
  104. # 使用新版本的redis.asyncio
  105. self._redis = redis_asyncio.from_url(
  106. self.config.url,
  107. password=self.config.password,
  108. db=self.config.db,
  109. encoding="utf-8",
  110. decode_responses=True,
  111. max_connections=self.config.max_connections,
  112. **socket_options
  113. )
  114. # 用于 langchain RedisStore 存储
  115. # 必须设为 False(LangChain 需要 bytes 数据)
  116. self._langchain_redis_client = redis_asyncio.from_url(
  117. self.config.url,
  118. password=self.config.password,
  119. db=self.config.db,
  120. encoding="utf-8",
  121. decode_responses=False,
  122. max_connections=self.config.max_connections,
  123. **socket_options
  124. )
  125. # ✅ 使用同步 Redis 客户端
  126. # self._langchain_redis_client = redis.Redis.from_url(
  127. # self.config.url,
  128. # password=self.config.password,
  129. # db=self.config.db,
  130. # decode_responses=False, # LangChain 需要 bytes
  131. # )
  132. #错误:Expected Redis client, got Redis instead
  133. # self._langchain_redis_client = async_redis.from_url(
  134. # self.config.url,
  135. # password=self.config.password,
  136. # db=self.config.db,
  137. # decode_responses=False
  138. # )
  139. return self
  140. @with_redis_retry()
  141. async def get(self, key: str) -> Any:
  142. """获取Redis键值"""
  143. return await self._redis.get(key)
  144. @with_redis_retry()
  145. async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool:
  146. """设置Redis键值"""
  147. return await self._redis.set(key, value, ex=ex, nx=nx)
  148. @with_redis_retry()
  149. async def setex(self, key: str, time: int, value: Any) -> bool:
  150. """设置Redis键值并指定过期时间"""
  151. return await self._redis.setex(key, time, value)
  152. @with_redis_retry()
  153. async def hget(self, key: str, field: str) -> Any:
  154. return await self._redis.hget(key, field)
  155. @with_redis_retry()
  156. async def hset(self, key: str, field: str, value: Any) -> int:
  157. return await self._redis.hset(key, field, value)
  158. @with_redis_retry()
  159. async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool:
  160. return await self._redis.hmset(key, mapping)
  161. @with_redis_retry()
  162. async def hgetall(self, key: str) -> Dict[str, Any]:
  163. return await self._redis.hgetall(key)
  164. @with_redis_retry()
  165. async def delete(self, *keys: str) -> int:
  166. return await self._redis.delete(*keys)
  167. @with_redis_retry()
  168. async def exists(self, key: str) -> int:
  169. return await self._redis.exists(key)
  170. @with_redis_retry()
  171. async def expire(self, key: str, seconds: int) -> bool:
  172. return await self._redis.expire(key, seconds)
  173. @with_redis_retry()
  174. async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
  175. int, list[str]]:
  176. return await self._redis.scan(cursor, match=match, count=count)
  177. @with_redis_retry()
  178. async def eval(self, script: str, numkeys: int, *keys_and_args: str) -> Any:
  179. """执行Redis脚本"""
  180. return await self._redis.eval(script, numkeys, *keys_and_args) # 解包成独立参数
  181. def get_langchain_redis_client(self):
  182. return self._langchain_redis_client
  183. async def _reconnect(self) -> None:
  184. """重新连接Redis"""
  185. try:
  186. server_logger.info("正在重新连接Redis...")
  187. if self._redis:
  188. await self._redis.close()
  189. await self._redis.wait_closed()
  190. if self._langchain_redis_client:
  191. await self._langchain_redis_client.close()
  192. await self._langchain_redis_client.wait_closed()
  193. # 等待短暂时间后重连
  194. await asyncio.sleep(1)
  195. # 重新建立连接
  196. await self.connect()
  197. server_logger.info("Redis重连成功")
  198. except Exception as e:
  199. server_logger.error(f"Redis重连失败: {str(e)}")
  200. raise
  201. async def close(self) -> None:
  202. if self._redis:
  203. await self._redis.close()
  204. #await self._redis.wait_closed() #该方法已弃用
  205. if self._langchain_redis_client:
  206. await self._langchain_redis_client.close()
  207. #await self._langchain_redis_client.wait_closed()
  208. class RedisConnectionFactory:
  209. """
  210. redis 连接工厂函数
  211. """
  212. _connections: Dict[str, RedisConnection] = {}
  213. _stores: Dict[str, RedisStore] = {}
  214. @classmethod
  215. async def get_connection(cls) -> RedisConnection:
  216. """获取Redis连接(单例模式)"""
  217. # 加载配置
  218. redis_config = load_config_from_env()
  219. #server_logger.info(f"redis_config={redis_config}")
  220. # 使用配置参数生成唯一标识
  221. conn_id = f"{redis_config.url}-{redis_config.db}"
  222. if conn_id not in cls._connections:
  223. adapter = RedisAdapter(redis_config)
  224. await adapter.connect()
  225. cls._connections[conn_id] = adapter
  226. return cls._connections[conn_id]
  227. @classmethod
  228. async def get_redis_store(cls) -> RedisStore:
  229. """获取 LangChain RedisStore 实例"""
  230. # 加载配置
  231. redis_config = load_config_from_env()
  232. conn = await cls.get_connection() # 或通过其他方式获取
  233. client = conn.get_langchain_redis_client()
  234. return client
  235. @classmethod
  236. async def get_langchain_redis_store(cls) -> RedisStore:
  237. """获取 LangChain RedisStore 实例
  238. 目前该方法存在问题
  239. """
  240. # 加载配置
  241. redis_config = load_config_from_env()
  242. # 使用配置参数生成唯一标识
  243. store_id = f"{redis_config.url}-{redis_config.db}"
  244. if store_id not in cls._stores:
  245. conn = await cls.get_connection() # 或通过其他方式获取
  246. client = conn.get_langchain_redis_client()
  247. store = client
  248. server_logger.info(f"client={client}")
  249. server_logger.info(f"store={dir(store)}")
  250. cls._stores[store_id] = store
  251. return cls._stores[store_id]
  252. @classmethod
  253. async def close_all(cls):
  254. """关闭所有Redis连接"""
  255. for conn in cls._connections.values():
  256. await conn.close()
  257. cls._connections = {}
  258. @classmethod
  259. def get_connection_count(cls) -> int:
  260. """获取当前连接数"""
  261. return len(cls._connections)