| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- import json
- import asyncio
- from typing import Dict, Any, Optional
- from datetime import datetime
- from foundation.logger.loggering import server_logger as logger
- from foundation.base.config import config_handler
- class SSECallbackManager:
- """SSE回调管理器 - 单例模式"""
- _instance = None
- _callbacks = {}
- def __new__(cls):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
- def register_callback(self, callback_task_id: str, callback_func):
- self._callbacks[callback_task_id] = callback_func
- logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}")
- def unregister_callback(self, callback_task_id: str):
- if callback_task_id in self._callbacks:
- del self._callbacks[callback_task_id]
- logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}")
- async def trigger_callback(self, callback_task_id: str, current_data: dict):
- if callback_task_id in self._callbacks:
- try:
- await self._callbacks[callback_task_id](callback_task_id, current_data)
- logger.debug(f"SSE回调执行成功: {callback_task_id}")
- logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}")
- return True
- except Exception as e:
- logger.error(f"SSE回调执行失败: {callback_task_id}, {e}")
- return False
- else:
- logger.debug(f"未找到SSE回调: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}, 已注册ID: {list(self._callbacks.keys())}")
- return False
- def get_callbacks_count(self):
- return len(self._callbacks)
- def clear_all_callbacks(self):
- self._callbacks.clear()
- logger.info("已清空所有SSE回调")
- sse_callback_manager = SSECallbackManager()
- class ProgressManager:
- """任务进度管理器"""
- def __init__(self):
- self.redis_client = None
- self.redis_connected = False
- self._init_redis()
- def _init_redis(self):
- try:
- import redis
- redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
- redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
- redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
- redis_db = config_handler.get('redis', 'REDIS_DB', '0')
- if redis_password:
- redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
- else:
- redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
- logger.debug(f"ProgressManager连接Redis: {redis_url}")
- self.redis_client = redis.from_url(redis_url, decode_responses=True)
- self.redis_client.ping()
- self.redis_connected = True
- logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
- except Exception as e:
- logger.error(f"ProgressManager Redis连接失败: {e}")
- self.redis_connected = False
- logger.warning("ProgressManager将使用内存存储作为备选方案")
- self.current_data = {}
- async def _get_redis_key(self, callback_task_id: str) -> str:
- return f"current:{callback_task_id}"
- async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
- try:
- current_data = {
- "user_id": user_id,
- "current": 0,
- "stage_name": "",
- "status": "准备开始",
- "message": "任务开始",
- "updated_at": datetime.now().isoformat(),
- "overall_task_status": "pending"
- }
- if self.redis_connected:
- try:
- redis_key = await self._get_redis_key(callback_task_id)
- self.redis_client.setex(
- redis_key,
- 3600,
- json.dumps(current_data)
- )
- logger.info(f"初始化任务进度列表")
- except Exception as redis_e:
- logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
- if not hasattr(self, 'current_data'):
- self.current_data = {}
- self.current_data[callback_task_id] = current_data
- logger.info(f"降级使用内存存储: {callback_task_id}")
- else:
- if not hasattr(self, 'current_data'):
- self.current_data = {}
- self.current_data[callback_task_id] = current_data
- logger.info(f"初始化任务进度到内存: {callback_task_id}")
- except Exception as e:
- logger.error(f"初始化进度失败: {str(e)}")
- raise
- async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None):
- """更新阶段进度 - 除callback_task_id外,其他参数都可选
- Args:
- callback_task_id: 回调任务ID(必需)
- stage_name: 阶段名称(可选)
- current: 当前进度(可选)
- status: 状态(可选)
- message: 消息(可选)
- issues: 问题列表(可选)
- user_id: 用户ID(可选)
- overall_task_status: 整体任务状态(可选)
- """
- try:
- task_progress = None
- if self.redis_connected:
- # 从Redis读取
- redis_key = await self._get_redis_key(callback_task_id)
- progress_json = self.redis_client.get(redis_key)
- if progress_json:
- task_progress = json.loads(progress_json)
- else:
- logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
- return
- else:
- # 从内存读取
- if callback_task_id in self.current_data:
- task_progress = self.current_data[callback_task_id]
- else:
- logger.warning(f"内存中未找到任务进度: {callback_task_id}")
- return
- # 更新进度数据 - 只有非空参数才更新
- if current is not None:
- task_progress["current"] = current
- if stage_name is not None:
- task_progress["stage_name"] = stage_name
- if status is not None:
- task_progress["status"] = status
- if message is not None:
- task_progress["message"] = message
- task_progress["updated_at"] = datetime.now().isoformat()
- if issues is not None:
- task_progress["issues"] = issues
- else:
- task_progress["issues"] = []
- if user_id is not None:
- task_progress["user_id"] = user_id
- if overall_task_status is not None:
- task_progress["overall_task_status"] = overall_task_status
- elif "overall_task_status" not in task_progress:
- task_progress["overall_task_status"] = "processing"
- try:
- if self.redis_connected:
- try:
- self.redis_client.setex(
- redis_key,
- 3600, # 1小时过期
- json.dumps(task_progress)
- )
- logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%")
- except Exception as sync_e:
- logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
- # 同步操作也失败时,降级到内存存储
- if not hasattr(self, 'current_data'):
- self.current_data = {}
- self.current_data[callback_task_id] = task_progress
- logger.debug(f"降级使用内存存储: {callback_task_id}")
- else:
- if not hasattr(self, 'current_data'):
- self.current_data = {}
- self.current_data[callback_task_id] = task_progress
- logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%")
- except Exception as e:
- logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
- if not hasattr(self, 'current_data'):
- self.current_data = {}
- self.current_data[callback_task_id] = task_progress
- # 触发SSE推送 - 使用全局回调管理器
- logger.debug(f"触发SSE推送: {callback_task_id}")
- updated_progress = await self.get_progress(callback_task_id)
- issues = task_progress.get("issues")
- if updated_progress and issues and len(issues) > 0 and issues[0] != 'clear':
- await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
- elif updated_progress and not issues: # 空列表时也要推送
- await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
- except Exception as e:
- logger.error(f"更新阶段进度失败: {str(e)}")
- raise
- async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
- """获取任务进度"""
- try:
- logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
- task_progress = None
- if self.redis_connected:
- # 从Redis读取
- redis_key = await self._get_redis_key(callback_task_id)
- logger.debug(f"Redis键: {redis_key}")
- progress_json = self.redis_client.get(redis_key)
- logger.debug(f"从Redis读取数据: {progress_json is not None}")
- if progress_json:
- task_progress = json.loads(progress_json)
- else:
- logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
- return None
- else:
- # 从内存读取
- if hasattr(self, 'current_data') and callback_task_id in self.current_data:
- task_progress = self.current_data[callback_task_id]
- else:
- logger.debug(f"内存中未找到任务进度: {callback_task_id}")
- return None
- # 获取overall_task_status,默认为"pending"
- overall_task_status = task_progress.get("overall_task_status", "pending")
- # 转换时间戳
- updated_at = task_progress["updated_at"]
- if isinstance(updated_at, str):
- updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
- else:
- updated_at_timestamp = int(updated_at.timestamp())
- # 构建返回数据
- result = {
- "callback_task_id": callback_task_id,
- "user_id": task_progress["user_id"],
- "current": task_progress["current"],
- "stage_name": task_progress["stage_name"],
- "status": task_progress["status"],
- "message": task_progress["message"],
- "overall_task_status": overall_task_status,
- "updated_at": updated_at_timestamp
- }
- # 添加可选字段
- if "issues" in task_progress:
- result["issues"] = task_progress["issues"]
- return result
- except Exception as e:
- logger.error(f"获取进度失败: {str(e)}")
- return None
- async def complete_task(self, callback_task_id: str):
- """标记任务完成"""
- try:
- task_progress = None
- logger.info(f"通知sse连接关闭: {callback_task_id}")
- if self.redis_connected:
- redis_key = await self._get_redis_key(callback_task_id)
- progress_json = self.redis_client.get(redis_key)
- if progress_json:
- task_progress = json.loads(progress_json)
- else:
- logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
- return
- else:
- # 从内存读取
- if hasattr(self, 'current_data') and callback_task_id in self.current_data:
- task_progress = self.current_data[callback_task_id]
- else:
- logger.warning(f"内存中未找到任务进度: {callback_task_id}")
- return
- task_progress["status"] = "completed"
- task_progress["overall_task_status"] = "completed"
- task_progress["message"] = "施工方案审查任务已完成!"
- task_progress["updated_at"] = datetime.now().isoformat()
- # 保存更新后的数据
- if self.redis_connected:
- self.redis_client.setex(
- redis_key,
- 3600,
- json.dumps(task_progress)
- )
- else:
- if hasattr(self, 'current_data'):
- self.current_data[callback_task_id] = task_progress
- # 触发SSE进度更新推送
- completed_progress = await self.get_progress(callback_task_id)
- if completed_progress:
- await sse_callback_manager.trigger_callback(callback_task_id, completed_progress)
- logger.debug(f"SSE完成进度已推送: {callback_task_id}")
- else:
- logger.warning(f"无法获取完成进度数据: {callback_task_id}")
- except Exception as e:
- logger.error(f"标记任务完成失败: {str(e)}")
- raise
-
|