session_manager.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # !/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. '''
  4. @Project : lq-agent-api
  5. @File :session_manager.py
  6. @IDE :PyCharm
  7. @Author :
  8. @Date :2025/7/24 03:03
  9. '''
  10. import asyncio
  11. import logging
  12. import re
  13. import time
  14. from typing import Any, Dict, List, Tuple
  15. # 假设的导入(根据实际框架调整)
  16. from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
  17. from base.config import config_handler
  18. from base.redis_config import load_config_from_env
  19. from base.redis_connection import RedisConnectionFactory
  20. from base.async_redis_lock import AsyncRedisLock
  21. from langchain.memory import ConversationBufferMemory
  22. from langchain_community.chat_message_histories import RedisChatMessageHistory
  23. from langchain_core.runnables.history import RunnableWithMessageHistory
  24. from langchain_core.messages import get_buffer_string
  25. from langchain_core.messages import messages_to_dict, messages_from_dict
  26. from langchain.prompts import PromptTemplate
  27. from utils.utils import get_models
  28. import warnings
  29. from langchain_core._api.deprecation import LangChainDeprecationWarning
  30. from logger.loggering import server_logger
  31. from utils.yaml_utils import system_prompt_config
  32. from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage, FunctionMessage
  33. class SessionManager:
  34. """集中管理会话状态和锁定机制"""
  35. def __init__(self , trace_id: str , lock_key_prefix: str , session_id: str, client_id: str = "default"):
  36. self.trace_id = trace_id
  37. self.session_id = session_id
  38. self.client_id = client_id
  39. self.session_lock = None
  40. self.session_lock_key = lock_key_prefix + session_id
  41. # 上下文管理器
  42. self.session_context_memory_manager = SessionContextMemoryManager(trace_id , session_id)
  43. async def is_session_locked(self) -> bool:
  44. """检查会话是否被其他设备锁定"""
  45. if await self.redis_client.exists(self.session_lock_key):
  46. return True
  47. return False
  48. async def acquire_session_lock(self, timeout: float = 5) -> bool:
  49. """尝试获取会话锁,带超时机制"""
  50. config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
  51. server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
  52. if config_is_lock == "False":
  53. return True
  54. try:
  55. # 通过工厂模式获取 redis 连接器
  56. self.redis_client = await RedisConnectionFactory.get_connection()
  57. self.session_lock = AsyncRedisLock(self.redis_client, self.session_lock_key)
  58. server_logger.info(trace_id =self.trace_id, msg=f"创建新会话: {self.session_lock_key} (锁定设备: {self.client_id})")
  59. flag = await self.session_lock.acquire(timeout)
  60. server_logger.debug(trace_id =self.trace_id, msg=f"尝试获取锁:{self.session_lock_key}-{flag}")
  61. return flag
  62. except asyncio.TimeoutError:
  63. server_logger.warning(trace_id =self.trace_id, msg=f"获取会话锁超时: {self.session_lock_key}")
  64. return False
  65. except Exception as e:
  66. server_logger.error(trace_id =self.trace_id, msg=f"获取会话锁失败: {self.session_lock_key}, 错误: {e}")
  67. return False
  68. async def release_session_lock(self):
  69. """释放会话锁"""
  70. config_is_lock = config_handler.get("chat", "CHAT_SESSION_LOCK" , "True")
  71. server_logger.info(trace_id =self.trace_id, msg=f"释放新会话: {self.session_lock_key},config_is_lock:{config_is_lock} (锁定设备: {self.client_id})")
  72. if config_is_lock == "False":
  73. return
  74. try:
  75. if self.session_lock:
  76. await self.session_lock.release()
  77. except Exception as e:
  78. server_logger.error(trace_id =self.trace_id, msg=f"释放会话锁失败: {self.session_lock_key}, 错误: {e}")
  79. async def get_memory_history(self):
  80. """
  81. 获取会话历史
  82. """
  83. return await self.session_context_memory_manager.get_memory_history()
  84. async def save_update_memory_history(self , history_messages , input_message , output_message):
  85. """
  86. 保存并更新历史会话
  87. """
  88. # 同步执行保存更新会话记录操作
  89. #await self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message)
  90. # 创建任务但不等待(不阻塞)
  91. asyncio.create_task(self.session_context_memory_manager.save_update_memory_history(history_messages , input_message , output_message))
  92. server_logger.info(trace_id =self.trace_id, msg=f"{self.session_id}: 保存并更新历史会话任务已创建,主协程继续执行结束")
  93. class SessionContextMemoryManager:
  94. """
  95. 会话内存上下文管理器
  96. """
  97. def __init__(self , trace_id: str, session_id: str):
  98. self.trace_id = trace_id
  99. self.session_id = session_id
  100. self.redis_memory = None
  101. # 最大历史记录长度,超过后进行摘要处理
  102. self.max_length = int(config_handler.get("lru", "AGENT_MAX_HISTORY_TOKENS"))
  103. # 意图识别 可以使用最大多少条历史记录
  104. self.recognize_intent_max_history = int(config_handler.get("lru", "AGENT_RECOGNIZE_INTENT_MAX_HISTORY_MESSAGES"))
  105. llm, chat, embed = get_models()
  106. self.llm = llm
  107. # 固定系统提示词
  108. self.system_prompt = system_prompt_config["summary_system_prompt"]
  109. # 初始化 redis 聊天历史
  110. self.init_redis_chat_history_memory()
  111. def init_redis_chat_history_memory(self):
  112. """
  113. 获取 Redis 中指定会话的聊天记录
  114. """
  115. # 使用 contextmanager 仅在该代码块内忽略警告
  116. with warnings.catch_warnings():
  117. warnings.filterwarnings("ignore", category=LangChainDeprecationWarning)
  118. redis_config = load_config_from_env()
  119. #server_logger.info(trace_id =self.trace_id, msg=f"redis_config={redis_config}")
  120. # 使用 RedisChatMessageHistory 存储对话历史
  121. chat_history = RedisChatMessageHistory(
  122. session_id=self.session_id, # 唯一标识会话
  123. url=redis_config.url # 或直接使用 redis_client
  124. )
  125. #使用 Redis 存储记忆
  126. self.redis_memory = ConversationBufferMemory(
  127. memory_key="chat_history",
  128. return_messages=True,
  129. chat_memory=chat_history # 或其他兼容存储
  130. )
  131. server_logger.info(trace_id=self.trace_id, msg=f"redis 内存上下文历史初始完成={self.redis_memory}")
  132. return self.redis_memory
  133. async def get_memory_history(self):
  134. """
  135. 获取内存历史(原始记录)
  136. """
  137. history_messages = self.load_memory_history()
  138. server_logger.debug(trace_id=self.trace_id, msg=f"begin session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
  139. return history_messages
  140. async def get_memory_last_history_str(self):
  141. """
  142. 获取内存最新的多少条历史记录(将消息列表序列化为字符串)
  143. # 示例消息列表
  144. messages = [
  145. HumanMessage(content="你好!"),
  146. AIMessage(content="我是AI助手。")
  147. ]
  148. # 转换为字符串
  149. formatted_str = get_buffer_string(
  150. messages,
  151. human_prefix="User", # 人类消息的前缀(默认"Human")
  152. ai_prefix="Assistant", # AI消息的前缀(默认"AI")
  153. separator="\n" # 消息分隔符(默认"\n")
  154. )
  155. 输出结果:text
  156. User: 你好!
  157. Assistant: 我是AI助手。
  158. """
  159. history_messages = self.load_memory_history()
  160. if history_messages is None or len(history_messages) == 0:
  161. return "无"
  162. truncated_last_history_messages = history_messages
  163. if len(history_messages) > self.recognize_intent_max_history:
  164. # 截取最新的 10 条消息-使用负数索引从末尾切片
  165. truncated_last_history_messages = history_messages[-self.recognize_intent_max_history:]
  166. # 使用安全转换 - 处理markdown格式
  167. #truncated_last_history_messages = self._simplify_messages_content(truncated_last_history_messages)
  168. # 获取历史字符串
  169. history_messages_str = get_buffer_string(truncated_last_history_messages)
  170. # 先转义大括号
  171. history_messages_str = history_messages_str.replace('{', '{{').replace('}', '}}')
  172. server_logger.info(trace_id=self.trace_id, msg=f"recognize_intent_history session_id:{self.session_id}, session.history.len: {len(history_messages)}, truncated.last.history.len: {len(truncated_last_history_messages)}, history_messages_str: {history_messages_str}")
  173. return history_messages_str
  174. def load_memory_history(self):
  175. """
  176. 加载历史会话
  177. """
  178. return self.redis_memory.load_memory_variables({})["chat_history"]
  179. def save_memory_history(self, input , output):
  180. """
  181. 保存历史会话
  182. """
  183. self.redis_memory.save_context({"input": input}, {"output": output})
  184. async def save_update_memory_history(self , history_messages , input_message , output_message):
  185. """
  186. 如果历史记录 超出 最大长度,则先清除历史记录,再保存历史会话的摘要
  187. 输入参数:
  188. history_messages: 历史会话
  189. input_message: 输入
  190. output_message: 输出
  191. """
  192. cur_messages = [input_message] + [output_message]
  193. tmp_messages = list(history_messages) + cur_messages
  194. history_messages , is_summary = await self.compress_chat_history(tmp_messages)
  195. server_logger.info(trace_id =self.trace_id, msg=f"保存更新历史记录处理:session_id={self.session_id},is_summary={is_summary}")
  196. if is_summary:
  197. # 如果是摘要消息,则清除保存的摘要消息
  198. if isinstance(history_messages[0], SystemMessage):
  199. self.clear_save_summary_memory(history_messages[0].content)
  200. server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(history_messages)}, session.history: {history_messages}")
  201. else:
  202. server_logger.debug(trace_id=self.trace_id, msg=f"end session_id:{self.session_id}, session.history.len: {len(tmp_messages)}, session.history: {tmp_messages}")
  203. self.redis_memory.save_context({"input": input_message.content}, {"output": output_message.content})
  204. def clear_save_summary_memory(self , summary_text):
  205. """
  206. 清除原始的记录,保存的摘要历史会话
  207. """
  208. try:
  209. server_logger.info(trace_id =self.trace_id, msg=f"type(summary_text): {type(summary_text)}")
  210. server_logger.info(trace_id =self.trace_id, msg=f"summary_text: {summary_text}")
  211. if not isinstance(summary_text, str):
  212. # 安全兜底:尝试提取 .content
  213. try:
  214. summary_text = str(summary_text.content)
  215. except AttributeError:
  216. summary_text = repr(summary_text) # 最后手段
  217. self.redis_memory.clear()
  218. server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}先清除历史会话记录完成")
  219. self.redis_memory.save_context({"input": "整理后对话摘要"}, {"output": summary_text})
  220. server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id}再保存摘要历史记录完成")
  221. except Exception as e:
  222. server_logger.error(trace_id =self.trace_id, msg=f"clear_save_summary_memory error: {e}")
  223. async def compress_chat_history(self , chat_history):
  224. """
  225. 压缩聊天历史,如果超过 max_length 字符则生成摘要
  226. """
  227. his_len = len(get_buffer_string(chat_history))
  228. server_logger.info(trace_id =self.trace_id, msg=f"his_len={his_len},max_length={self.max_length}")
  229. if his_len < self.max_length:
  230. #server_logger.info(trace_id =self.trace_id, msg="get_buffer_string(chat_history) < max_length")
  231. return chat_history , False
  232. summary_prompt = PromptTemplate(
  233. input_variables=["history"],
  234. template=self.system_prompt
  235. )
  236. # 创建可运行链:prompt + llm
  237. chain = summary_prompt | self.llm # 等价于 LLMChain 的功能
  238. # 获取历史字符串
  239. history_str = get_buffer_string(chat_history)
  240. # 异步调用(新 API)
  241. summary_response = await chain.ainvoke({"history": history_str})
  242. server_logger.info(trace_id =self.trace_id, msg=f"session_id={self.session_id},summary_text")
  243. # 返回一个“系统消息”表示摘要
  244. system_message = [SystemMessage(content=f"对话摘要:{summary_response.content}")]
  245. return system_message , True
  246. def _simplify_messages_content(self , history_messages):
  247. """
  248. 简化消息内容
  249. 处理 Markdown 内容
  250. """
  251. cleaned_messages = []
  252. # 在处理过程中简化内容
  253. for message in history_messages:
  254. if not hasattr(message, 'content') or not isinstance(message.content, str):
  255. # 如果消息没有 content 或 content 不是字符串,直接添加
  256. cleaned_messages.append(message)
  257. continue
  258. content = message.content
  259. content = self._simplify_markdown_content(content)
  260. # 创建新消息对象,保留原始消息的类型和元数据,只更新 content
  261. if isinstance(message, HumanMessage):
  262. new_msg = HumanMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
  263. elif isinstance(message, AIMessage):
  264. new_msg = AIMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
  265. elif isinstance(message, SystemMessage):
  266. new_msg = SystemMessage(content=content, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
  267. elif isinstance(message, ToolMessage):
  268. # ToolMessage 的 content 可能不是用户生成的文本,通常不处理
  269. new_msg = ToolMessage(content=content, tool_call_id=message.tool_call_id, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
  270. elif isinstance(message, FunctionMessage):
  271. new_msg = FunctionMessage(content=content, name=message.name, additional_kwargs=message.additional_kwargs, response_metadata=message.response_metadata)
  272. else:
  273. # 对于未知类型,尝试通用方式(如果 BaseMessage 支持)
  274. new_msg = message.__class__(content=content, **{k: v for k, v in message.__dict__.items() if k != 'content'})
  275. cleaned_messages.append(new_msg)
  276. return cleaned_messages
  277. def _simplify_markdown_content(self , content):
  278. """简化 Markdown 内容"""
  279. # 移除表格
  280. content = re.sub(r'\|.*\|.*\n\|.*\|.*(\n\|.*\|.*)*', '[表格数据]', content)
  281. # 移除标题
  282. content = re.sub(r'#{1,6}\s*', '', content)
  283. # 移除粗体斜体
  284. content = re.sub(r'[*_]{1,2}(.*?)[*_]{1,2}', r'\1', content)
  285. # 移除表情符号和特殊标记
  286. content = re.sub(r'[✅❌📋📊]', '', content)
  287. # 标准化换行
  288. content = re.sub(r'\n{3,}', '\n\n', content)
  289. return content.strip()