|
|
@@ -0,0 +1,344 @@
|
|
|
+# !/usr/bin/python
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+'''
|
|
|
+@Project : lq-agent-api
|
|
|
+@File :session_manager.py
|
|
|
+@IDE :PyCharm
|
|
|
+@Author :
|
|
|
+@Date :2025/7/24 03:03
|
|
|
+'''
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import logging
|
|
|
+import re
|
|
|
+import time
|
|
|
+from typing import Any, Dict, List, Tuple
|
|
|
+# 假设的导入(根据实际框架调整)
|
|
|
+from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
|
+from base.config import config_handler
|
|
|
+from base.redis_config import load_config_from_env
|
|
|
+from base.redis_connection import RedisConnectionFactory
|
|
|
+from base.async_redis_lock import AsyncRedisLock
|
|
|
+
|
|
|
+from langchain.memory import ConversationBufferMemory
|
|
|
+from langchain_community.chat_message_histories import RedisChatMessageHistory
|
|
|
+from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
+from langchain_core.messages import get_buffer_string
|
|
|
+from langchain_core.messages import messages_to_dict, messages_from_dict
|
|
|
+from langchain.prompts import PromptTemplate
|
|
|
+from utils.utils import get_models
|
|
|
+import warnings
|
|
|
+from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
|
+from logger.loggering import server_logger
|
|
|
+from utils.yaml_utils import system_prompt_config
|
|
|
+
|
|
|
+from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage, FunctionMessage
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class SessionManager:
|
|
|
+ """集中管理会话状态和锁定机制"""
|
|
|
+
|
|
|
+ def __init__(self , trace_id: str , lock_key_prefix: str , session_id: str, client_id: str = "default"):
|
|
|
+ self.trace_id = trace_id
|
|
|
+ self.session_id = session_id
|
|
|
+ self.client_id = client_id
|
|
|
+ self.session_lock = None
|
|
|
+ self.session_lock_key = lock_key_prefix + session_id
|
|
|
+ # 上下文管理器
|
|
|
+ self.session_context_memory_manager = SessionContextMemoryManager(trace_id , session_id)
|
|
|
+
|
|
|
+
|
|
|
+ async def is_session_locked(self) -> bool:
|
|
|
+ """检查会话是否被其他设备锁定"""
|
|
|
+ if await self.redis_client.exists(self.session_lock_key):
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+ async def acquire_session_lock(self, timeout: float = 5) -> bool:
|
|
|
+ """尝试获取会话锁,带超时机制"""
|
|
|
+ config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
|
|
|
+ if config_is_lock == "False":
|
|
|
+ return True
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 通过工厂模式获取 redis 连接器
|
|
|
+ self.redis_client = await RedisConnectionFactory.get_connection()
|
|
|
+ self.session_lock = AsyncRedisLock(self.redis_client, self.session_lock_key)
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key} (锁定设备: {self.client_id})")
|
|
|
+ flag = await self.session_lock.acquire(timeout)
|
|
|
+ server_logger.debug(trace_id =self.trace_id, msg=f"尝试获取锁:{self.session_lock_key}-{flag}")
|
|
|
+ return flag
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ server_logger.warning(trace_id =self.trace_id, msg=f"获取会话锁超时: {self.session_lock_key}")
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ server_logger.error(trace_id =self.trace_id, msg=f"获取会话锁失败: {self.session_lock_key}, 错误: {e}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ async def release_session_lock(self):
|
|
|
+ """释放会话锁"""
|
|
|
+ config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"释放新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
|
|
|
+ if config_is_lock == "False":
|
|
|
+ return
|
|
|
+ try:
|
|
|
+ if self.session_lock:
|
|
|
+ await self.session_lock.release()
|
|
|
+ except Exception as e:
|
|
|
+ server_logger.error(trace_id =self.trace_id, msg=f"释放会话锁失败: {self.session_lock_key}, 错误: {e}")
|
|
|
+
|
|
|
+
|
|
|
+ async def get_memory_history(self):
|
|
|
+ """
|
|
|
+ 获取会话历史
|
|
|
+ """
|
|
|
+ return await self.session_context_memory_manager.get_memory_history()
|
|
|
+
|
|
|
+ async def save_update_memory_history(self , history_messages , input_message , output_message):
|
|
|
+ """
|
|
|
+ 保存并更新历史会话
|
|
|
+ """
|
|
|
+ # 同步执行保存更新会话记录操作
|
|
|
+ #await self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message)
|
|
|
+
|
|
|
+ # 创建任务但不等待(不阻塞)
|
|
|
+ asyncio.create_task(self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message))
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"{self.session_id}: 保存并更新历史会话任务已创建,主协程继续执行结束")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class SessionContextMemoryManager:
|
|
|
+ """
|
|
|
+ 会话内存上下文管理器
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self , trace_id: str, session_id: str):
|
|
|
+ self.trace_id = trace_id
|
|
|
+ self.session_id = session_id
|
|
|
+ self.redis_memory = None
|
|
|
+ # 最大历史记录长度,超过后进行摘要处理
|
|
|
+ self.max_length = int(config_handler.get("lru", "AGENT_MAX_HISTORY_TOKENS"))
|
|
|
+ # 意图识别 可以使用最大多少条历史记录
|
|
|
+ self.recognize_intent_max_history = int(config_handler.get("lru", "AGENT_RECOGNIZE_INTENT_MAX_HISTORY_MESSAGES"))
|
|
|
+ llm, chat, embed = get_models()
|
|
|
+ self.llm = llm
|
|
|
+ # 固定系统提示词
|
|
|
+ self.system_prompt = system_prompt_config["summary_system_prompt"]
|
|
|
+ # 初始化 redis 聊天历史
|
|
|
+ self.init_redis_chat_history_memory()
|
|
|
+
|
|
|
+
|
|
|
+ def init_redis_chat_history_memory(self):
|
|
|
+ """
|
|
|
+ 获取 Redis 中指定会话的聊天记录
|
|
|
+ """
|
|
|
+ # 使用 contextmanager 仅在该代码块内忽略警告
|
|
|
+ with warnings.catch_warnings():
|
|
|
+ warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
|
|
+ redis_config = load_config_from_env()
|
|
|
+ #server_logger.info(trace_id =self.trace_id, msg=f"redis_config={redis_config}")
|
|
|
+ # 使用 RedisChatMessageHistory 存储对话历史
|
|
|
+ chat_history = RedisChatMessageHistory(
|
|
|
+ session_id=self.session_id, # 唯一标识会话
|
|
|
+ url=redis_config.url # 或直接使用 redis_client
|
|
|
+ )
|
|
|
+ #使用 Redis 存储记忆
|
|
|
+ self.redis_memory = ConversationBufferMemory(
|
|
|
+ memory_key="chat_history",
|
|
|
+ return_messages=True,
|
|
|
+ chat_memory=chat_history # 或其他兼容存储
|
|
|
+ )
|
|
|
+
|
|
|
+ server_logger.info(trace_id=self.trace_id, msg=f"redis 内存上下文历史初始完成={self.redis_memory}")
|
|
|
+ return self.redis_memory
|
|
|
+
|
|
|
+
|
|
|
+ async def get_memory_history(self):
|
|
|
+ """
|
|
|
+ 获取内存历史(原始记录)
|
|
|
+ """
|
|
|
+ history_messages = self.load_memory_history()
|
|
|
+ server_logger.debug(trace_id=self.trace_id, msg=f"begin session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
|
|
|
+ return history_messages
|
|
|
+
|
|
|
+
|
|
|
+ async def get_memory_last_history_str(self):
|
|
|
+ """
|
|
|
+ 获取内存最新的多少条历史记录(将消息列表序列化为字符串)
|
|
|
+ # 示例消息列表
|
|
|
+ messages = [
|
|
|
+ HumanMessage(content="你好!"),
|
|
|
+ AIMessage(content="我是AI助手。")
|
|
|
+ ]
|
|
|
+ # 转换为字符串
|
|
|
+ formatted_str = get_buffer_string(
|
|
|
+ messages,
|
|
|
+ human_prefix="User", # 人类消息的前缀(默认"Human")
|
|
|
+ ai_prefix="Assistant", # AI消息的前缀(默认"AI")
|
|
|
+ separator="\n" # 消息分隔符(默认"\n")
|
|
|
+ )
|
|
|
+ 输出结果:text
|
|
|
+ User: 你好!
|
|
|
+ Assistant: 我是AI助手。
|
|
|
+
|
|
|
+ """
|
|
|
+ history_messages = self.load_memory_history()
|
|
|
+ if history_messages is None or len(history_messages) == 0:
|
|
|
+ return "无"
|
|
|
+ truncated_last_history_messages = history_messages
|
|
|
+ if len(history_messages) > self.recognize_intent_max_history:
|
|
|
+ # 截取最新的 10 条消息-使用负数索引从末尾切片
|
|
|
+ truncated_last_history_messages = history_messages[-self.recognize_intent_max_history:]
|
|
|
+ # 使用安全转换 - 处理markdown格式
|
|
|
+ #truncated_last_history_messages = self._simplify_messages_content(truncated_last_history_messages)
|
|
|
+ # 获取历史字符串
|
|
|
+ history_messages_str = get_buffer_string(truncated_last_history_messages)
|
|
|
+ # 先转义大括号
|
|
|
+ history_messages_str = history_messages_str.replace('{', '{{').replace('}', '}}')
|
|
|
+ server_logger.info(trace_id=self.trace_id, msg=f"recognize_intent_history session_id:{self.session_id}, session.history.len: {len(history_messages)}, truncated.last.history.len: {len(truncated_last_history_messages)}, history_messages_str: {history_messages_str}")
|
|
|
+ return history_messages_str
|
|
|
+
|
|
|
+
|
|
|
+ def load_memory_history(self):
|
|
|
+ """
|
|
|
+ 加载历史会话
|
|
|
+ """
|
|
|
+ return self.redis_memory.load_memory_variables({})["chat_history"]
|
|
|
+
|
|
|
+
|
|
|
+ def save_memory_history(self, input , output):
|
|
|
+ """
|
|
|
+ 保存历史会话
|
|
|
+ """
|
|
|
+ self.redis_memory.save_context({"input": input}, {"output": output})
|
|
|
+
|
|
|
+
|
|
|
+ async def save_update_memory_history(self , history_messages , input_message , output_message):
|
|
|
+ """
|
|
|
+ 如果历史记录 超出 最大长度,则先清除历史记录,再保存历史会话的摘要
|
|
|
+ 输入参数:
|
|
|
+ history_messages: 历史会话
|
|
|
+ input_message: 输入
|
|
|
+ output_message: 输出
|
|
|
+ """
|
|
|
+ cur_messages = [input_message] + [output_message]
|
|
|
+ tmp_messages = list(history_messages) + cur_messages
|
|
|
+ history_messages , is_summary = await self.compress_chat_history(tmp_messages)
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"保存更新历史记录处理:session_id={self.session_id},is_summary={is_summary}")
|
|
|
+
|
|
|
+ if is_summary:
|
|
|
+ # 如果是摘要消息,则清除保存的摘要消息
|
|
|
+ if isinstance(history_messages[0], SystemMessage):
|
|
|
+ self.clear_save_summary_memory(history_messages[0].content)
|
|
|
+ server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
|
|
|
+ else:
|
|
|
+ server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(tmp_messages)}, session.history: {tmp_messages}")
|
|
|
+ self.redis_memory.save_context({"input": input_message.content}, {"output": output_message.content})
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def clear_save_summary_memory(self , summary_text):
|
|
|
+ """
|
|
|
+ 清除原始的记录,保存的摘要历史会话
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"type(summary_text): {type(summary_text)}")
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"summary_text: {summary_text}")
|
|
|
+ if not isinstance(summary_text, str):
|
|
|
+ # 安全兜底:尝试提取 .content
|
|
|
+ try:
|
|
|
+ summary_text = str(summary_text.content)
|
|
|
+ except AttributeError:
|
|
|
+ summary_text = repr(summary_text) # 最后手段
|
|
|
+ self.redis_memory.clear()
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}先清除历史会话记录完成")
|
|
|
+ self.redis_memory.save_context({"input": "整理后对话摘要"}, {"output": summary_text})
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}再保存摘要历史记录完成")
|
|
|
+ except Exception as e:
|
|
|
+ server_logger.error(trace_id =self.trace_id, msg=f"clear_save_summary_memory error: {e}")
|
|
|
+
|
|
|
+
|
|
|
+ async def compress_chat_history(self , chat_history):
|
|
|
+ """
|
|
|
+ 压缩聊天历史,如果超过 max_length 字符则生成摘要
|
|
|
+ """
|
|
|
+ his_len = len(get_buffer_string(chat_history))
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"his_len={his_len},max_length={self.max_length}")
|
|
|
+ if his_len < self.max_length:
|
|
|
+ #server_logger.info(trace_id =self.trace_id, msg="get_buffer_string(chat_history) < max_length")
|
|
|
+ return chat_history , False
|
|
|
+
|
|
|
+ summary_prompt = PromptTemplate(
|
|
|
+ input_variables=["history"],
|
|
|
+ template=self.system_prompt
|
|
|
+ )
|
|
|
+ # 创建可运行链:prompt + llm
|
|
|
+ chain = summary_prompt | self.llm # 等价于 LLMChain 的功能
|
|
|
+ # 获取历史字符串
|
|
|
+ history_str = get_buffer_string(chat_history)
|
|
|
+ # 异步调用(新 API)
|
|
|
+ summary_response = await chain.ainvoke({"history": history_str})
|
|
|
+ server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id},summary_text")
|
|
|
+ # 返回一个“系统消息”表示摘要
|
|
|
+ system_message = [SystemMessage(content=f"对话摘要:{summary_response.content}")]
|
|
|
+ return system_message , True
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def _simplify_messages_content(self , history_messages):
|
|
|
+ """
|
|
|
+ 简化消息内容
|
|
|
+ 处理 Markdown 内容
|
|
|
+ """
|
|
|
+ cleaned_messages = []
|
|
|
+ # 在处理过程中简化内容
|
|
|
+ for message in history_messages:
|
|
|
+ if not hasattr(message, 'content') or not isinstance(message.content, str):
|
|
|
+ # 如果消息没有 content 或 content 不是字符串,直接添加
|
|
|
+ cleaned_messages.append(message)
|
|
|
+ continue
|
|
|
+
|
|
|
+ content = message.content
|
|
|
+ content = self._simplify_markdown_content(content)
|
|
|
+ # 创建新消息对象,保留原始消息的类型和元数据,只更新 content
|
|
|
+ if isinstance(message, HumanMessage):
|
|
|
+ new_msg = HumanMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
+ elif isinstance(message, AIMessage):
|
|
|
+ new_msg = AIMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
+ elif isinstance(message, SystemMessage):
|
|
|
+ new_msg = SystemMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
+ elif isinstance(message, ToolMessage):
|
|
|
+ # ToolMessage 的 content 可能不是用户生成的文本,通常不处理
|
|
|
+ new_msg = ToolMessage(content=content, tool_call_id=message.tool_call_id, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
+ elif isinstance(message, FunctionMessage):
|
|
|
+ new_msg = FunctionMessage(content=content, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
+ else:
|
|
|
+ # 对于未知类型,尝试通用方式(如果 BaseMessage 支持)
|
|
|
+ new_msg = message.__class__(content=content, **{k: v for k, v in message.__dict__.items() if k != 'content'})
|
|
|
+
|
|
|
+ cleaned_messages.append(new_msg)
|
|
|
+
|
|
|
+ return cleaned_messages
|
|
|
+
|
|
|
+
|
|
|
+ def _simplify_markdown_content(self , content):
|
|
|
+ """简化 Markdown 内容"""
|
|
|
+ # 移除表格
|
|
|
+ content = re.sub(r'\|.*\|.*\n\|.*\|.*(\n\|.*\|.*)*', '[表格数据]', content)
|
|
|
+ # 移除标题
|
|
|
+ content = re.sub(r'#{1,6}\s*', '', content)
|
|
|
+ # 移除粗体斜体
|
|
|
+ content = re.sub(r'[*_]{1,2}(.*?)[*_]{1,2}', r'\1', content)
|
|
|
+ # 移除表情符号和特殊标记
|
|
|
+ content = re.sub(r'[✅❌📋📊]', '', content)
|
|
|
+ # 标准化换行
|
|
|
+ content = re.sub(r'\n{3,}', '\n\n', content)
|
|
|
+ return content.strip()
|
|
|
+
|
|
|
+
|
|
|
+
|