|
|
@@ -1,344 +0,0 @@
|
|
|
-# !/usr/bin/python
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
-'''
|
|
|
-@Project : lq-agent-api
|
|
|
-@File :session_manager.py
|
|
|
-@IDE :PyCharm
|
|
|
-@Author :
|
|
|
-@Date :2025/7/24 03:03
|
|
|
-'''
|
|
|
-
|
|
|
-import asyncio
|
|
|
-import logging
|
|
|
-import re
|
|
|
-import time
|
|
|
-from typing import Any, Dict, List, Tuple
|
|
|
-# 假设的导入(根据实际框架调整)
|
|
|
-from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
|
-from base.config import config_handler
|
|
|
-from base.redis_config import load_config_from_env
|
|
|
-from base.redis_connection import RedisConnectionFactory
|
|
|
-from base.async_redis_lock import AsyncRedisLock
|
|
|
-
|
|
|
-from langchain.memory import ConversationBufferMemory
|
|
|
-from langchain_community.chat_message_histories import RedisChatMessageHistory
|
|
|
-from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
-from langchain_core.messages import get_buffer_string
|
|
|
-from langchain_core.messages import messages_to_dict, messages_from_dict
|
|
|
-from langchain.prompts import PromptTemplate
|
|
|
-from utils.utils import get_models
|
|
|
-import warnings
|
|
|
-from langchain_core._api.deprecation import LangChainDeprecationWarning
|
|
|
-from logger.loggering import server_logger
|
|
|
-from utils.yaml_utils import system_prompt_config
|
|
|
-
|
|
|
-from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage, FunctionMessage
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-class SessionManager:
|
|
|
- """集中管理会话状态和锁定机制"""
|
|
|
-
|
|
|
- def __init__(self , trace_id: str , lock_key_prefix: str , session_id: str, client_id: str = "default"):
|
|
|
- self.trace_id = trace_id
|
|
|
- self.session_id = session_id
|
|
|
- self.client_id = client_id
|
|
|
- self.session_lock = None
|
|
|
- self.session_lock_key = lock_key_prefix + session_id
|
|
|
- # 上下文管理器
|
|
|
- self.session_context_memory_manager = SessionContextMemoryManager(trace_id , session_id)
|
|
|
-
|
|
|
-
|
|
|
- async def is_session_locked(self) -> bool:
|
|
|
- """检查会话是否被其他设备锁定"""
|
|
|
- if await self.redis_client.exists(self.session_lock_key):
|
|
|
- return True
|
|
|
- return False
|
|
|
-
|
|
|
- async def acquire_session_lock(self, timeout: float = 5) -> bool:
|
|
|
- """尝试获取会话锁,带超时机制"""
|
|
|
- config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
|
|
|
- if config_is_lock == "False":
|
|
|
- return True
|
|
|
-
|
|
|
- try:
|
|
|
- # 通过工厂模式获取 redis 连接器
|
|
|
- self.redis_client = await RedisConnectionFactory.get_connection()
|
|
|
- self.session_lock = AsyncRedisLock(self.redis_client, self.session_lock_key)
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key} (锁定设备: {self.client_id})")
|
|
|
- flag = await self.session_lock.acquire(timeout)
|
|
|
- server_logger.debug(trace_id =self.trace_id, msg=f"尝试获取锁:{self.session_lock_key}-{flag}")
|
|
|
- return flag
|
|
|
- except asyncio.TimeoutError:
|
|
|
- server_logger.warning(trace_id =self.trace_id, msg=f"获取会话锁超时: {self.session_lock_key}")
|
|
|
- return False
|
|
|
- except Exception as e:
|
|
|
- server_logger.error(trace_id =self.trace_id, msg=f"获取会话锁失败: {self.session_lock_key}, 错误: {e}")
|
|
|
- return False
|
|
|
-
|
|
|
- async def release_session_lock(self):
|
|
|
- """释放会话锁"""
|
|
|
- config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"释放新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
|
|
|
- if config_is_lock == "False":
|
|
|
- return
|
|
|
- try:
|
|
|
- if self.session_lock:
|
|
|
- await self.session_lock.release()
|
|
|
- except Exception as e:
|
|
|
- server_logger.error(trace_id =self.trace_id, msg=f"释放会话锁失败: {self.session_lock_key}, 错误: {e}")
|
|
|
-
|
|
|
-
|
|
|
- async def get_memory_history(self):
|
|
|
- """
|
|
|
- 获取会话历史
|
|
|
- """
|
|
|
- return await self.session_context_memory_manager.get_memory_history()
|
|
|
-
|
|
|
- async def save_update_memory_history(self , history_messages , input_message , output_message):
|
|
|
- """
|
|
|
- 保存并更新历史会话
|
|
|
- """
|
|
|
- # 同步执行保存更新会话记录操作
|
|
|
- #await self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message)
|
|
|
-
|
|
|
- # 创建任务但不等待(不阻塞)
|
|
|
- asyncio.create_task(self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message))
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"{self.session_id}: 保存并更新历史会话任务已创建,主协程继续执行结束")
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-class SessionContextMemoryManager:
|
|
|
- """
|
|
|
- 会话内存上下文管理器
|
|
|
- """
|
|
|
-
|
|
|
- def __init__(self , trace_id: str, session_id: str):
|
|
|
- self.trace_id = trace_id
|
|
|
- self.session_id = session_id
|
|
|
- self.redis_memory = None
|
|
|
- # 最大历史记录长度,超过后进行摘要处理
|
|
|
- self.max_length = int(config_handler.get("lru", "AGENT_MAX_HISTORY_TOKENS"))
|
|
|
- # 意图识别 可以使用最大多少条历史记录
|
|
|
- self.recognize_intent_max_history = int(config_handler.get("lru", "AGENT_RECOGNIZE_INTENT_MAX_HISTORY_MESSAGES"))
|
|
|
- llm, chat, embed = get_models()
|
|
|
- self.llm = llm
|
|
|
- # 固定系统提示词
|
|
|
- self.system_prompt = system_prompt_config["summary_system_prompt"]
|
|
|
- # 初始化 redis 聊天历史
|
|
|
- self.init_redis_chat_history_memory()
|
|
|
-
|
|
|
-
|
|
|
- def init_redis_chat_history_memory(self):
|
|
|
- """
|
|
|
- 获取 Redis 中指定会话的聊天记录
|
|
|
- """
|
|
|
- # 使用 contextmanager 仅在该代码块内忽略警告
|
|
|
- with warnings.catch_warnings():
|
|
|
- warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
|
|
|
- redis_config = load_config_from_env()
|
|
|
- #server_logger.info(trace_id =self.trace_id, msg=f"redis_config={redis_config}")
|
|
|
- # 使用 RedisChatMessageHistory 存储对话历史
|
|
|
- chat_history = RedisChatMessageHistory(
|
|
|
- session_id=self.session_id, # 唯一标识会话
|
|
|
- url=redis_config.url # 或直接使用 redis_client
|
|
|
- )
|
|
|
- #使用 Redis 存储记忆
|
|
|
- self.redis_memory = ConversationBufferMemory(
|
|
|
- memory_key="chat_history",
|
|
|
- return_messages=True,
|
|
|
- chat_memory=chat_history # 或其他兼容存储
|
|
|
- )
|
|
|
-
|
|
|
- server_logger.info(trace_id=self.trace_id, msg=f"redis 内存上下文历史初始完成={self.redis_memory}")
|
|
|
- return self.redis_memory
|
|
|
-
|
|
|
-
|
|
|
- async def get_memory_history(self):
|
|
|
- """
|
|
|
- 获取内存历史(原始记录)
|
|
|
- """
|
|
|
- history_messages = self.load_memory_history()
|
|
|
- server_logger.debug(trace_id=self.trace_id, msg=f"begin session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
|
|
|
- return history_messages
|
|
|
-
|
|
|
-
|
|
|
- async def get_memory_last_history_str(self):
|
|
|
- """
|
|
|
- 获取内存最新的多少条历史记录(将消息列表序列化为字符串)
|
|
|
- # 示例消息列表
|
|
|
- messages = [
|
|
|
- HumanMessage(content="你好!"),
|
|
|
- AIMessage(content="我是AI助手。")
|
|
|
- ]
|
|
|
- # 转换为字符串
|
|
|
- formatted_str = get_buffer_string(
|
|
|
- messages,
|
|
|
- human_prefix="User", # 人类消息的前缀(默认"Human")
|
|
|
- ai_prefix="Assistant", # AI消息的前缀(默认"AI")
|
|
|
- separator="\n" # 消息分隔符(默认"\n")
|
|
|
- )
|
|
|
- 输出结果:text
|
|
|
- User: 你好!
|
|
|
- Assistant: 我是AI助手。
|
|
|
-
|
|
|
- """
|
|
|
- history_messages = self.load_memory_history()
|
|
|
- if history_messages is None or len(history_messages) == 0:
|
|
|
- return "无"
|
|
|
- truncated_last_history_messages = history_messages
|
|
|
- if len(history_messages) > self.recognize_intent_max_history:
|
|
|
- # 截取最新的 10 条消息-使用负数索引从末尾切片
|
|
|
- truncated_last_history_messages = history_messages[-self.recognize_intent_max_history:]
|
|
|
- # 使用安全转换 - 处理markdown格式
|
|
|
- #truncated_last_history_messages = self._simplify_messages_content(truncated_last_history_messages)
|
|
|
- # 获取历史字符串
|
|
|
- history_messages_str = get_buffer_string(truncated_last_history_messages)
|
|
|
- # 先转义大括号
|
|
|
- history_messages_str = history_messages_str.replace('{', '{{').replace('}', '}}')
|
|
|
- server_logger.info(trace_id=self.trace_id, msg=f"recognize_intent_history session_id:{self.session_id}, session.history.len: {len(history_messages)}, truncated.last.history.len: {len(truncated_last_history_messages)}, history_messages_str: {history_messages_str}")
|
|
|
- return history_messages_str
|
|
|
-
|
|
|
-
|
|
|
- def load_memory_history(self):
|
|
|
- """
|
|
|
- 加载历史会话
|
|
|
- """
|
|
|
- return self.redis_memory.load_memory_variables({})["chat_history"]
|
|
|
-
|
|
|
-
|
|
|
- def save_memory_history(self, input , output):
|
|
|
- """
|
|
|
- 保存历史会话
|
|
|
- """
|
|
|
- self.redis_memory.save_context({"input": input}, {"output": output})
|
|
|
-
|
|
|
-
|
|
|
- async def save_update_memory_history(self , history_messages , input_message , output_message):
|
|
|
- """
|
|
|
- 如果历史记录 超出 最大长度,则先清除历史记录,再保存历史会话的摘要
|
|
|
- 输入参数:
|
|
|
- history_messages: 历史会话
|
|
|
- input_message: 输入
|
|
|
- output_message: 输出
|
|
|
- """
|
|
|
- cur_messages = [input_message] + [output_message]
|
|
|
- tmp_messages = list(history_messages) + cur_messages
|
|
|
- history_messages , is_summary = await self.compress_chat_history(tmp_messages)
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"保存更新历史记录处理:session_id={self.session_id},is_summary={is_summary}")
|
|
|
-
|
|
|
- if is_summary:
|
|
|
- # 如果是摘要消息,则清除保存的摘要消息
|
|
|
- if isinstance(history_messages[0], SystemMessage):
|
|
|
- self.clear_save_summary_memory(history_messages[0].content)
|
|
|
- server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
|
|
|
- else:
|
|
|
- server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(tmp_messages)}, session.history: {tmp_messages}")
|
|
|
- self.redis_memory.save_context({"input": input_message.content}, {"output": output_message.content})
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- def clear_save_summary_memory(self , summary_text):
|
|
|
- """
|
|
|
- 清除原始的记录,保存的摘要历史会话
|
|
|
- """
|
|
|
- try:
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"type(summary_text): {type(summary_text)}")
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"summary_text: {summary_text}")
|
|
|
- if not isinstance(summary_text, str):
|
|
|
- # 安全兜底:尝试提取 .content
|
|
|
- try:
|
|
|
- summary_text = str(summary_text.content)
|
|
|
- except AttributeError:
|
|
|
- summary_text = repr(summary_text) # 最后手段
|
|
|
- self.redis_memory.clear()
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}先清除历史会话记录完成")
|
|
|
- self.redis_memory.save_context({"input": "整理后对话摘要"}, {"output": summary_text})
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}再保存摘要历史记录完成")
|
|
|
- except Exception as e:
|
|
|
- server_logger.error(trace_id =self.trace_id, msg=f"clear_save_summary_memory error: {e}")
|
|
|
-
|
|
|
-
|
|
|
- async def compress_chat_history(self , chat_history):
|
|
|
- """
|
|
|
- 压缩聊天历史,如果超过 max_length 字符则生成摘要
|
|
|
- """
|
|
|
- his_len = len(get_buffer_string(chat_history))
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"his_len={his_len},max_length={self.max_length}")
|
|
|
- if his_len < self.max_length:
|
|
|
- #server_logger.info(trace_id =self.trace_id, msg="get_buffer_string(chat_history) < max_length")
|
|
|
- return chat_history , False
|
|
|
-
|
|
|
- summary_prompt = PromptTemplate(
|
|
|
- input_variables=["history"],
|
|
|
- template=self.system_prompt
|
|
|
- )
|
|
|
- # 创建可运行链:prompt + llm
|
|
|
- chain = summary_prompt | self.llm # 等价于 LLMChain 的功能
|
|
|
- # 获取历史字符串
|
|
|
- history_str = get_buffer_string(chat_history)
|
|
|
- # 异步调用(新 API)
|
|
|
- summary_response = await chain.ainvoke({"history": history_str})
|
|
|
- server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id},summary_text")
|
|
|
- # 返回一个“系统消息”表示摘要
|
|
|
- system_message = [SystemMessage(content=f"对话摘要:{summary_response.content}")]
|
|
|
- return system_message , True
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
- def _simplify_messages_content(self , history_messages):
|
|
|
- """
|
|
|
- 简化消息内容
|
|
|
- 处理 Markdown 内容
|
|
|
- """
|
|
|
- cleaned_messages = []
|
|
|
- # 在处理过程中简化内容
|
|
|
- for message in history_messages:
|
|
|
- if not hasattr(message, 'content') or not isinstance(message.content, str):
|
|
|
- # 如果消息没有 content 或 content 不是字符串,直接添加
|
|
|
- cleaned_messages.append(message)
|
|
|
- continue
|
|
|
-
|
|
|
- content = message.content
|
|
|
- content = self._simplify_markdown_content(content)
|
|
|
- # 创建新消息对象,保留原始消息的类型和元数据,只更新 content
|
|
|
- if isinstance(message, HumanMessage):
|
|
|
- new_msg = HumanMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
- elif isinstance(message, AIMessage):
|
|
|
- new_msg = AIMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
- elif isinstance(message, SystemMessage):
|
|
|
- new_msg = SystemMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
- elif isinstance(message, ToolMessage):
|
|
|
- # ToolMessage 的 content 可能不是用户生成的文本,通常不处理
|
|
|
- new_msg = ToolMessage(content=content, tool_call_id=message.tool_call_id, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
- elif isinstance(message, FunctionMessage):
|
|
|
- new_msg = FunctionMessage(content=content, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
|
|
|
- else:
|
|
|
- # 对于未知类型,尝试通用方式(如果 BaseMessage 支持)
|
|
|
- new_msg = message.__class__(content=content, **{k: v for k, v in message.__dict__.items() if k != 'content'})
|
|
|
-
|
|
|
- cleaned_messages.append(new_msg)
|
|
|
-
|
|
|
- return cleaned_messages
|
|
|
-
|
|
|
-
|
|
|
- def _simplify_markdown_content(self , content):
|
|
|
- """简化 Markdown 内容"""
|
|
|
- # 移除表格
|
|
|
- content = re.sub(r'\|.*\|.*\n\|.*\|.*(\n\|.*\|.*)*', '[表格数据]', content)
|
|
|
- # 移除标题
|
|
|
- content = re.sub(r'#{1,6}\s*', '', content)
|
|
|
- # 移除粗体斜体
|
|
|
- content = re.sub(r'[*_]{1,2}(.*?)[*_]{1,2}', r'\1', content)
|
|
|
- # 移除表情符号和特殊标记
|
|
|
- content = re.sub(r'[✅❌📋📊]', '', content)
|
|
|
- # 标准化换行
|
|
|
- content = re.sub(r'\n{3,}', '\n\n', content)
|
|
|
- return content.strip()
|
|
|
-
|
|
|
-
|
|
|
-
|