# !/usr/bin/python # -*- coding: utf-8 -*- ''' @Project : lq-agent-api @File :session_manager.py @IDE :PyCharm @Author : @Date :2025/7/24 03:03 ''' import asyncio import logging import re import time from typing import Any, Dict, List, Tuple # 假设的导入(根据实际框架调整) from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from base.config import config_handler from base.redis_config import load_config_from_env from base.redis_connection import RedisConnectionFactory from base.async_redis_lock import AsyncRedisLock from langchain.memory import ConversationBufferMemory from langchain_community.chat_message_histories import RedisChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.messages import get_buffer_string from langchain_core.messages import messages_to_dict, messages_from_dict from langchain.prompts import PromptTemplate from utils.utils import get_models import warnings from langchain_core._api.deprecation import LangChainDeprecationWarning from logger.loggering import server_logger from utils.yaml_utils import system_prompt_config from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage, FunctionMessage class SessionManager: """集中管理会话状态和锁定机制""" def __init__(self , trace_id: str , lock_key_prefix: str , session_id: str, client_id: str = "default"): self.trace_id = trace_id self.session_id = session_id self.client_id = client_id self.session_lock = None self.session_lock_key = lock_key_prefix + session_id # 上下文管理器 self.session_context_memory_manager = SessionContextMemoryManager(trace_id , session_id) async def is_session_locked(self) -> bool: """检查会话是否被其他设备锁定""" if await self.redis_client.exists(self.session_lock_key): return True return False async def acquire_session_lock(self, timeout: float = 5) -> bool: """尝试获取会话锁,带超时机制""" config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True") server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})") if config_is_lock == "False": return True try: # 通过工厂模式获取 redis 连接器 self.redis_client = await RedisConnectionFactory.get_connection() self.session_lock = AsyncRedisLock(self.redis_client, self.session_lock_key) server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key} (锁定设备: {self.client_id})") flag = await self.session_lock.acquire(timeout) server_logger.debug(trace_id =self.trace_id, msg=f"尝试获取锁:{self.session_lock_key}-{flag}") return flag except asyncio.TimeoutError: server_logger.warning(trace_id =self.trace_id, msg=f"获取会话锁超时: {self.session_lock_key}") return False except Exception as e: server_logger.error(trace_id =self.trace_id, msg=f"获取会话锁失败: {self.session_lock_key}, 错误: {e}") return False async def release_session_lock(self): """释放会话锁""" config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True") server_logger.info(trace_id =self.trace_id, msg=f"释放新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})") if config_is_lock == "False": return try: if self.session_lock: await self.session_lock.release() except Exception as e: server_logger.error(trace_id =self.trace_id, msg=f"释放会话锁失败: {self.session_lock_key}, 错误: {e}") async def get_memory_history(self): """ 获取会话历史 """ return await self.session_context_memory_manager.get_memory_history() async def save_update_memory_history(self , history_messages , input_message , output_message): """ 保存并更新历史会话 """ # 同步执行保存更新会话记录操作 #await self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message) # 创建任务但不等待(不阻塞) asyncio.create_task(self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message)) server_logger.info(trace_id =self.trace_id, msg=f"{self.session_id}: 保存并更新历史会话任务已创建,主协程继续执行结束") class SessionContextMemoryManager: """ 会话内存上下文管理器 """ def __init__(self , trace_id: str, session_id: str): self.trace_id = trace_id self.session_id = session_id self.redis_memory = None # 最大历史记录长度,超过后进行摘要处理 self.max_length = int(config_handler.get("lru", "AGENT_MAX_HISTORY_TOKENS")) # 意图识别 可以使用最大多少条历史记录 self.recognize_intent_max_history = int(config_handler.get("lru", "AGENT_RECOGNIZE_INTENT_MAX_HISTORY_MESSAGES")) llm, chat, embed = get_models() self.llm = llm # 固定系统提示词 self.system_prompt = system_prompt_config["summary_system_prompt"] # 初始化 redis 聊天历史 self.init_redis_chat_history_memory() def init_redis_chat_history_memory(self): """ 获取 Redis 中指定会话的聊天记录 """ # 使用 contextmanager 仅在该代码块内忽略警告 with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=LangChainDeprecationWarning) redis_config = load_config_from_env() #server_logger.info(trace_id =self.trace_id, msg=f"redis_config={redis_config}") # 使用 RedisChatMessageHistory 存储对话历史 chat_history = RedisChatMessageHistory( session_id=self.session_id, # 唯一标识会话 url=redis_config.url # 或直接使用 redis_client ) #使用 Redis 存储记忆 self.redis_memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True, chat_memory=chat_history # 或其他兼容存储 ) server_logger.info(trace_id=self.trace_id, msg=f"redis 内存上下文历史初始完成={self.redis_memory}") return self.redis_memory async def get_memory_history(self): """ 获取内存历史(原始记录) """ history_messages = self.load_memory_history() server_logger.debug(trace_id=self.trace_id, msg=f"begin session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}") return history_messages async def get_memory_last_history_str(self): """ 获取内存最新的多少条历史记录(将消息列表序列化为字符串) # 示例消息列表 messages = [ HumanMessage(content="你好!"), AIMessage(content="我是AI助手。") ] # 转换为字符串 formatted_str = get_buffer_string( messages, human_prefix="User", # 人类消息的前缀(默认"Human") ai_prefix="Assistant", # AI消息的前缀(默认"AI") separator="\n" # 消息分隔符(默认"\n") ) 输出结果:text User: 你好! Assistant: 我是AI助手。 """ history_messages = self.load_memory_history() if history_messages is None or len(history_messages) == 0: return "无" truncated_last_history_messages = history_messages if len(history_messages) > self.recognize_intent_max_history: # 截取最新的 10 条消息-使用负数索引从末尾切片 truncated_last_history_messages = history_messages[-self.recognize_intent_max_history:] # 使用安全转换 - 处理markdown格式 #truncated_last_history_messages = self._simplify_messages_content(truncated_last_history_messages) # 获取历史字符串 history_messages_str = get_buffer_string(truncated_last_history_messages) # 先转义大括号 history_messages_str = history_messages_str.replace('{', '{{').replace('}', '}}') server_logger.info(trace_id=self.trace_id, msg=f"recognize_intent_history session_id:{self.session_id}, session.history.len: {len(history_messages)}, truncated.last.history.len: {len(truncated_last_history_messages)}, history_messages_str: {history_messages_str}") return history_messages_str def load_memory_history(self): """ 加载历史会话 """ return self.redis_memory.load_memory_variables({})["chat_history"] def save_memory_history(self, input , output): """ 保存历史会话 """ self.redis_memory.save_context({"input": input}, {"output": output}) async def save_update_memory_history(self , history_messages , input_message , output_message): """ 如果历史记录 超出 最大长度,则先清除历史记录,再保存历史会话的摘要 输入参数: history_messages: 历史会话 input_message: 输入 output_message: 输出 """ cur_messages = [input_message] + [output_message] tmp_messages = list(history_messages) + cur_messages history_messages , is_summary = await self.compress_chat_history(tmp_messages) server_logger.info(trace_id =self.trace_id, msg=f"保存更新历史记录处理:session_id={self.session_id},is_summary={is_summary}") if is_summary: # 如果是摘要消息,则清除保存的摘要消息 if isinstance(history_messages[0], SystemMessage): self.clear_save_summary_memory(history_messages[0].content) server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}") else: server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(tmp_messages)}, session.history: {tmp_messages}") self.redis_memory.save_context({"input": input_message.content}, {"output": output_message.content}) def clear_save_summary_memory(self , summary_text): """ 清除原始的记录,保存的摘要历史会话 """ try: server_logger.info(trace_id =self.trace_id, msg=f"type(summary_text): {type(summary_text)}") server_logger.info(trace_id =self.trace_id, msg=f"summary_text: {summary_text}") if not isinstance(summary_text, str): # 安全兜底:尝试提取 .content try: summary_text = str(summary_text.content) except AttributeError: summary_text = repr(summary_text) # 最后手段 self.redis_memory.clear() server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}先清除历史会话记录完成") self.redis_memory.save_context({"input": "整理后对话摘要"}, {"output": summary_text}) server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}再保存摘要历史记录完成") except Exception as e: server_logger.error(trace_id =self.trace_id, msg=f"clear_save_summary_memory error: {e}") async def compress_chat_history(self , chat_history): """ 压缩聊天历史,如果超过 max_length 字符则生成摘要 """ his_len = len(get_buffer_string(chat_history)) server_logger.info(trace_id =self.trace_id, msg=f"his_len={his_len},max_length={self.max_length}") if his_len < self.max_length: #server_logger.info(trace_id =self.trace_id, msg="get_buffer_string(chat_history) < max_length") return chat_history , False summary_prompt = PromptTemplate( input_variables=["history"], template=self.system_prompt ) # 创建可运行链:prompt + llm chain = summary_prompt | self.llm # 等价于 LLMChain 的功能 # 获取历史字符串 history_str = get_buffer_string(chat_history) # 异步调用(新 API) summary_response = await chain.ainvoke({"history": history_str}) server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id},summary_text") # 返回一个“系统消息”表示摘要 system_message = [SystemMessage(content=f"对话摘要:{summary_response.content}")] return system_message , True def _simplify_messages_content(self , history_messages): """ 简化消息内容 处理 Markdown 内容 """ cleaned_messages = [] # 在处理过程中简化内容 for message in history_messages: if not hasattr(message, 'content') or not isinstance(message.content, str): # 如果消息没有 content 或 content 不是字符串,直接添加 cleaned_messages.append(message) continue content = message.content content = self._simplify_markdown_content(content) # 创建新消息对象,保留原始消息的类型和元数据,只更新 content if isinstance(message, HumanMessage): new_msg = HumanMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata) elif isinstance(message, AIMessage): new_msg = AIMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata) elif isinstance(message, SystemMessage): new_msg = SystemMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata) elif isinstance(message, ToolMessage): # ToolMessage 的 content 可能不是用户生成的文本,通常不处理 new_msg = ToolMessage(content=content, tool_call_id=message.tool_call_id, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata) elif isinstance(message, FunctionMessage): new_msg = FunctionMessage(content=content, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata) else: # 对于未知类型,尝试通用方式(如果 BaseMessage 支持) new_msg = message.__class__(content=content, **{k: v for k, v in message.__dict__.items() if k != 'content'}) cleaned_messages.append(new_msg) return cleaned_messages def _simplify_markdown_content(self , content): """简化 Markdown 内容""" # 移除表格 content = re.sub(r'\|.*\|.*\n\|.*\|.*(\n\|.*\|.*)*', '[表格数据]', content) # 移除标题 content = re.sub(r'#{1,6}\s*', '', content) # 移除粗体斜体 content = re.sub(r'[*_]{1,2}(.*?)[*_]{1,2}', r'\1', content) # 移除表情符号和特殊标记 content = re.sub(r'[✅❌📋📊]', '', content) # 标准化换行 content = re.sub(r'\n{3,}', '\n\n', content) return content.strip()