| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- # !/usr/bin/python
- # -*- coding: utf-8 -*-
- '''
- @Project : lq-agent-api
- @File :session_manager.py
- @IDE :PyCharm
- @Author :
- @Date :2025/7/24 03:03
- '''
- import asyncio
- import logging
- import re
- import time
- from typing import Any, Dict, List, Tuple
- # 假设的导入(根据实际框架调整)
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
- from base.config import config_handler
- from base.redis_config import load_config_from_env
- from base.redis_connection import RedisConnectionFactory
- from base.async_redis_lock import AsyncRedisLock
- from langchain.memory import ConversationBufferMemory
- from langchain_community.chat_message_histories import RedisChatMessageHistory
- from langchain_core.runnables.history import RunnableWithMessageHistory
- from langchain_core.messages import get_buffer_string
- from langchain_core.messages import messages_to_dict, messages_from_dict
- from langchain.prompts import PromptTemplate
- from utils.utils import get_models
- import warnings
- from langchain_core._api.deprecation import LangChainDeprecationWarning
- from logger.loggering import server_logger
- from utils.yaml_utils import system_prompt_config
- from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage, FunctionMessage
- class SessionManager:
- """集中管理会话状态和锁定机制"""
- def __init__(self , trace_id: str , lock_key_prefix: str , session_id: str, client_id: str = "default"):
- self.trace_id = trace_id
- self.session_id = session_id
- self.client_id = client_id
- self.session_lock = None
- self.session_lock_key = lock_key_prefix + session_id
- # 上下文管理器
- self.session_context_memory_manager = SessionContextMemoryManager(trace_id , session_id)
-
-
- async def is_session_locked(self) -> bool:
- """检查会话是否被其他设备锁定"""
- if await self.redis_client.exists(self.session_lock_key):
- return True
- return False
- async def acquire_session_lock(self, timeout: float = 5) -> bool:
- """尝试获取会话锁,带超时机制"""
- config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
- server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
- if config_is_lock == "False":
- return True
-
- try:
- # 通过工厂模式获取 redis 连接器
- self.redis_client = await RedisConnectionFactory.get_connection()
- self.session_lock = AsyncRedisLock(self.redis_client, self.session_lock_key)
- server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key} (锁定设备: {self.client_id})")
- flag = await self.session_lock.acquire(timeout)
- server_logger.debug(trace_id =self.trace_id, msg=f"尝试获取锁:{self.session_lock_key}-{flag}")
- return flag
- except asyncio.TimeoutError:
- server_logger.warning(trace_id =self.trace_id, msg=f"获取会话锁超时: {self.session_lock_key}")
- return False
- except Exception as e:
- server_logger.error(trace_id =self.trace_id, msg=f"获取会话锁失败: {self.session_lock_key}, 错误: {e}")
- return False
- async def release_session_lock(self):
- """释放会话锁"""
- config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
- server_logger.info(trace_id =self.trace_id, msg=f"释放新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
- if config_is_lock == "False":
- return
- try:
- if self.session_lock:
- await self.session_lock.release()
- except Exception as e:
- server_logger.error(trace_id =self.trace_id, msg=f"释放会话锁失败: {self.session_lock_key}, 错误: {e}")
- async def get_memory_history(self):
- """
- 获取会话历史
- """
- return await self.session_context_memory_manager.get_memory_history()
-
- async def save_update_memory_history(self , history_messages , input_message , output_message):
- """
- 保存并更新历史会话
- """
- # 同步执行保存更新会话记录操作
- #await self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message)
- # 创建任务但不等待(不阻塞)
- asyncio.create_task(self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message))
- server_logger.info(trace_id =self.trace_id, msg=f"{self.session_id}: 保存并更新历史会话任务已创建,主协程继续执行结束")
- class SessionContextMemoryManager:
- """
- 会话内存上下文管理器
- """
- def __init__(self , trace_id: str, session_id: str):
- self.trace_id = trace_id
- self.session_id = session_id
- self.redis_memory = None
- # 最大历史记录长度,超过后进行摘要处理
- self.max_length = int(config_handler.get("lru", "AGENT_MAX_HISTORY_TOKENS"))
- # 意图识别 可以使用最大多少条历史记录
- self.recognize_intent_max_history = int(config_handler.get("lru", "AGENT_RECOGNIZE_INTENT_MAX_HISTORY_MESSAGES"))
- llm, chat, embed = get_models()
- self.llm = llm
- # 固定系统提示词
- self.system_prompt = system_prompt_config["summary_system_prompt"]
- # 初始化 redis 聊天历史
- self.init_redis_chat_history_memory()
-
- def init_redis_chat_history_memory(self):
- """
- 获取 Redis 中指定会话的聊天记录
- """
- # 使用 contextmanager 仅在该代码块内忽略警告
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
- redis_config = load_config_from_env()
- #server_logger.info(trace_id =self.trace_id, msg=f"redis_config={redis_config}")
- # 使用 RedisChatMessageHistory 存储对话历史
- chat_history = RedisChatMessageHistory(
- session_id=self.session_id, # 唯一标识会话
- url=redis_config.url # 或直接使用 redis_client
- )
- #使用 Redis 存储记忆
- self.redis_memory = ConversationBufferMemory(
- memory_key="chat_history",
- return_messages=True,
- chat_memory=chat_history # 或其他兼容存储
- )
- server_logger.info(trace_id=self.trace_id, msg=f"redis 内存上下文历史初始完成={self.redis_memory}")
- return self.redis_memory
- async def get_memory_history(self):
- """
- 获取内存历史(原始记录)
- """
- history_messages = self.load_memory_history()
- server_logger.debug(trace_id=self.trace_id, msg=f"begin session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
- return history_messages
- async def get_memory_last_history_str(self):
- """
- 获取内存最新的多少条历史记录(将消息列表序列化为字符串)
- # 示例消息列表
- messages = [
- HumanMessage(content="你好!"),
- AIMessage(content="我是AI助手。")
- ]
- # 转换为字符串
- formatted_str = get_buffer_string(
- messages,
- human_prefix="User", # 人类消息的前缀(默认"Human")
- ai_prefix="Assistant", # AI消息的前缀(默认"AI")
- separator="\n" # 消息分隔符(默认"\n")
- )
- 输出结果:text
- User: 你好!
- Assistant: 我是AI助手。
- """
- history_messages = self.load_memory_history()
- if history_messages is None or len(history_messages) == 0:
- return "无"
- truncated_last_history_messages = history_messages
- if len(history_messages) > self.recognize_intent_max_history:
- # 截取最新的 10 条消息-使用负数索引从末尾切片
- truncated_last_history_messages = history_messages[-self.recognize_intent_max_history:]
- # 使用安全转换 - 处理markdown格式
- #truncated_last_history_messages = self._simplify_messages_content(truncated_last_history_messages)
- # 获取历史字符串
- history_messages_str = get_buffer_string(truncated_last_history_messages)
- # 先转义大括号
- history_messages_str = history_messages_str.replace('{', '{{').replace('}', '}}')
- server_logger.info(trace_id=self.trace_id, msg=f"recognize_intent_history session_id:{self.session_id}, session.history.len: {len(history_messages)}, truncated.last.history.len: {len(truncated_last_history_messages)}, history_messages_str: {history_messages_str}")
- return history_messages_str
- def load_memory_history(self):
- """
- 加载历史会话
- """
- return self.redis_memory.load_memory_variables({})["chat_history"]
-
- def save_memory_history(self, input , output):
- """
- 保存历史会话
- """
- self.redis_memory.save_context({"input": input}, {"output": output})
-
- async def save_update_memory_history(self , history_messages , input_message , output_message):
- """
- 如果历史记录 超出 最大长度,则先清除历史记录,再保存历史会话的摘要
- 输入参数:
- history_messages: 历史会话
- input_message: 输入
- output_message: 输出
- """
- cur_messages = [input_message] + [output_message]
- tmp_messages = list(history_messages) + cur_messages
- history_messages , is_summary = await self.compress_chat_history(tmp_messages)
- server_logger.info(trace_id =self.trace_id, msg=f"保存更新历史记录处理:session_id={self.session_id},is_summary={is_summary}")
-
- if is_summary:
- # 如果是摘要消息,则清除保存的摘要消息
- if isinstance(history_messages[0], SystemMessage):
- self.clear_save_summary_memory(history_messages[0].content)
- server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
- else:
- server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(tmp_messages)}, session.history: {tmp_messages}")
- self.redis_memory.save_context({"input": input_message.content}, {"output": output_message.content})
-
- def clear_save_summary_memory(self , summary_text):
- """
- 清除原始的记录,保存的摘要历史会话
- """
- try:
- server_logger.info(trace_id =self.trace_id, msg=f"type(summary_text): {type(summary_text)}")
- server_logger.info(trace_id =self.trace_id, msg=f"summary_text: {summary_text}")
- if not isinstance(summary_text, str):
- # 安全兜底:尝试提取 .content
- try:
- summary_text = str(summary_text.content)
- except AttributeError:
- summary_text = repr(summary_text) # 最后手段
- self.redis_memory.clear()
- server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}先清除历史会话记录完成")
- self.redis_memory.save_context({"input": "整理后对话摘要"}, {"output": summary_text})
- server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}再保存摘要历史记录完成")
- except Exception as e:
- server_logger.error(trace_id =self.trace_id, msg=f"clear_save_summary_memory error: {e}")
-
- async def compress_chat_history(self , chat_history):
- """
- 压缩聊天历史,如果超过 max_length 字符则生成摘要
- """
- his_len = len(get_buffer_string(chat_history))
- server_logger.info(trace_id =self.trace_id, msg=f"his_len={his_len},max_length={self.max_length}")
- if his_len < self.max_length:
- #server_logger.info(trace_id =self.trace_id, msg="get_buffer_string(chat_history) < max_length")
- return chat_history , False
- summary_prompt = PromptTemplate(
- input_variables=["history"],
- template=self.system_prompt
- )
- # 创建可运行链:prompt + llm
- chain = summary_prompt | self.llm # 等价于 LLMChain 的功能
- # 获取历史字符串
- history_str = get_buffer_string(chat_history)
- # 异步调用(新 API)
- summary_response = await chain.ainvoke({"history": history_str})
- server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id},summary_text")
- # 返回一个“系统消息”表示摘要
- system_message = [SystemMessage(content=f"对话摘要:{summary_response.content}")]
- return system_message , True
- def _simplify_messages_content(self , history_messages):
- """
- 简化消息内容
- 处理 Markdown 内容
- """
- cleaned_messages = []
- # 在处理过程中简化内容
- for message in history_messages:
- if not hasattr(message, 'content') or not isinstance(message.content, str):
- # 如果消息没有 content 或 content 不是字符串,直接添加
- cleaned_messages.append(message)
- continue
-
- content = message.content
- content = self._simplify_markdown_content(content)
- # 创建新消息对象,保留原始消息的类型和元数据,只更新 content
- if isinstance(message, HumanMessage):
- new_msg = HumanMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
- elif isinstance(message, AIMessage):
- new_msg = AIMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
- elif isinstance(message, SystemMessage):
- new_msg = SystemMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
- elif isinstance(message, ToolMessage):
- # ToolMessage 的 content 可能不是用户生成的文本,通常不处理
- new_msg = ToolMessage(content=content, tool_call_id=message.tool_call_id, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
- elif isinstance(message, FunctionMessage):
- new_msg = FunctionMessage(content=content, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
- else:
- # 对于未知类型,尝试通用方式(如果 BaseMessage 支持)
- new_msg = message.__class__(content=content, **{k: v for k, v in message.__dict__.items() if k != 'content'})
-
- cleaned_messages.append(new_msg)
-
- return cleaned_messages
-
- def _simplify_markdown_content(self , content):
- """简化 Markdown 内容"""
- # 移除表格
- content = re.sub(r'\|.*\|.*\n\|.*\|.*(\n\|.*\|.*)*', '[表格数据]', content)
- # 移除标题
- content = re.sub(r'#{1,6}\s*', '', content)
- # 移除粗体斜体
- content = re.sub(r'[*_]{1,2}(.*?)[*_]{1,2}', r'\1', content)
- # 移除表情符号和特殊标记
- content = re.sub(r'[✅❌📋📊]', '', content)
- # 标准化换行
- content = re.sub(r'\n{3,}', '\n\n', content)
- return content.strip()
|