import json import asyncio from typing import Dict, Any, Optional from datetime import datetime from foundation.logger.loggering import server_logger as logger from foundation.base.config import config_handler class SSECallbackManager: """SSE回调管理器 - 单例模式""" _instance = None _callbacks = {} def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def register_callback(self, callback_task_id: str, callback_func): self._callbacks[callback_task_id] = callback_func logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}") def unregister_callback(self, callback_task_id: str): if callback_task_id in self._callbacks: del self._callbacks[callback_task_id] logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}") def is_callback_registered(self, callback_task_id: str) -> bool: """检查回调是否已注册""" return callback_task_id in self._callbacks async def trigger_callback(self, callback_task_id: str, current_data: dict): if callback_task_id in self._callbacks: try: await self._callbacks[callback_task_id](callback_task_id, current_data) logger.debug(f"SSE回调执行成功: {callback_task_id}") logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}") return True except Exception as e: logger.error(f"SSE回调执行失败: {callback_task_id}, {e}") return False else: logger.debug(f"未找到SSE回调: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}, 已注册ID: {list(self._callbacks.keys())}") return False def get_callbacks_count(self): return len(self._callbacks) def clear_all_callbacks(self): self._callbacks.clear() logger.info("已清空所有SSE回调") sse_callback_manager = SSECallbackManager() class ProgressManager: """任务进度管理器""" def __init__(self): self.redis_client = None self.redis_connected = False self._init_redis() def _init_redis(self): try: import redis redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost') redis_port = config_handler.get('redis', 'REDIS_PORT', '6379') redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '') redis_db = config_handler.get('redis', 'REDIS_DB', '0') if redis_password: redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}" else: redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}" logger.debug(f"ProgressManager连接Redis: {redis_url}") self.redis_client = redis.from_url(redis_url, decode_responses=True) self.redis_client.ping() self.redis_connected = True logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}") except Exception as e: logger.error(f"ProgressManager Redis连接失败: {e}") self.redis_connected = False logger.warning("ProgressManager将使用内存存储作为备选方案") self.current_data = {} async def _get_redis_key(self, callback_task_id: str) -> str: return f"current:{callback_task_id}" async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list): try: current_data = { "user_id": user_id, "current": 0, "stage_name": "", "status": "准备开始", "message": "任务开始", "updated_at": datetime.now().isoformat(), "overall_task_status": "pending" } if self.redis_connected: try: redis_key = await self._get_redis_key(callback_task_id) self.redis_client.setex( redis_key, 3600, json.dumps(current_data) ) logger.info(f"初始化任务进度列表") except Exception as redis_e: logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}") if not hasattr(self, 'current_data'): self.current_data = {} self.current_data[callback_task_id] = current_data logger.info(f"降级使用内存存储: {callback_task_id}") else: if not hasattr(self, 'current_data'): self.current_data = {} self.current_data[callback_task_id] = current_data logger.info(f"初始化任务进度到内存: {callback_task_id}") except Exception as e: logger.error(f"初始化进度失败: {str(e)}") raise async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None, event_type: str = "processing"): """更新阶段进度 - 除callback_task_id外,其他参数都可选 Args: callback_task_id: 回调任务ID(必需) stage_name: 阶段名称(可选) current: 当前进度(可选) status: 状态(可选) message: 消息(可选) issues: 问题列表(可选) user_id: 用户ID(可选) overall_task_status: 整体任务状态(可选) event_type: SSE事件类型(可选,默认为"processing") """ try: task_progress = None if self.redis_connected: # 从Redis读取 redis_key = await self._get_redis_key(callback_task_id) progress_json = self.redis_client.get(redis_key) if progress_json: task_progress = json.loads(progress_json) else: logger.warning(f"Redis中未找到任务进度: {callback_task_id}") return else: # 从内存读取 if callback_task_id in self.current_data: task_progress = self.current_data[callback_task_id] else: logger.warning(f"内存中未找到任务进度: {callback_task_id}") return # 更新进度数据 - 只有非空参数才更新 if current is not None: task_progress["current"] = current if stage_name is not None: task_progress["stage_name"] = stage_name if status is not None: task_progress["status"] = status if message is not None: task_progress["message"] = message task_progress["updated_at"] = datetime.now().isoformat() if issues is not None: task_progress["issues"] = issues else: task_progress["issues"] = [] if user_id is not None: task_progress["user_id"] = user_id if overall_task_status is not None: task_progress["overall_task_status"] = overall_task_status elif "overall_task_status" not in task_progress: task_progress["overall_task_status"] = "processing" if event_type is not None: task_progress["event_type"] = event_type logger.debug(f"设置event_type: {event_type} for {callback_task_id}") else: logger.debug(f"event_type为None,不设置 for {callback_task_id}") try: if self.redis_connected: try: self.redis_client.setex( redis_key, 3600, # 1小时过期 json.dumps(task_progress) ) logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%") except Exception as sync_e: logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}") # 同步操作也失败时,降级到内存存储 if not hasattr(self, 'current_data'): self.current_data = {} self.current_data[callback_task_id] = task_progress logger.debug(f"降级使用内存存储: {callback_task_id}") else: if not hasattr(self, 'current_data'): self.current_data = {} self.current_data[callback_task_id] = task_progress logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%") except Exception as e: logger.error(f"保存进度数据异常: {callback_task_id}, {e}") if not hasattr(self, 'current_data'): self.current_data = {} self.current_data[callback_task_id] = task_progress if task_progress.get("overall_task_status") == "completed": logger.info(f"任务完成,不触发redis更新SSE推送,将完成信号交由lanch_review上层sse推送: {callback_task_id}") else: logger.info(f"触发SSE推送: {callback_task_id}") updated_progress = await self.get_progress(callback_task_id) issues = task_progress.get("issues") event_type = task_progress.get("event_type", "processing") logger.info(f"触发SSE回调: {callback_task_id}, event_type: {event_type}") if updated_progress and issues and len(issues) > 0 and issues[0] != 'clear': await sse_callback_manager.trigger_callback(callback_task_id, updated_progress) elif updated_progress and not issues: # 空列表时也要推送 await sse_callback_manager.trigger_callback(callback_task_id, updated_progress) except Exception as e: logger.error(f"更新阶段进度失败: {str(e)}") raise async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]: """获取任务进度""" try: logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}") task_progress = None if self.redis_connected: # 从Redis读取 redis_key = await self._get_redis_key(callback_task_id) logger.debug(f"Redis键: {redis_key}") progress_json = self.redis_client.get(redis_key) logger.debug(f"从Redis读取数据: {progress_json is not None}") if progress_json: task_progress = json.loads(progress_json) else: logger.debug(f"Redis中未找到任务进度: {callback_task_id}") return None else: # 从内存读取 if hasattr(self, 'current_data') and callback_task_id in self.current_data: task_progress = self.current_data[callback_task_id] else: logger.debug(f"内存中未找到任务进度: {callback_task_id}") return None # 获取overall_task_status,默认为"pending" overall_task_status = task_progress.get("overall_task_status", "pending") # 转换时间戳 updated_at = task_progress["updated_at"] if isinstance(updated_at, str): updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp()) else: updated_at_timestamp = int(updated_at.timestamp()) # 构建返回数据 result = { "callback_task_id": callback_task_id, "user_id": task_progress["user_id"], "current": task_progress["current"], "stage_name": task_progress["stage_name"], "status": task_progress["status"], "message": task_progress["message"], "overall_task_status": overall_task_status, "updated_at": updated_at_timestamp } # 添加可选字段 if "issues" in task_progress: result["issues"] = task_progress["issues"] if "event_type" in task_progress: result["event_type"] = task_progress["event_type"] return result except Exception as e: logger.error(f"获取进度失败: {str(e)}") return None async def complete_task(self, callback_task_id: str, user_id: str = None, current_data: dict = None): """标记任务完成""" try: logger.info(f"保存审查结果: {callback_task_id}") # 使用update_stage_progress方法更新响应数据,但不推送SSE await self.update_stage_progress( callback_task_id=callback_task_id, user_id=user_id, current=current_data.get("current", 100) if current_data else 100, stage_name="审查完成", status="completed", message="施工审查方案处理完成!", overall_task_status='completed', issues=current_data.get("issues", []) if current_data else [] ) logger.info(f"取消注册任务: {callback_task_id}") # 取消SSE回调注册,避免重复推送 sse_callback_manager.unregister_callback(callback_task_id) logger.info(f"任务关闭: {callback_task_id}") except Exception as e: logger.error(f"标记任务完成失败: {str(e)}") raise