|
|
@@ -1,8 +1,3 @@
|
|
|
-"""
|
|
|
-任务进度管理器
|
|
|
-负责任务进度的存储、更新和查询
|
|
|
-"""
|
|
|
-
|
|
|
import json
|
|
|
import asyncio
|
|
|
from typing import Dict, Any, Optional
|
|
|
@@ -12,9 +7,9 @@ from foundation.logger.loggering import server_logger as logger
|
|
|
from foundation.base.config import config_handler
|
|
|
|
|
|
class SSECallbackManager:
|
|
|
- """SSE回调管理器 - 单例模式管理全局SSE回调"""
|
|
|
+ """SSE回调管理器 - 单例模式"""
|
|
|
_instance = None
|
|
|
- _callbacks = {} # {callback_task_id: callback_function}
|
|
|
+ _callbacks = {}
|
|
|
|
|
|
def __new__(cls):
|
|
|
if cls._instance is None:
|
|
|
@@ -22,27 +17,21 @@ class SSECallbackManager:
|
|
|
return cls._instance
|
|
|
|
|
|
def register_callback(self, callback_task_id: str, callback_func):
|
|
|
- """注册SSE回调函数"""
|
|
|
self._callbacks[callback_task_id] = callback_func
|
|
|
logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}")
|
|
|
|
|
|
def unregister_callback(self, callback_task_id: str):
|
|
|
- """注销SSE回调函数"""
|
|
|
if callback_task_id in self._callbacks:
|
|
|
del self._callbacks[callback_task_id]
|
|
|
logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}")
|
|
|
|
|
|
async def trigger_callback(self, callback_task_id: str, current_data: dict):
|
|
|
- """触发SSE回调"""
|
|
|
if callback_task_id in self._callbacks:
|
|
|
try:
|
|
|
- # 直接异步执行回调,保持trace上下文
|
|
|
await self._callbacks[callback_task_id](callback_task_id, current_data)
|
|
|
logger.debug(f"SSE回调执行成功: {callback_task_id}")
|
|
|
-
|
|
|
logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}")
|
|
|
return True
|
|
|
-
|
|
|
except Exception as e:
|
|
|
logger.error(f"SSE回调执行失败: {callback_task_id}, {e}")
|
|
|
return False
|
|
|
@@ -51,19 +40,16 @@ class SSECallbackManager:
|
|
|
return False
|
|
|
|
|
|
def get_callbacks_count(self):
|
|
|
- """获取当前回调数量"""
|
|
|
return len(self._callbacks)
|
|
|
|
|
|
def clear_all_callbacks(self):
|
|
|
- """清空所有回调"""
|
|
|
self._callbacks.clear()
|
|
|
logger.info("已清空所有SSE回调")
|
|
|
|
|
|
-# 全局SSE回调管理器实例
|
|
|
sse_callback_manager = SSECallbackManager()
|
|
|
|
|
|
class ProgressManager:
|
|
|
- """任务进度管理器 - 增长型进度管理版本"""
|
|
|
+ """任务进度管理器"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.redis_client = None
|
|
|
@@ -71,7 +57,6 @@ class ProgressManager:
|
|
|
self._init_redis()
|
|
|
|
|
|
def _init_redis(self):
|
|
|
- """初始化Redis连接"""
|
|
|
try:
|
|
|
import redis
|
|
|
|
|
|
@@ -80,7 +65,6 @@ class ProgressManager:
|
|
|
redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
|
|
|
redis_db = config_handler.get('redis', 'REDIS_DB', '0')
|
|
|
|
|
|
- # 构建Redis连接URL
|
|
|
if redis_password:
|
|
|
redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
|
|
|
else:
|
|
|
@@ -88,10 +72,7 @@ class ProgressManager:
|
|
|
|
|
|
logger.debug(f"ProgressManager连接Redis: {redis_url}")
|
|
|
|
|
|
- # 连接Redis
|
|
|
self.redis_client = redis.from_url(redis_url, decode_responses=True)
|
|
|
-
|
|
|
- # 测试连接
|
|
|
self.redis_client.ping()
|
|
|
self.redis_connected = True
|
|
|
logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
|
|
|
@@ -100,20 +81,13 @@ class ProgressManager:
|
|
|
logger.error(f"ProgressManager Redis连接失败: {e}")
|
|
|
self.redis_connected = False
|
|
|
logger.warning("ProgressManager将使用内存存储作为备选方案")
|
|
|
- self.current_data = {} # 备选内存存储
|
|
|
+ self.current_data = {}
|
|
|
|
|
|
async def _get_redis_key(self, callback_task_id: str) -> str:
|
|
|
- """获取Redis键名"""
|
|
|
return f"current:{callback_task_id}"
|
|
|
|
|
|
async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
|
|
|
- """初始化进度记录"""
|
|
|
try:
|
|
|
-
|
|
|
- # 设置总量为100(百分比模式)
|
|
|
- stage_name = stages[0]["stage_name"] if stages else ""
|
|
|
- message = "任务开始"
|
|
|
-
|
|
|
current_data = {
|
|
|
"user_id": user_id,
|
|
|
"current": 0,
|
|
|
@@ -125,24 +99,21 @@ class ProgressManager:
|
|
|
}
|
|
|
|
|
|
if self.redis_connected:
|
|
|
- # 使用同步Redis操作避免异步任务销毁问题
|
|
|
try:
|
|
|
redis_key = await self._get_redis_key(callback_task_id)
|
|
|
self.redis_client.setex(
|
|
|
redis_key,
|
|
|
- 3600, # 1小时过期
|
|
|
+ 3600,
|
|
|
json.dumps(current_data)
|
|
|
)
|
|
|
logger.info(f"初始化任务进度列表")
|
|
|
except Exception as redis_e:
|
|
|
logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
|
|
|
- # 降级到内存存储
|
|
|
if not hasattr(self, 'current_data'):
|
|
|
self.current_data = {}
|
|
|
self.current_data[callback_task_id] = current_data
|
|
|
logger.info(f"降级使用内存存储: {callback_task_id}")
|
|
|
else:
|
|
|
- # 使用内存存储
|
|
|
if not hasattr(self, 'current_data'):
|
|
|
self.current_data = {}
|
|
|
self.current_data[callback_task_id] = current_data
|
|
|
@@ -152,8 +123,19 @@ class ProgressManager:
|
|
|
logger.error(f"初始化进度失败: {str(e)}")
|
|
|
raise
|
|
|
|
|
|
- async def update_stage_progress(self, callback_task_id: str, stage_name: str, current: int, status: str, message: str = ""):
|
|
|
- """更新阶段进度"""
|
|
|
+ async def update_stage_progress(self, callback_task_id: str, stage_name: str = None, current: int = None, status: str = None, message: str = None, issues=None, user_id: str = None, overall_task_status: str = None):
|
|
|
+ """更新阶段进度 - 除callback_task_id外,其他参数都可选
|
|
|
+
|
|
|
+ Args:
|
|
|
+ callback_task_id: 回调任务ID(必需)
|
|
|
+ stage_name: 阶段名称(可选)
|
|
|
+ current: 当前进度(可选)
|
|
|
+ status: 状态(可选)
|
|
|
+ message: 消息(可选)
|
|
|
+ issues: 问题列表(可选)
|
|
|
+ user_id: 用户ID(可选)
|
|
|
+ overall_task_status: 整体任务状态(可选)
|
|
|
+ """
|
|
|
try:
|
|
|
task_progress = None
|
|
|
|
|
|
@@ -174,15 +156,25 @@ class ProgressManager:
|
|
|
logger.warning(f"内存中未找到任务进度: {callback_task_id}")
|
|
|
return
|
|
|
|
|
|
- # 更新进度数据
|
|
|
- task_progress["current"] = current
|
|
|
- task_progress["stage_name"] = stage_name
|
|
|
- task_progress["status"] = status
|
|
|
- task_progress["message"] = message
|
|
|
+ # 更新进度数据 - 只有非空参数才更新
|
|
|
+ if current is not None:
|
|
|
+ task_progress["current"] = current
|
|
|
+ if stage_name is not None:
|
|
|
+ task_progress["stage_name"] = stage_name
|
|
|
+ if status is not None:
|
|
|
+ task_progress["status"] = status
|
|
|
+ if message is not None:
|
|
|
+ task_progress["message"] = message
|
|
|
task_progress["updated_at"] = datetime.now().isoformat()
|
|
|
-
|
|
|
- # 保留overall_task_status字段,不要被普通进度更新覆盖
|
|
|
- if "overall_task_status" not in task_progress:
|
|
|
+ if issues is not None:
|
|
|
+ task_progress["issues"] = issues
|
|
|
+ else:
|
|
|
+ task_progress["issues"] = []
|
|
|
+ if user_id is not None:
|
|
|
+ task_progress["user_id"] = user_id
|
|
|
+ if overall_task_status is not None:
|
|
|
+ task_progress["overall_task_status"] = overall_task_status
|
|
|
+ elif "overall_task_status" not in task_progress:
|
|
|
task_progress["overall_task_status"] = "processing"
|
|
|
|
|
|
try:
|
|
|
@@ -215,7 +207,10 @@ class ProgressManager:
|
|
|
# 触发SSE推送 - 使用全局回调管理器
|
|
|
logger.debug(f"触发SSE推送: {callback_task_id}")
|
|
|
updated_progress = await self.get_progress(callback_task_id)
|
|
|
- if updated_progress:
|
|
|
+ issues = task_progress.get("issues")
|
|
|
+ if updated_progress and issues and len(issues) > 0 and issues[0] != 'clear':
|
|
|
+ await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
|
|
|
+ elif updated_progress and not issues: # 空列表时也要推送
|
|
|
await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
|
|
|
|
|
|
except Exception as e:
|
|
|
@@ -257,7 +252,8 @@ class ProgressManager:
|
|
|
else:
|
|
|
updated_at_timestamp = int(updated_at.timestamp())
|
|
|
|
|
|
- return {
|
|
|
+ # 构建返回数据
|
|
|
+ result = {
|
|
|
"callback_task_id": callback_task_id,
|
|
|
"user_id": task_progress["user_id"],
|
|
|
"current": task_progress["current"],
|
|
|
@@ -268,6 +264,12 @@ class ProgressManager:
|
|
|
"updated_at": updated_at_timestamp
|
|
|
}
|
|
|
|
|
|
+ # 添加可选字段
|
|
|
+ if "issues" in task_progress:
|
|
|
+ result["issues"] = task_progress["issues"]
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
except Exception as e:
|
|
|
logger.error(f"获取进度失败: {str(e)}")
|
|
|
return None
|