# !/usr/bin/python # -*- coding: utf-8 -*- ''' @Project : lq-agent-api @File :redis_connection.py.py @IDE :PyCharm @Author : @Date :2025/7/21 15:07 ''' import redis # 同步专用 from redis import asyncio as aioredis from typing import Optional, Protocol, Dict, Any from foundation.base.redis_config import RedisConfig from foundation.base.redis_config import load_config_from_env from foundation.logger.loggering import server_logger from typing import Dict, Any, List from typing import Tuple, Optional from langchain_community.storage import RedisStore class RedisConnection(Protocol): """ Redis 接口协议 """ async def get(self, key: str) -> Any: ... async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool: ... async def hget(self, key: str, field: str) -> Any: ... async def hset(self, key: str, field: str, value: Any) -> int: ... async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool: ... async def hgetall(self, key: str) -> Dict[str, Any]: ... async def delete(self, *keys: str) -> int: ... async def exists(self, key: str) -> int: ... async def expire(self, key: str, seconds: int) -> bool: ... async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[ int, list[str]]: ... async def eval(self, script: str, keys: list[str], args: list[str]) -> Any: ... async def close(self) -> None: ... class RedisAdapter(RedisConnection): """ Redis 适配器 """ def __init__(self, config: RedisConfig): self.config = config # 用于普通Redis 操作存储 self._redis = None # 用于 langchain RedisStore 存储 self._langchain_redis_client = None async def connect(self): """创建Redis连接""" self._redis = await aioredis.from_url( self.config.url, password=self.config.password, db=self.config.db, encoding="utf-8", decode_responses=True, max_connections=self.config.max_connections ) # 用于 langchain RedisStore 存储 # 必须设为 False(LangChain 需要 bytes 数据) self._langchain_redis_client = aioredis.from_url( self.config.url, password=self.config.password, db=self.config.db, encoding="utf-8", decode_responses=False, max_connections=self.config.max_connections ) # ✅ 使用同步 Redis 客户端 # self._langchain_redis_client = redis.Redis.from_url( # self.config.url, # password=self.config.password, # db=self.config.db, # decode_responses=False, # LangChain 需要 bytes # ) #错误:Expected Redis client, got Redis instead # self._langchain_redis_client = async_redis.from_url( # self.config.url, # password=self.config.password, # db=self.config.db, # decode_responses=False # ) return self async def get(self, key: str) -> Any: return await self._redis.get(key) async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool: return await self._redis.set(key, value, ex=ex, nx=nx) async def hget(self, key: str, field: str) -> Any: return await self._redis.hget(key, field) async def hset(self, key: str, field: str, value: Any) -> int: return await self._redis.hset(key, field, value) async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool: return await self._redis.hmset(key, mapping) async def hgetall(self, key: str) -> Dict[str, Any]: return await self._redis.hgetall(key) async def delete(self, *keys: str) -> int: return await self._redis.delete(*keys) async def exists(self, key: str) -> int: return await self._redis.exists(key) async def expire(self, key: str, seconds: int) -> bool: return await self._redis.expire(key, seconds) async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[ int, list[str]]: return await self._redis.scan(cursor, match=match, count=count) async def eval(self, script: str, numkeys: int, *keys_and_args: str) -> Any: return await self._redis.eval(script, numkeys, *keys_and_args) # 解包成独立参数 def get_langchain_redis_client(self): return self._langchain_redis_client async def close(self) -> None: if self._redis: await self._redis.close() await self._redis.wait_closed() if self._langchain_redis_client: await self._langchain_redis_client.close() await self._langchain_redis_client.wait_closed() class RedisConnectionFactory: """ redis 连接工厂函数 """ _connections: Dict[str, RedisConnection] = {} _stores: Dict[str, RedisStore] = {} @classmethod async def get_connection(cls) -> RedisConnection: """获取Redis连接(单例模式)""" # 加载配置 redis_config = load_config_from_env() #server_logger.info(f"redis_config={redis_config}") # 使用配置参数生成唯一标识 conn_id = f"{redis_config.url}-{redis_config.db}" if conn_id not in cls._connections: adapter = RedisAdapter(redis_config) await adapter.connect() cls._connections[conn_id] = adapter return cls._connections[conn_id] @classmethod async def get_redis_store(cls) -> RedisStore: """获取 LangChain RedisStore 实例""" # 加载配置 redis_config = load_config_from_env() conn = await cls.get_connection() # 或通过其他方式获取 client = conn.get_langchain_redis_client() return client @classmethod async def get_langchain_redis_store(cls) -> RedisStore: """获取 LangChain RedisStore 实例 目前该方法存在问题 """ # 加载配置 redis_config = load_config_from_env() # 使用配置参数生成唯一标识 store_id = f"{redis_config.url}-{redis_config.db}" if store_id not in cls._stores: conn = await cls.get_connection() # 或通过其他方式获取 client = conn.get_langchain_redis_client() store = client server_logger.info(f"client={client}") server_logger.info(f"store={dir(store)}") cls._stores[store_id] = store return cls._stores[store_id] @classmethod async def close_all(cls): """关闭所有Redis连接""" for conn in cls._connections.values(): await conn.close() cls._connections = {} @classmethod def get_connection_count(cls) -> int: """获取当前连接数""" return len(cls._connections)