redis_connection.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :redis_connection.py.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/21 15:07
  9. '''
  10. import redis # 同步专用
  11. from redis import asyncio as aioredis
  12. from typing import Optional, Protocol, Dict, Any
  13. from foundation.base.redis_config import RedisConfig
  14. from foundation.base.redis_config import load_config_from_env
  15. from foundation.logger.loggering import server_logger
  16. from typing import Dict, Any, List
  17. from typing import Tuple, Optional
  18. from langchain_community.storage import RedisStore
  19. class RedisConnection(Protocol):
  20. """
  21. Redis 接口协议
  22. """
  23. async def get(self, key: str) -> Any: ...
  24. async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool: ...
  25. async def hget(self, key: str, field: str) -> Any: ...
  26. async def hset(self, key: str, field: str, value: Any) -> int: ...
  27. async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool: ...
  28. async def hgetall(self, key: str) -> Dict[str, Any]: ...
  29. async def delete(self, *keys: str) -> int: ...
  30. async def exists(self, key: str) -> int: ...
  31. async def expire(self, key: str, seconds: int) -> bool: ...
  32. async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
  33. int, list[str]]: ...
  34. async def eval(self, script: str, keys: list[str], args: list[str]) -> Any: ...
  35. async def close(self) -> None: ...
  36. class RedisAdapter(RedisConnection):
  37. """
  38. Redis 适配器
  39. """
  40. def __init__(self, config: RedisConfig):
  41. self.config = config
  42. # 用于普通Redis 操作存储
  43. self._redis = None
  44. # 用于 langchain RedisStore 存储
  45. self._langchain_redis_client = None
  46. async def connect(self):
  47. """创建Redis连接"""
  48. self._redis = await aioredis.from_url(
  49. self.config.url,
  50. password=self.config.password,
  51. db=self.config.db,
  52. encoding="utf-8",
  53. decode_responses=True,
  54. max_connections=self.config.max_connections
  55. )
  56. # 用于 langchain RedisStore 存储
  57. # 必须设为 False(LangChain 需要 bytes 数据)
  58. self._langchain_redis_client = aioredis.from_url(
  59. self.config.url,
  60. password=self.config.password,
  61. db=self.config.db,
  62. encoding="utf-8",
  63. decode_responses=False,
  64. max_connections=self.config.max_connections
  65. )
  66. # ✅ 使用同步 Redis 客户端
  67. # self._langchain_redis_client = redis.Redis.from_url(
  68. # self.config.url,
  69. # password=self.config.password,
  70. # db=self.config.db,
  71. # decode_responses=False, # LangChain 需要 bytes
  72. # )
  73. #错误:Expected Redis client, got Redis instead
  74. # self._langchain_redis_client = async_redis.from_url(
  75. # self.config.url,
  76. # password=self.config.password,
  77. # db=self.config.db,
  78. # decode_responses=False
  79. # )
  80. return self
  81. async def get(self, key: str) -> Any:
  82. return await self._redis.get(key)
  83. async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool:
  84. return await self._redis.set(key, value, ex=ex, nx=nx)
  85. async def hget(self, key: str, field: str) -> Any:
  86. return await self._redis.hget(key, field)
  87. async def hset(self, key: str, field: str, value: Any) -> int:
  88. return await self._redis.hset(key, field, value)
  89. async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool:
  90. return await self._redis.hmset(key, mapping)
  91. async def hgetall(self, key: str) -> Dict[str, Any]:
  92. return await self._redis.hgetall(key)
  93. async def delete(self, *keys: str) -> int:
  94. return await self._redis.delete(*keys)
  95. async def exists(self, key: str) -> int:
  96. return await self._redis.exists(key)
  97. async def expire(self, key: str, seconds: int) -> bool:
  98. return await self._redis.expire(key, seconds)
  99. async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
  100. int, list[str]]:
  101. return await self._redis.scan(cursor, match=match, count=count)
  102. async def eval(self, script: str, numkeys: int, *keys_and_args: str) -> Any:
  103. return await self._redis.eval(script, numkeys, *keys_and_args) # 解包成独立参数
  104. def get_langchain_redis_client(self):
  105. return self._langchain_redis_client
  106. async def close(self) -> None:
  107. if self._redis:
  108. await self._redis.close()
  109. await self._redis.wait_closed()
  110. if self._langchain_redis_client:
  111. await self._langchain_redis_client.close()
  112. await self._langchain_redis_client.wait_closed()
  113. class RedisConnectionFactory:
  114. """
  115. redis 连接工厂函数
  116. """
  117. _connections: Dict[str, RedisConnection] = {}
  118. _stores: Dict[str, RedisStore] = {}
  119. @classmethod
  120. async def get_connection(cls) -> RedisConnection:
  121. """获取Redis连接(单例模式)"""
  122. # 加载配置
  123. redis_config = load_config_from_env()
  124. #server_logger.info(f"redis_config={redis_config}")
  125. # 使用配置参数生成唯一标识
  126. conn_id = f"{redis_config.url}-{redis_config.db}"
  127. if conn_id not in cls._connections:
  128. adapter = RedisAdapter(redis_config)
  129. await adapter.connect()
  130. cls._connections[conn_id] = adapter
  131. return cls._connections[conn_id]
  132. @classmethod
  133. async def get_redis_store(cls) -> RedisStore:
  134. """获取 LangChain RedisStore 实例"""
  135. # 加载配置
  136. redis_config = load_config_from_env()
  137. conn = await cls.get_connection() # 或通过其他方式获取
  138. client = conn.get_langchain_redis_client()
  139. return client
  140. @classmethod
  141. async def get_langchain_redis_store(cls) -> RedisStore:
  142. """获取 LangChain RedisStore 实例
  143. 目前该方法存在问题
  144. """
  145. # 加载配置
  146. redis_config = load_config_from_env()
  147. # 使用配置参数生成唯一标识
  148. store_id = f"{redis_config.url}-{redis_config.db}"
  149. if store_id not in cls._stores:
  150. conn = await cls.get_connection() # 或通过其他方式获取
  151. client = conn.get_langchain_redis_client()
  152. store = client
  153. server_logger.info(f"client={client}")
  154. server_logger.info(f"store={dir(store)}")
  155. cls._stores[store_id] = store
  156. return cls._stores[store_id]
  157. @classmethod
  158. async def close_all(cls):
  159. """关闭所有Redis连接"""
  160. for conn in cls._connections.values():
  161. await conn.close()
  162. cls._connections = {}
  163. @classmethod
  164. def get_connection_count(cls) -> int:
  165. """获取当前连接数"""
  166. return len(cls._connections)