| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- '''
- @Project : lq-agent-api
- @File : sse_manager.py
- @IDE : VsCode
- @Author :
- @Date : 2025-12-04 10:58:00
- =================================
- 📋 统一SSE管理器 (Unified SSE Manager)
- 🏗️ 核心功能:
- ├── UnifiedSSEManager() # 统一SSE管理器(单例)
- ├── establish_connection() # 建立连接并注册回调
- ├── close_connection() # 关闭连接(清理连接和回调)
- ├── send_progress() # 发送进度消息
- └── trigger_callback() # 触发回调函数
- 📊 状态管理:
- ├── connections # 消息队列字典
- ├── callbacks # 回调函数字典
- └── get_connection_count() # 获取连接数统计
- 🔧 实用方法:
- ├── is_connected() # 检查连接是否存在
- ├── is_callback_registered() # 检查回调是否已注册
- ├── get_stats() # 获取详细统计信息
- └── clear_all() # 清理所有连接和回调
- '''
- import asyncio
- from typing import Dict, Any, Optional, Callable
- from datetime import datetime
- from foundation.observability.logger.loggering import review_logger as logger
- class UnifiedSSEManager:
- """
- 统一的SSE管理器 - 管理SSE连接、回调函数和消息推送
- 功能:
- 1. 管理SSE消息队列连接
- 2. 管理回调函数注册和触发
- 3. 提供统一的消息推送接口
- 4. 确保连接和回调状态同步
- """
- _instance = None
- def __new__(cls):
- """单例模式实现"""
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance.connections = {} # 消息队列字典
- cls._instance.callbacks = {} # 回调函数字典
- return cls._instance
- def __init__(self):
- """初始化统一SSE管理器"""
- pass # 在__new__中已完成初始化
- async def establish_connection(self, callback_task_id: str, callback_func: Optional[Callable] = None):
- """
- 建立SSE连接并注册回调函数
- Args:
- callback_task_id: 回调任务ID
- callback_func: 可选的回调函数
- Returns:
- asyncio.Queue: 消息队列,用于SSE事件流
- """
- try:
- # 创建消息队列
- queue = asyncio.Queue()
- self.connections[callback_task_id] = queue
- # 注册回调函数(如果提供)
- if callback_func:
- self.callbacks[callback_task_id] = callback_func
- # 发送连接建立确认消息
- await queue.put({
- "type": "connection_established",
- "callback_task_id": callback_task_id,
- "timestamp": datetime.now().isoformat()
- })
- logger.info(f"SSE连接已建立: {callback_task_id}")
- logger.info(f"当前连接数: {len(self.connections)}, 回调数: {len(self.callbacks)}")
- return queue
- except Exception as e:
- logger.error(f"建立SSE连接失败: {callback_task_id}, 错误: {str(e)}")
- raise
- async def close_connection(self, callback_task_id: str):
- """
- 关闭SSE连接(同时清理连接和回调)
- Args:
- callback_task_id: 回调任务ID
- """
- try:
- connection_existed = False
- callback_existed = False
- # 1. 先向队列发送结束信号,让SSE流能够正常结束
- if callback_task_id in self.connections:
- queue = self.connections[callback_task_id]
- try:
- await queue.put({
- "type": "connection_closed",
- "callback_task_id": callback_task_id,
- "timestamp": datetime.now().isoformat()
- })
- logger.info(f"已发送连接关闭信号到队列: {callback_task_id}")
- except Exception as queue_error:
- logger.warning(f"发送关闭信号失败,队列可能已关闭: {callback_task_id}, 错误: {str(queue_error)}")
- # 2. 清理连接
- if callback_task_id in self.connections:
- del self.connections[callback_task_id]
- connection_existed = True
- logger.info(f"SSE连接已断开: {callback_task_id}")
- # 3. 清理回调
- if callback_task_id in self.callbacks:
- del self.callbacks[callback_task_id]
- callback_existed = True
- logger.info(f"SSE回调已注销: {callback_task_id}")
- if not connection_existed and not callback_existed:
- logger.debug(f"SSE连接和回调均不存在: {callback_task_id}")
- else:
- logger.info(f"SSE连接清理完成: {callback_task_id}, 剩余连接数: {len(self.connections)}, 剩余回调数: {len(self.callbacks)}")
- except Exception as e:
- logger.error(f"关闭SSE连接时出错: {callback_task_id}, 错误: {str(e)}")
- async def send_progress(self, callback_task_id: str, current_data: dict):
- """
- 发送进度消息到指定连接
- Args:
- callback_task_id: 回调任务ID
- current_data: 进度数据
- """
- try:
- queue = self.connections.get(callback_task_id)
- if queue:
- # 确定事件类型
- event_type = current_data.get("event_type", "processing")
- # 处理特殊的单元审查事件
- if event_type == "unit_review" or (event_type == "processing" and current_data.get("status") == "unit_review_update"):
- event_type = "unit_review_update"
- # 添加时间戳
- message = {
- "type": event_type,
- "data": current_data,
- "timestamp": datetime.now().isoformat()
- }
- await queue.put(message)
- logger.debug(f"SSE进度已推送: {callback_task_id}, 事件类型: {event_type}")
- else:
- logger.warning(f"SSE连接不存在,跳过进度推送: {callback_task_id} - 任务继续执行")
- except Exception as e:
- logger.error(f"发送SSE进度消息失败: {callback_task_id}, 错误: {str(e)}")
- async def trigger_callback(self, callback_task_id: str, current_data: dict):
- """
- 触发指定任务的回调函数
- Args:
- callback_task_id: 回调任务ID
- current_data: 传递给回调的数据
- Returns:
- bool: 回调是否成功触发
- """
- try:
- callback_func = self.callbacks.get(callback_task_id)
- if callback_func:
- await callback_func(callback_task_id, current_data)
- logger.debug(f"SSE回调执行成功: {callback_task_id}")
- return True
- else:
- logger.debug(f"未找到SSE回调: {callback_task_id}, 已注册ID: {list(self.callbacks.keys())}")
- return False
- except Exception as e:
- logger.error(f"SSE回调执行失败: {callback_task_id}, 错误: {str(e)}")
- return False
- def is_connected(self, callback_task_id: str) -> bool:
- """检查SSE连接是否存在"""
- return callback_task_id in self.connections
- def is_callback_registered(self, callback_task_id: str) -> bool:
- """检查回调函数是否已注册"""
- return callback_task_id in self.callbacks
- def get_connection_count(self) -> int:
- """获取当前连接数"""
- return len(self.connections)
- def get_callback_count(self) -> int:
- """获取当前回调数"""
- return len(self.callbacks)
- def get_stats(self) -> Dict[str, Any]:
- """获取详细的统计信息"""
- return {
- "connections": {
- "count": len(self.connections),
- "ids": list(self.connections.keys())
- },
- "callbacks": {
- "count": len(self.callbacks),
- "ids": list(self.callbacks.keys())
- },
- "synchronized": len(self.connections) == len(self.callbacks)
- }
- async def clear_all(self):
- """清理所有连接和回调"""
- try:
- connection_count = len(self.connections)
- callback_count = len(self.callbacks)
- self.connections.clear()
- self.callbacks.clear()
- logger.info(f"已清理所有SSE连接和回调: {connection_count}个连接, {callback_count}个回调")
- except Exception as e:
- logger.error(f"清理所有SSE连接和回调时出错: {str(e)}")
- def register_callback_only(self, callback_task_id: str, callback_func: Callable):
- """
- 仅注册回调函数(不建立连接)
- Args:
- callback_task_id: 回调任务ID
- callback_func: 回调函数
- """
- self.callbacks[callback_task_id] = callback_func
- logger.info(f"SSE回调已注册: {callback_task_id}, 当前回调数: {len(self.callbacks)}")
- def unregister_callback_only(self, callback_task_id: str):
- """
- 仅注销回调函数(不关闭连接)
- Args:
- callback_task_id: 回调任务ID
- """
- if callback_task_id in self.callbacks:
- del self.callbacks[callback_task_id]
- logger.info(f"SSE回调已注销: {callback_task_id}, 剩余回调数: {len(self.callbacks)}")
- else:
- logger.debug(f"SSE回调不存在: {callback_task_id}")
- # 创建全局单例实例
- unified_sse_manager = UnifiedSSEManager()
|