#!/usr/bin/env python # -*- coding: utf-8 -*- ''' @Project : lq-agent-api @File : sse_manager.py @IDE : VsCode @Author : @Date : 2025-12-04 10:58:00 ================================= 📋 统一SSE管理器 (Unified SSE Manager) 🏗️ 核心功能: ├── UnifiedSSEManager() # 统一SSE管理器(单例) ├── establish_connection() # 建立连接并注册回调 ├── close_connection() # 关闭连接(清理连接和回调) ├── send_progress() # 发送进度消息 └── trigger_callback() # 触发回调函数 📊 状态管理: ├── connections # 消息队列字典 ├── callbacks # 回调函数字典 └── get_connection_count() # 获取连接数统计 🔧 实用方法: ├── is_connected() # 检查连接是否存在 ├── is_callback_registered() # 检查回调是否已注册 ├── get_stats() # 获取详细统计信息 └── clear_all() # 清理所有连接和回调 ''' import asyncio from typing import Dict, Any, Optional, Callable from datetime import datetime from foundation.logger.loggering import server_logger as logger class UnifiedSSEManager: """ 统一的SSE管理器 - 管理SSE连接、回调函数和消息推送 功能: 1. 管理SSE消息队列连接 2. 管理回调函数注册和触发 3. 提供统一的消息推送接口 4. 确保连接和回调状态同步 """ _instance = None def __new__(cls): """单例模式实现""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.connections = {} # 消息队列字典 cls._instance.callbacks = {} # 回调函数字典 return cls._instance def __init__(self): """初始化统一SSE管理器""" pass # 在__new__中已完成初始化 async def establish_connection(self, callback_task_id: str, callback_func: Optional[Callable] = None): """ 建立SSE连接并注册回调函数 Args: callback_task_id: 回调任务ID callback_func: 可选的回调函数 Returns: asyncio.Queue: 消息队列,用于SSE事件流 """ try: # 创建消息队列 queue = asyncio.Queue() self.connections[callback_task_id] = queue # 注册回调函数(如果提供) if callback_func: self.callbacks[callback_task_id] = callback_func # 发送连接建立确认消息 await queue.put({ "type": "connection_established", "callback_task_id": callback_task_id, "timestamp": datetime.now().isoformat() }) logger.info(f"SSE连接已建立: {callback_task_id}") logger.info(f"当前连接数: {len(self.connections)}, 回调数: {len(self.callbacks)}") return queue except Exception as e: logger.error(f"建立SSE连接失败: {callback_task_id}, 错误: {str(e)}") raise async def close_connection(self, callback_task_id: str): """ 关闭SSE连接(同时清理连接和回调) Args: callback_task_id: 回调任务ID """ try: connection_existed = False callback_existed = False # 1. 先向队列发送结束信号,让SSE流能够正常结束 if callback_task_id in self.connections: queue = self.connections[callback_task_id] try: await queue.put({ "type": "connection_closed", "callback_task_id": callback_task_id, "timestamp": datetime.now().isoformat() }) logger.info(f"已发送连接关闭信号到队列: {callback_task_id}") except Exception as queue_error: logger.warning(f"发送关闭信号失败,队列可能已关闭: {callback_task_id}, 错误: {str(queue_error)}") # 2. 清理连接 if callback_task_id in self.connections: del self.connections[callback_task_id] connection_existed = True logger.info(f"SSE连接已断开: {callback_task_id}") # 3. 清理回调 if callback_task_id in self.callbacks: del self.callbacks[callback_task_id] callback_existed = True logger.info(f"SSE回调已注销: {callback_task_id}") if not connection_existed and not callback_existed: logger.debug(f"SSE连接和回调均不存在: {callback_task_id}") else: logger.info(f"SSE连接清理完成: {callback_task_id}, 剩余连接数: {len(self.connections)}, 剩余回调数: {len(self.callbacks)}") except Exception as e: logger.error(f"关闭SSE连接时出错: {callback_task_id}, 错误: {str(e)}") async def send_progress(self, callback_task_id: str, current_data: dict): """ 发送进度消息到指定连接 Args: callback_task_id: 回调任务ID current_data: 进度数据 """ try: queue = self.connections.get(callback_task_id) if queue: # 确定事件类型 event_type = current_data.get("event_type", "processing") # 处理特殊的单元审查事件 if event_type == "unit_review" or (event_type == "processing" and current_data.get("status") == "unit_review_update"): event_type = "unit_review_update" # 添加时间戳 message = { "type": event_type, "data": current_data, "timestamp": datetime.now().isoformat() } await queue.put(message) logger.debug(f"SSE进度已推送: {callback_task_id}, 事件类型: {event_type}") else: logger.warning(f"SSE连接不存在,跳过进度推送: {callback_task_id} - 任务继续执行") except Exception as e: logger.error(f"发送SSE进度消息失败: {callback_task_id}, 错误: {str(e)}") async def trigger_callback(self, callback_task_id: str, current_data: dict): """ 触发指定任务的回调函数 Args: callback_task_id: 回调任务ID current_data: 传递给回调的数据 Returns: bool: 回调是否成功触发 """ try: callback_func = self.callbacks.get(callback_task_id) if callback_func: await callback_func(callback_task_id, current_data) logger.debug(f"SSE回调执行成功: {callback_task_id}") return True else: logger.debug(f"未找到SSE回调: {callback_task_id}, 已注册ID: {list(self.callbacks.keys())}") return False except Exception as e: logger.error(f"SSE回调执行失败: {callback_task_id}, 错误: {str(e)}") return False def is_connected(self, callback_task_id: str) -> bool: """检查SSE连接是否存在""" return callback_task_id in self.connections def is_callback_registered(self, callback_task_id: str) -> bool: """检查回调函数是否已注册""" return callback_task_id in self.callbacks def get_connection_count(self) -> int: """获取当前连接数""" return len(self.connections) def get_callback_count(self) -> int: """获取当前回调数""" return len(self.callbacks) def get_stats(self) -> Dict[str, Any]: """获取详细的统计信息""" return { "connections": { "count": len(self.connections), "ids": list(self.connections.keys()) }, "callbacks": { "count": len(self.callbacks), "ids": list(self.callbacks.keys()) }, "synchronized": len(self.connections) == len(self.callbacks) } async def clear_all(self): """清理所有连接和回调""" try: connection_count = len(self.connections) callback_count = len(self.callbacks) self.connections.clear() self.callbacks.clear() logger.info(f"已清理所有SSE连接和回调: {connection_count}个连接, {callback_count}个回调") except Exception as e: logger.error(f"清理所有SSE连接和回调时出错: {str(e)}") def register_callback_only(self, callback_task_id: str, callback_func: Callable): """ 仅注册回调函数(不建立连接) Args: callback_task_id: 回调任务ID callback_func: 回调函数 """ self.callbacks[callback_task_id] = callback_func logger.info(f"SSE回调已注册: {callback_task_id}, 当前回调数: {len(self.callbacks)}") def unregister_callback_only(self, callback_task_id: str): """ 仅注销回调函数(不关闭连接) Args: callback_task_id: 回调任务ID """ if callback_task_id in self.callbacks: del self.callbacks[callback_task_id] logger.info(f"SSE回调已注销: {callback_task_id}, 剩余回调数: {len(self.callbacks)}") else: logger.debug(f"SSE回调不存在: {callback_task_id}") # 创建全局单例实例 unified_sse_manager = UnifiedSSEManager()