Просмотр исходного кода

Merge branch 'dev' of http://192.168.0.3:3000/CRBC-MaaS-Platform-Project/LQAgentPlatform into dev

lingmin_package@163.com 3 месяцев назад
Родитель
Сommit
cbe6776bfa
41 измененных файлов с 2359 добавлено и 1255 удалено
  1. 5 2
      .gitignore
  2. 6 32
      README.md
  3. 7 7
      config/config.ini
  4. 0 0
      core/base/doc_worker/__init__.py
  5. 1 1
      core/base/doc_worker/config.yaml
  6. 0 0
      core/base/doc_worker/config_loader.py
  7. 0 0
      core/base/doc_worker/core.py
  8. 0 0
      core/base/doc_worker/llm_classifier.py
  9. 0 0
      core/base/doc_worker/result_saver.py
  10. 0 0
      core/base/doc_worker/text_splitter.py
  11. 0 0
      core/base/doc_worker/toc_extractor.py
  12. 266 65
      core/base/progress_manager.py
  13. 96 1
      core/base/redis_duplicate_checker.py
  14. 63 104
      core/base/workflow_manager.py
  15. 31 16
      core/construction_review/component/ai_review_engine.py
  16. 4 4
      core/construction_review/component/document_processor.py
  17. 19 2
      core/construction_review/component/reviewers/base_reviewer.py
  18. 54 24
      core/construction_review/workflows/ai_review_workflow.py
  19. 13 19
      core/construction_review/workflows/document_workflow.py
  20. 7 11
      core/construction_review/workflows/report_workflow.py
  21. 22 1
      foundation/base/celery_app.py
  22. 1 1
      foundation/base/redis_config.py
  23. 123 12
      foundation/base/redis_connection.py
  24. 9 3
      foundation/base/tasks.py
  25. 19 4
      foundation/logger/loggering.py
  26. 1 1
      foundation/rag/vector/milvus_vector.py
  27. 121 0
      foundation/trace/celery_trace.py
  28. 153 0
      foundation/trace/trace_context.py
  29. 184 10
      foundation/utils/redis_utils.py
  30. 22 5
      foundation/utils/time_statistics.py
  31. 3 13
      temp/AI审查结果.json
  32. 0 281
      test/construction_review/api_test_client.py
  33. 0 370
      test/construction_review/test_error_codes_pytest.py
  34. 190 0
      test/system_trace_id_test.py
  35. 179 0
      test/test_sse_integration.py
  36. 4 4
      views/__init__.py
  37. 16 25
      views/construction_review/app.py
  38. 83 56
      views/construction_review/file_upload.py
  39. 358 0
      views/construction_review/launch_review.py
  40. 132 51
      views/construction_review/schemas/error_schemas.py
  41. 167 130
      views/construction_review/task_progress.py

+ 5 - 2
.gitignore

@@ -48,7 +48,9 @@ coverage.xml
 # Translations
 *.mo
 *.pot
-
+*.pdf
+*.docs
+*.doc
 # Django stuff:
 *.log
 langfuse/
@@ -61,4 +63,5 @@ target/
 todo.md
 .design
 .claude
-.R&D
+.R&D
+temp\AI审查结果.json

+ 6 - 32
README.md

@@ -13,6 +13,10 @@
     - python .\views\construction_review\app.py
 
 
+    
+
+
+
   
     pip install aioredis -i https://mirrors.aliyun.com/pypi/simple/
     pip install langgraph-checkpoint-postgres -i https://mirrors.aliyun.com/pypi/simple/
@@ -30,38 +34,6 @@
 
 
 
-### 向量数据库 milvus
-   - cd /home/cjb/lq_workspace/milvus
-   - docker-compose up -d
-   - 检查服务是否正常 http://192.168.0.3:9091/healthz
-   - 拉取并运行 Attu http://192.168.0.3:13000/#/connect
-
-
-#### 测试向量数据库 检索测试接口
-  - 测试接口1
-    http://localhost:8001/test/data/bfp/milvus/search
-        {
-        "config": {
-            "session_id":"5"
-        },
-        "input": "普通模版荷载计算"
-      }
-
-
-
-   - 测试接口2
-      http://localhost:8001/test/data/bfp/milvus/search
-        {
-        "config": {
-            "session_id":"3"
-        },
-        "input": "安全生产条件"
-      }
-
-
-
-
-
 ### 测试接口
 
   #### 生成模型接口 
@@ -180,3 +152,5 @@ curl -X POST "http://localhost:8001/test/agent/stream" \
       },
       "input": "安全生产条件"
     }
+
+

+ 7 - 7
config/config.ini

@@ -1,13 +1,13 @@
 
 
 [model]
-MODEL_TYPE=qwen
+MODEL_TYPE=qwen_local_14b
 
 
 
 [gemini]
 GEMINI_SERVER_URL=https://generativelanguage.googleapis.com
-GEMINI_MODEL_ID=gemini-2.5-flash
+GEMINI_MODEL_ID=gemini-2.0-flash
 GEMINI_API_KEY=AIzaSyDcL1AZS4u9N-8OyE7q7M25wvYZhj2okJc
 
 [deepseek]
@@ -99,8 +99,8 @@ PGVECTOR_PASSWORD=pg16@123
 
 
 [milvus]
-MILVUS_HOST=192.168.0.3
-MILVUS_PORT=19530
-MILVUS_DB=lq_db
-MILVUS_USER=
-MILVUS_PASSWORD=
+MILVUS_HOST=124.223.140.149
+MILVUS_PORT=7432
+MILVUS_DB=vector_db
+MILVUS_USER=vector_user
+MILVUS_PASSWORD=pg16@123

+ 0 - 0
core/construction_review/doc_worker/__init__.py → core/base/doc_worker/__init__.py


+ 1 - 1
core/construction_review/doc_worker/config.yaml → core/base/doc_worker/config.yaml

@@ -5,7 +5,7 @@ llm:
   # 模型API地址
   model_url: "http://172.16.35.50:8000/v1/chat/completions"
   # 模型名称
-  model_name: "Qwen2.5-7B-Instruct"
+  model_name: "Qwen2.5-1.5B-Instruct"
   # 温度参数(越低越确定)
   temperature: 0.1
   # 请求超时时间(秒)

+ 0 - 0
core/construction_review/doc_worker/config_loader.py → core/base/doc_worker/config_loader.py


+ 0 - 0
core/construction_review/doc_worker/core.py → core/base/doc_worker/core.py


+ 0 - 0
core/construction_review/doc_worker/llm_classifier.py → core/base/doc_worker/llm_classifier.py


+ 0 - 0
core/construction_review/doc_worker/result_saver.py → core/base/doc_worker/result_saver.py


+ 0 - 0
core/construction_review/doc_worker/text_splitter.py → core/base/doc_worker/text_splitter.py


+ 0 - 0
core/construction_review/doc_worker/toc_extractor.py → core/base/doc_worker/toc_extractor.py


+ 266 - 65
core/base/progress_manager.py

@@ -4,59 +4,219 @@
 """
 
 import json
+import asyncio
 from typing import Dict, Any, Optional
 from datetime import datetime
 
 from foundation.logger.loggering import server_logger as logger
+from foundation.base.config import config_handler
+
+class SSECallbackManager:
+    """SSE回调管理器 - 单例模式管理全局SSE回调"""
+    _instance = None
+    _callbacks = {}  # {callback_task_id: callback_function}
+
+    def __new__(cls):
+        if cls._instance is None:
+            cls._instance = super().__new__(cls)
+        return cls._instance
+
+    def register_callback(self, callback_task_id: str, callback_func):
+        """注册SSE回调函数"""
+        self._callbacks[callback_task_id] = callback_func
+        logger.info(f"SSE回调注册, 当前注册数: {len(self._callbacks)}")
+
+    def unregister_callback(self, callback_task_id: str):
+        """注销SSE回调函数"""
+        if callback_task_id in self._callbacks:
+            del self._callbacks[callback_task_id]
+            logger.info(f"SSE回调注销, 剩余注册数: {len(self._callbacks)}")
+
+    async def trigger_callback(self, callback_task_id: str, current_data: dict):
+        """触发SSE回调"""
+        if callback_task_id in self._callbacks:
+            try:
+                # 直接异步执行回调,保持trace上下文
+                await self._callbacks[callback_task_id](callback_task_id, current_data)
+                logger.debug(f"SSE回调执行成功: {callback_task_id}")
+
+                logger.debug(f"SSE回调已触发: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}")
+                return True
+
+            except Exception as e:
+                logger.error(f"SSE回调执行失败: {callback_task_id}, {e}")
+                return False
+        else:
+            logger.debug(f"未找到SSE回调: {callback_task_id}, 当前注册回调数: {len(self._callbacks)}, 已注册ID: {list(self._callbacks.keys())}")
+            return False
+
+    def get_callbacks_count(self):
+        """获取当前回调数量"""
+        return len(self._callbacks)
+
+    def clear_all_callbacks(self):
+        """清空所有回调"""
+        self._callbacks.clear()
+        logger.info("已清空所有SSE回调")
+
+# 全局SSE回调管理器实例
+sse_callback_manager = SSECallbackManager()
 
 class ProgressManager:
-    """任务进度管理器"""
+    """任务进度管理器 - 增长型进度管理版本"""
 
     def __init__(self):
-        self.progress_data = {}  # 简化:使用内存存储
+        self.redis_client = None
+        self.redis_connected = False
+        self._init_redis()
+
+    def _init_redis(self):
+        """初始化Redis连接"""
+        try:
+            import redis
+
+            redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
+            redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
+            redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '')
+            redis_db = config_handler.get('redis', 'REDIS_DB', '0')
+
+            # 构建Redis连接URL
+            if redis_password:
+                redis_url = f"redis://:{redis_password}@{redis_host}:{redis_port}/{redis_db}"
+            else:
+                redis_url = f"redis://{redis_host}:{redis_port}/{redis_db}"
+
+            logger.debug(f"ProgressManager连接Redis: {redis_url}")
+
+            # 连接Redis
+            self.redis_client = redis.from_url(redis_url, decode_responses=True)
+
+            # 测试连接
+            self.redis_client.ping()
+            self.redis_connected = True
+            logger.debug(f"ProgressManager Redis连接成功: {redis_host}:{redis_port}")
+
+        except Exception as e:
+            logger.error(f"ProgressManager Redis连接失败: {e}")
+            self.redis_connected = False
+            logger.warning("ProgressManager将使用内存存储作为备选方案")
+            self.current_data = {}  # 备选内存存储
+
+    async def _get_redis_key(self, callback_task_id: str) -> str:
+        """获取Redis键名"""
+        return f"current:{callback_task_id}"
 
     async def initialize_progress(self, callback_task_id: str, user_id: str, stages: list):
         """初始化进度记录"""
         try:
-            self.progress_data[callback_task_id] = {
+
+            # 设置总量为100(百分比模式)
+            stage_name = stages[0]["stage_name"] if stages else ""
+            message = "任务开始"
+
+            current_data = {
                 "user_id": user_id,
-                "overall_progress": 0,
-                "current_stage": stages[0]["stage_name"] if stages else "",
-                "stages": stages,
-                "updated_at": datetime.now()
+                "current": 0,
+                "stage_name": "",
+                "status": "准备开始",
+                "message": "任务开始",
+                "updated_at": datetime.now().isoformat(),
+                "overall_task_status": "pending"
             }
-            logger.info(f"初始化任务进度: {callback_task_id}")
+
+            if self.redis_connected:
+                # 使用同步Redis操作避免异步任务销毁问题
+                try:
+                    redis_key = await self._get_redis_key(callback_task_id)
+                    self.redis_client.setex(
+                        redis_key,
+                        3600,  # 1小时过期
+                        json.dumps(current_data)
+                    )
+                    logger.info(f"初始化任务进度列表")
+                except Exception as redis_e:
+                    logger.warning(f"初始化进度到Redis失败: {callback_task_id}, {redis_e}")
+                    # 降级到内存存储
+                    if not hasattr(self, 'current_data'):
+                        self.current_data = {}
+                    self.current_data[callback_task_id] = current_data
+                    logger.info(f"降级使用内存存储: {callback_task_id}")
+            else:
+                # 使用内存存储
+                if not hasattr(self, 'current_data'):
+                    self.current_data = {}
+                self.current_data[callback_task_id] = current_data
+                logger.info(f"初始化任务进度到内存: {callback_task_id}")
 
         except Exception as e:
             logger.error(f"初始化进度失败: {str(e)}")
             raise
 
-    async def update_stage_progress(self, callback_task_id: str, stage_name: str,
-                                  progress: int, status: str, message: str = "",
-                                  sub_progress: int = 0):
+    async def update_stage_progress(self, callback_task_id: str, stage_name: str, current: int, status: str, message: str = ""):
         """更新阶段进度"""
         try:
-            if callback_task_id not in self.progress_data:
-                logger.warning(f"任务进度不存在: {callback_task_id}")
-                return
+            task_progress = None
+
+            if self.redis_connected:
+                # 从Redis读取
+                redis_key = await self._get_redis_key(callback_task_id)
+                progress_json = self.redis_client.get(redis_key)
+                if progress_json:
+                    task_progress = json.loads(progress_json)
+                else:
+                    logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
+                    return
+            else:
+                # 从内存读取
+                if callback_task_id in self.current_data:
+                    task_progress = self.current_data[callback_task_id]
+                else:
+                    logger.warning(f"内存中未找到任务进度: {callback_task_id}")
+                    return
 
-            task_progress = self.progress_data[callback_task_id]
+            # 更新进度数据
+            task_progress["current"] = current
+            task_progress["stage_name"] = stage_name
+            task_progress["status"] = status
+            task_progress["message"] = message
+            task_progress["updated_at"] = datetime.now().isoformat()
 
-            # 更新阶段进度
-            for stage in task_progress["stages"]:
-                if stage["stage_name"] == stage_name:
-                    stage["progress"] = progress
-                    stage["stage_status"] = status
-                    stage["message"] = message
-                    stage["sub_progress"] = sub_progress
-                    break
+            # 保留overall_task_status字段,不要被普通进度更新覆盖
+            if "overall_task_status" not in task_progress:
+                task_progress["overall_task_status"] = "processing"
 
-            # 更新当前阶段和整体进度
-            task_progress["current_stage"] = stage_name
-            task_progress["overall_progress"] = self._calculate_overall_progress(task_progress["stages"])
-            task_progress["updated_at"] = datetime.now()
+            try:
+                if self.redis_connected:
+                    try:
+                        self.redis_client.setex(
+                            redis_key,
+                            3600,  # 1小时过期
+                            json.dumps(task_progress)
+                        )
+                        logger.debug(f"更新进度到Redis: {callback_task_id}, 进度: {current}%")
+                    except Exception as sync_e:
+                        logger.warning(f"同步Redis操作失败: {callback_task_id}, {sync_e}")
+                        # 同步操作也失败时,降级到内存存储
+                        if not hasattr(self, 'current_data'):
+                            self.current_data = {}
+                        self.current_data[callback_task_id] = task_progress
+                        logger.debug(f"降级使用内存存储: {callback_task_id}")
+                else:
+                    if not hasattr(self, 'current_data'):
+                        self.current_data = {}
+                    self.current_data[callback_task_id] = task_progress
+                    logger.debug(f"更新进度到内存: {callback_task_id}, 进度: {current}%")
+            except Exception as e:
+                logger.error(f"保存进度数据异常: {callback_task_id}, {e}")
+                if not hasattr(self, 'current_data'):
+                    self.current_data = {}
+                self.current_data[callback_task_id] = task_progress
 
-            logger.debug(f"更新进度: {callback_task_id}, 阶段: {stage_name}, 进度: {progress}%")
+            # 触发SSE推送 - 使用全局回调管理器
+            logger.debug(f"触发SSE推送: {callback_task_id}")
+            updated_progress = await self.get_progress(callback_task_id)
+            if updated_progress:
+                await sse_callback_manager.trigger_callback(callback_task_id, updated_progress)
 
         except Exception as e:
             logger.error(f"更新阶段进度失败: {str(e)}")
@@ -65,61 +225,102 @@ class ProgressManager:
     async def get_progress(self, callback_task_id: str) -> Optional[Dict[str, Any]]:
         """获取任务进度"""
         try:
-            if callback_task_id not in self.progress_data:
-                return None
-
-            task_progress = self.progress_data[callback_task_id]
-
-            # 计算整体状态
-            if any(stage["stage_status"] == "failed" for stage in task_progress["stages"]):
-                review_task_status = "failed"
-            elif all(stage["stage_status"] == "completed" for stage in task_progress["stages"]):
-                review_task_status = "completed"
-            elif any(stage["stage_status"] == "processing" for stage in task_progress["stages"]):
-                review_task_status = "processing"
+            logger.debug(f"开始获取进度: {callback_task_id}, Redis连接状态: {self.redis_connected}")
+            task_progress = None
+
+            if self.redis_connected:
+                # 从Redis读取
+                redis_key = await self._get_redis_key(callback_task_id)
+                logger.debug(f"Redis键: {redis_key}")
+                progress_json = self.redis_client.get(redis_key)
+                logger.debug(f"从Redis读取数据: {progress_json is not None}")
+                if progress_json:
+                    task_progress = json.loads(progress_json)
+                else:
+                    logger.debug(f"Redis中未找到任务进度: {callback_task_id}")
+                    return None
             else:
-                review_task_status = "pending"
+                # 从内存读取
+                if hasattr(self, 'current_data') and callback_task_id in self.current_data:
+                    task_progress = self.current_data[callback_task_id]
+                else:
+                    logger.debug(f"内存中未找到任务进度: {callback_task_id}")
+                    return None
+
+            # 获取overall_task_status,默认为"pending"
+            overall_task_status = task_progress.get("overall_task_status", "pending")
+
+            # 转换时间戳
+            updated_at = task_progress["updated_at"]
+            if isinstance(updated_at, str):
+                updated_at_timestamp = int(datetime.fromisoformat(updated_at).timestamp())
+            else:
+                updated_at_timestamp = int(updated_at.timestamp())
 
             return {
                 "callback_task_id": callback_task_id,
                 "user_id": task_progress["user_id"],
-                "review_task_status": review_task_status,
-                "overall_progress": task_progress["overall_progress"],
-                "stages": task_progress["stages"],
-                "updated_at": int(task_progress["updated_at"].timestamp()),
-                "estimated_remaining": 600
+                "current": task_progress["current"],
+                "stage_name": task_progress["stage_name"],
+                "status": task_progress["status"],
+                "message": task_progress["message"],
+                "overall_task_status": overall_task_status,
+                "updated_at": updated_at_timestamp
             }
 
         except Exception as e:
             logger.error(f"获取进度失败: {str(e)}")
             return None
 
-    async def complete_task(self, callback_task_id: str, result: Dict[str, Any]):
+    async def complete_task(self, callback_task_id: str):
         """标记任务完成"""
         try:
-            if callback_task_id in self.progress_data:
-                task_progress = self.progress_data[callback_task_id]
-
-                # 完成最后一个阶段
-                if task_progress["stages"]:
-                    task_progress["stages"][-1]["stage_status"] = "completed"
-                    task_progress["stages"][-1]["progress"] = 100
+            task_progress = None
+            logger.info(f"通知sse连接关闭: {callback_task_id}")
+            if self.redis_connected:
+                redis_key = await self._get_redis_key(callback_task_id)
+                progress_json = self.redis_client.get(redis_key)
+                if progress_json:
+                    task_progress = json.loads(progress_json)
+                else:
+                    logger.warning(f"Redis中未找到任务进度: {callback_task_id}")
+                    return
+            else:
+                # 从内存读取
+                if hasattr(self, 'current_data') and callback_task_id in self.current_data:
+                    task_progress = self.current_data[callback_task_id]
+                else:
+                    logger.warning(f"内存中未找到任务进度: {callback_task_id}")
+                    return
 
-                task_progress["overall_progress"] = 100
-                task_progress["updated_at"] = datetime.now()
+            task_progress["status"] = "completed"
+            task_progress["overall_task_status"] = "completed"
+            task_progress["message"] = "任务已全部完成"
+            task_progress["updated_at"] = datetime.now().isoformat()
 
-                # 保存结果
-                task_progress["result"] = result
 
-            logger.info(f"任务完成: {callback_task_id}")
+            # 保存更新后的数据
+            if self.redis_connected:
+                self.redis_client.setex(
+                    redis_key,
+                    3600,
+                    json.dumps(task_progress)
+                )
+            else:
+                if hasattr(self, 'current_data'):
+                    self.current_data[callback_task_id] = task_progress
 
+            # 触发SSE进度更新推送
+            completed_progress = await self.get_progress(callback_task_id)
+            if completed_progress:
+                await sse_callback_manager.trigger_callback(callback_task_id, completed_progress)
+                logger.debug(f"SSE完成进度已推送: {callback_task_id}")
+            else:
+                logger.warning(f"无法获取完成进度数据: {callback_task_id}")
         except Exception as e:
             logger.error(f"标记任务完成失败: {str(e)}")
             raise
 
-    def _calculate_overall_progress(self, stages: list) -> int:
-        """计算整体进度"""
-        if not stages:
-            return 0
-        total_progress = sum(stage["progress"] for stage in stages)
-        return int(total_progress / len(stages))
+
+
+    

+ 96 - 1
core/base/redis_duplicate_checker.py

@@ -87,6 +87,7 @@ class RedisDuplicateChecker:
             task_data = {
                 "callback_task_id": callback_task_id,
                 "created_at": datetime.now().isoformat(),
+                "used": False,  # 标记任务是否已被使用启动审查
                 "file_info": serializable_file_info
             }
 
@@ -121,6 +122,37 @@ class RedisDuplicateChecker:
         except Exception as e:
             logger.error(f"取消注册任务失败: {str(e)}")
 
+    async def is_valid_task_id(self, callback_task_id: str) -> bool:
+        """验证任务ID是否存在且未过期"""
+        try:
+            if self.use_redis:
+                # 遍历所有任务键,查找匹配的callback_task_id
+                keys = self.redis_client.keys("task:*")
+                for key in keys:
+                    task_info = self.redis_client.get(key)
+                    if task_info:
+                        task_data = json.loads(task_info)
+                        if task_data.get("callback_task_id") == callback_task_id:
+                            created_at = datetime.fromisoformat(task_data['created_at'])
+                            if datetime.now() - created_at < timedelta(minutes=2):
+                                return True
+                            else:
+                                # 任务已过期,清理
+                                self.redis_client.delete(key)
+                return False
+            else:
+                # 内存模式检查
+                for file_id, task_info in self.task_cache.items():
+                    if task_info.get("callback_task_id") == callback_task_id:
+                        created_at = datetime.fromisoformat(task_info['created_at'])
+                        if datetime.now() - created_at < timedelta(minutes=2):
+                            return True
+                return False
+
+        except Exception as e:
+            logger.error(f"验证任务ID失败: {str(e)}")
+            return False
+
     async def get_task_info(self, file_id: str) -> str:
         """获取任务信息"""
         try:
@@ -158,4 +190,67 @@ class RedisDuplicateChecker:
                     logger.info(f"清理过期缓存: {len(expired_files)} 个文件")
 
         except Exception as e:
-            logger.error(f"清理过期缓存失败: {str(e)}")
+            logger.error(f"清理过期缓存失败: {str(e)}")
+
+    async def is_task_already_used(self, callback_task_id: str) -> bool:
+        """检查任务是否已经被使用启动审查"""
+        try:
+            if self.use_redis:
+                # 遍历所有任务键,查找匹配的callback_task_id
+                keys = self.redis_client.keys("task:*")
+                for key in keys:
+                    task_info = self.redis_client.get(key)
+                    if task_info:
+                        task_data = json.loads(task_info)
+                        if task_data.get("callback_task_id") == callback_task_id:
+                            # 检查任务是否已被使用
+                            if task_data.get("used", False):
+                                logger.info(f"任务已被使用: {callback_task_id}")
+                                return True
+                            else:
+                                return False
+                return False
+            else:
+                # 内存模式检查
+                for file_id, task_info in self.task_cache.items():
+                    if task_info.get("callback_task_id") == callback_task_id:
+                        if task_info.get("used", False):
+                            return True
+                        else:
+                            return False
+                return False
+
+        except Exception as e:
+            logger.error(f"检查任务使用状态失败: {str(e)}")
+            return False
+
+    async def mark_task_as_used(self, callback_task_id: str):
+        """标记任务为已使用"""
+        try:
+            if self.use_redis:
+                # 遍历所有任务键,查找匹配的callback_task_id
+                keys = self.redis_client.keys("task:*")
+                for key in keys:
+                    task_info = self.redis_client.get(key)
+                    if task_info:
+                        task_data = json.loads(task_info)
+                        if task_data.get("callback_task_id") == callback_task_id:
+                            # 更新used字段为True
+                            task_data["used"] = True
+                            self.redis_client.setex(
+                                key,
+                                3600,  # 1小时
+                                json.dumps(task_data, ensure_ascii=False)
+                            )
+                            logger.info(f"任务已标记为使用: {callback_task_id}")
+                            return
+            else:
+                # 内存模式
+                for file_id, task_info in self.task_cache.items():
+                    if task_info.get("callback_task_id") == callback_task_id:
+                        task_info["used"] = True
+                        logger.info(f"任务已标记为使用: {callback_task_id}")
+                        return
+
+        except Exception as e:
+            logger.error(f"标记任务使用状态失败: {str(e)}")

+ 63 - 104
core/base/workflow_manager.py

@@ -12,6 +12,7 @@ from dataclasses import dataclass
 from langgraph.graph import StateGraph, END
 from langgraph.graph.message import add_messages
 from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
+import json
 
 from foundation.logger.loggering import server_logger as logger
 from foundation.utils.time_statistics import track_execution_time
@@ -19,6 +20,28 @@ from .progress_manager import ProgressManager
 from .redis_duplicate_checker import RedisDuplicateChecker
 from ..construction_review.workflows import DocumentWorkflow,AIReviewWorkflow,ReportWorkflow
 
+class ProgressManagerRegistry:
+    """ProgressManager注册表 - 为每个任务管理独立的ProgressManager实例"""
+    _registry = {}  # {callback_task_id: ProgressManager}
+
+    @classmethod
+    def register_progress_manager(cls, callback_task_id: str, progress_manager: ProgressManager):
+        """注册ProgressManager实例"""
+        cls._registry[callback_task_id] = progress_manager
+        logger.info(f"注册ProgressManager实例: {callback_task_id}, ID: {id(progress_manager)}")
+
+    @classmethod
+    def get_progress_manager(cls, callback_task_id: str) -> ProgressManager:
+        """获取ProgressManager实例"""
+        return cls._registry.get(callback_task_id)
+
+    @classmethod
+    def unregister_progress_manager(cls, callback_task_id: str):
+        """注销ProgressManager实例"""
+        if callback_task_id in cls._registry:
+            del cls._registry[callback_task_id]
+            logger.info(f"注销ProgressManager实例: {callback_task_id}")
+
 @dataclass
 class TaskChain:
     """任务链"""
@@ -48,7 +71,7 @@ class WorkflowManager:
         self.review_semaphore = asyncio.Semaphore(max_concurrent_reviews)
 
         # 服务组件
-        self.progress_manager = ProgressManager()
+        self.progress_manager = ProgressManager()  # 简化:直接使用实例
         self.redis_duplicate_checker = RedisDuplicateChecker()
 
         # 活跃任务跟踪
@@ -58,12 +81,16 @@ class WorkflowManager:
     async def submit_task_processing(self, file_info: dict) -> str:
         """异步提交任务处理(用于file_upload层)"""
         from foundation.base.tasks import submit_task_processing_task
+        from foundation.trace.celery_trace import CeleryTraceManager
 
         try:
             logger.info(f"提交文档处理任务到Celery: {file_info['file_id']}")
 
-            # 提交到Celery队列
-            task = submit_task_processing_task.delay(file_info)
+            # 使用CeleryTraceManager提交任务,自动传递trace_id
+            task = CeleryTraceManager.submit_celery_task(
+                submit_task_processing_task,
+                file_info
+            )
 
             logger.info(f"Celery任务已提交,Task ID: {task.id}")
             return task.id
@@ -76,7 +103,6 @@ class WorkflowManager:
         """同步提交任务处理(用于Celery worker)"""
         try:
 
-
             logger.info(f"提交文档处理任务: {file_info['file_id']}")
 
             # 1. 生成任务链ID
@@ -85,31 +111,26 @@ class WorkflowManager:
             # 2. 创建任务链
             task_chain = TaskChain(
                 callback_task_id=callback_task_id,
-                file_id=file_info['file_id'],
-                user_id=file_info['user_id'],
+                file_id=file_info.get('file_id', ''),
+                user_id=file_info.get('user_id', 'default_user'),
                 status="processing",
                 current_stage="document_processing",
                 created_at=datetime.now()
             )
 
-            # 4. 注册任务
-            asyncio.run(self.redis_duplicate_checker.register_task(file_info, callback_task_id))
+            # 3. 添加到活跃任务跟踪
             self.active_chains[callback_task_id] = task_chain
 
-            # 5. 初始化进度
+            # 4. 初始化进度管理
             asyncio.run(self.progress_manager.initialize_progress(
                 callback_task_id=callback_task_id,
-                user_id=file_info['user_id'],
-                stages=[
-                    {"stage_name": "文件上传", "progress": 100, "status": "completed"},
-                    {"stage_name": "文档处理", "progress": 0, "status": "pending"},
-                    {"stage_name": "AI审查", "progress": 0, "status": "pending"},
-                    {"stage_name": "报告生成", "progress": 0, "status": "pending"}
-                ]
+                user_id=file_info.get('user_id', 'default_user'),
+                stages=[]
             ))
 
             # 6. 启动处理流程(同步执行)
             self._process_task_chain_sync(task_chain, file_info['file_content'], file_info['file_type'])
+
             # logger.info(f"提交文档处理任务: {callback_task_id}")
             logger.info(f"施工方案审查任务已完成! ")
             logger.info(f"文件ID: {file_info['file_id']}")
@@ -120,94 +141,6 @@ class WorkflowManager:
             raise
     
 
-    async def _process_task_chain(self, task_chain: TaskChain, file_content: bytes, file_type: str):
-        """处理文档任务链 - 串行执行,内部并发"""
-        try:
-            task_chain.started_at = datetime.now()
-
-            # 阶段1:文档处理(串行)
-            async with self.doc_semaphore:
-                task_chain.current_stage = "document_processing"
-
-                document_workflow = DocumentWorkflow(
-                    file_id=task_chain.file_id,
-                    callback_task_id=task_chain.callback_task_id,
-                    user_id=task_chain.user_id,
-                    progress_manager=self.progress_manager,
-                    redis_duplicate_checker=self.redis_duplicate_checker
-                )
-
-                doc_result = await document_workflow.execute(file_content, file_type)
-                task_chain.results['document'] = doc_result
-
-            # 阶段2:AI审查(内部并发)
-            task_chain.current_stage = "ai_review"
-
-            structured_content = doc_result['structured_content']
-
-            # 读取AI审查配置
-            import configparser
-            config = configparser.ConfigParser()
-            config.read('config/config.ini', encoding='utf-8')
-
-            max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
-            if max_review_units == 0:  # 如果配置为0,表示审查所有
-                max_review_units = None
-            review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
-
-            logger.info(f"AI审查配置: 最大审查条文数量={max_review_units}, 审查模式={review_mode}")
-
-            ai_workflow = AIReviewWorkflow(
-                file_id=task_chain.file_id,
-                callback_task_id=task_chain.callback_task_id,
-                user_id=task_chain.user_id,
-                structured_content=structured_content,
-                progress_manager=self.progress_manager,
-                max_review_units=max_review_units,
-                review_mode=review_mode
-            )
-
-            ai_result = await ai_workflow.execute()
-            task_chain.results['ai_review'] = ai_result
-
-            # 阶段3:报告生成(串行)
-            task_chain.current_stage = "report_generation"
-
-            report_workflow = ReportWorkflow(
-                file_id=task_chain.file_id,
-                callback_task_id=task_chain.callback_task_id,
-                user_id=task_chain.user_id,
-                ai_review_results=ai_result,
-                progress_manager=self.progress_manager
-            )
-
-            report_result = await report_workflow.execute()
-            task_chain.results['report'] = report_result
-
-            # 完成任务链
-            task_chain.status = "completed"
-            task_chain.completed_at = datetime.now()
-
-            # 清理任务注册
-            await self.redis_duplicate_checker.unregister_task(task_chain.file_id)
-
-            logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
-
-        except Exception as e:
-            task_chain.status = "failed"
-            logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}")
-
-            # 清理任务注册
-            await self.redis_duplicate_checker.unregister_task(task_chain.file_id)
-
-            raise
-        finally:
-            # 清理活跃任务
-            if task_chain.callback_task_id in self.active_chains:
-                del self.active_chains[task_chain.callback_task_id]
-
-
-
     def _process_task_chain_sync(self, task_chain: TaskChain, file_content: bytes, file_type: str):
         """同步处理文档任务链(用于Celery worker)"""
         try:
@@ -292,6 +225,16 @@ class WorkflowManager:
 
             # 清理任务注册
             asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
+            # 通知SSE连接任务完成
+            asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id))
+
+            # 清理Redis文件缓存
+            try:
+                from foundation.utils.redis_utils import delete_file_info
+                asyncio.run(delete_file_info(task_chain.file_id))
+                logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
+            except Exception as e:
+                logger.warning(f"清理Redis文件缓存失败: {str(e)}")
 
             logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
             return task_chain.results
@@ -303,6 +246,22 @@ class WorkflowManager:
             # 清理任务注册
             asyncio.run(self.redis_duplicate_checker.unregister_task(task_chain.file_id))
 
+            # 清理Redis文件缓存(即使失败也清理)
+            try:
+                from foundation.utils.redis_utils import delete_file_info
+                asyncio.run(delete_file_info(task_chain.file_id))
+                logger.info(f"已清理Redis文件缓存: {task_chain.file_id}")
+            except Exception as cleanup_error:
+                logger.warning(f"清理Redis文件缓存失败: {str(cleanup_error)}")
+
+            # 通知SSE连接任务失败
+            error_result = {
+                "error": str(e),
+                "status": "failed",
+                "timestamp": datetime.now().isoformat()
+            }
+            asyncio.run(self.progress_manager.complete_task(task_chain.callback_task_id))
+
             raise
         finally:
             # 清理活跃任务

+ 31 - 16
core/construction_review/component/ai_review_engine.py

@@ -3,6 +3,7 @@ AI审查引擎
 负责执行AI审查,支持审查条目并发处理
 """
 
+import time
 import asyncio
 from enum import Enum
 from dataclasses import dataclass
@@ -61,7 +62,8 @@ class AIReviewEngine(BaseReviewer):
 
 
     
-    async def basic_compliance_check(self,trace_id_idx: str, unit_content: Dict[str, Any]) -> Dict[str, Any]:
+    async def basic_compliance_check(self,trace_id_idx: str, unit_content: Dict[str, Any],
+                                   stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """基础合规性检查"""
         review_content = unit_content['content']
         review_references = unit_content.get('review_references')
@@ -70,7 +72,7 @@ class AIReviewEngine(BaseReviewer):
 
         async def check_with_semaphore(check_func, *args):
             async with self.semaphore:
-                return await check_func(*args)
+                return await check_func(*args, stage_name=stage_name, state=state, current_progress=current_progress)
 
         basic_tasks = [
             check_with_semaphore(self.check_grammar, trace_id_idx, review_content, review_references),
@@ -97,14 +99,15 @@ class AIReviewEngine(BaseReviewer):
             'overall_score': self._calculate_basic_score(grammar_result, semantic_result, completeness_result)
         }
 
-    async def technical_compliance_check(self, trace_id_idx: str, unit_content: Dict[str, Any]) -> Dict[str, Any]:
+    async def technical_compliance_check(self, trace_id_idx: str, unit_content: Dict[str, Any],
+                                      stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """技术性合规检查"""
         review_content = unit_content['content']
         review_references = unit_content.get('review_references')
         logger.info(f"开始技术性合规检查,内容长度: {len(review_content)}")
         async def check_with_semaphore(check_func, *args):
             async with self.semaphore:
-                return await check_func(*args)
+                return await check_func(*args, stage_name=stage_name, state=state, current_progress=current_progress)
 
         technical_tasks = [
             check_with_semaphore(self.check_mandatory_standards, trace_id_idx, review_content,review_references),
@@ -148,47 +151,59 @@ class AIReviewEngine(BaseReviewer):
         }
 
 
-    async def check_grammar(self, trace_id_idx: str, review_content: str = None, review_references: str = None) -> Dict[str, Any]:
+    async def check_grammar(self, trace_id_idx: str, review_content: str = None, review_references: str = None,
+                          stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """语法检查"""
         reviewer_type = Stage.BASIC.value['reviewer_type']
         prompt_name = Stage.BASIC.value['sensitive']
         trace_id = prompt_name+trace_id_idx
-        return await self.review("语法检查", trace_id, reviewer_type, prompt_name, review_content,review_references)
+        return await self.review("语法检查", trace_id, reviewer_type, prompt_name, review_content, review_references,
+                               stage_name, state, current_progress)
 
-    async def check_semantic_logic(self, trace_id_idx: str, review_content: str = None, review_references: str = None) -> Dict[str, Any]:
+    async def check_semantic_logic(self, trace_id_idx: str, review_content: str = None, review_references: str = None,
+                                 stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """语义逻辑检查"""
         reviewer_type = Stage.BASIC.value['reviewer_type']
         prompt_name = Stage.BASIC.value['semantic']
         trace_id = prompt_name+trace_id_idx
-        return await self.review("语义逻辑检查", trace_id, reviewer_type, prompt_name, review_content,review_references)
+        return await self.review("语义逻辑检查", trace_id, reviewer_type, prompt_name, review_content, review_references,
+                               stage_name, state, current_progress)
 
-    async def check_completeness(self, trace_id_idx: str, review_content: str = None, review_references: str = None) -> Dict[str, Any]:
+    async def check_completeness(self, trace_id_idx: str, review_content: str = None, review_references: str = None,
+                               stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """完整性检查"""
         reviewer_type = Stage.BASIC.value['reviewer_type']
         prompt_name = Stage.BASIC.value['completeness']
         trace_id = prompt_name+trace_id_idx
-        return await self.review("完整性检查", trace_id, reviewer_type, prompt_name, review_content,review_references)
+        return await self.review("完整性检查", trace_id, reviewer_type, prompt_name, review_content, review_references,
+                               stage_name, state, current_progress)
 
-    async def check_mandatory_standards(self, trace_id_idx: str, review_content: str = None, review_references: str = None) -> Dict[str, Any]:
+    async def check_mandatory_standards(self, trace_id_idx: str, review_content: str = None, review_references: str = None,
+                                        stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """强制性标准检查"""
         reviewer_type = Stage.TECHNICAL.value['reviewer_type']
         prompt_name = Stage.TECHNICAL.value['mandatory']
         trace_id = prompt_name+trace_id_idx
-        return await self.review("强制性标准检查", trace_id, reviewer_type, prompt_name, review_content,review_references)
+        return await self.review("强制性标准检查", trace_id, reviewer_type, prompt_name, review_content, review_references,
+                               stage_name, state, current_progress)
 
-    async def check_design_values(self, trace_id_idx: str, review_content: str = None, review_references: str = None) -> Dict[str, Any]:
+    async def check_design_values(self, trace_id_idx: str, review_content: str = None, review_references: str = None,
+                                  stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """设计值检查"""
         reviewer_type = Stage.TECHNICAL.value['reviewer_type']
         prompt_name = Stage.TECHNICAL.value['design']
         trace_id = prompt_name+trace_id_idx
-        return await self.review("设计值检查", trace_id, reviewer_type, prompt_name, review_content,review_references)
+        return await self.review("设计值检查", trace_id, reviewer_type, prompt_name, review_content, review_references,
+                               stage_name, state, current_progress)
 
-    async def check_technical_parameters(self, trace_id_idx: str, review_content: str = None, review_references: str = None) -> Dict[str, Any]:
+    async def check_technical_parameters(self, trace_id_idx: str, review_content: str = None, review_references: str = None,
+                                         stage_name: str = None, state: dict = None, current_progress: int = None) -> Dict[str, Any]:
         """技术参数检查"""
         reviewer_type = Stage.TECHNICAL.value['reviewer_type']
         prompt_name = Stage.TECHNICAL.value['technical']
         trace_id = prompt_name+trace_id_idx
-        return await self.review("技术参数检查", trace_id, reviewer_type, prompt_name, review_content,review_references)
+        return await self.review("技术参数检查", trace_id, reviewer_type, prompt_name, review_content, review_references,
+                               stage_name, state, current_progress)
 
     # RAG检索增强
     async def vector_search(self, content: str) -> List[Dict[str, Any]]:

+ 4 - 4
core/construction_review/component/document_processor.py

@@ -15,11 +15,11 @@ from foundation.logger.loggering import server_logger as logger
 
 # 引入doc_worker核心组件
 try:
-    from ..doc_worker import TOCExtractor, TextSplitter, LLMClassifier
-    from ..doc_worker.config_loader import get_config
+    from base.doc_worker import TOCExtractor, TextSplitter, LLMClassifier
+    from base.doc_worker.config_loader import get_config
 except ImportError:
-    from core.construction_review.doc_worker import TOCExtractor, TextSplitter, LLMClassifier
-    from core.construction_review.doc_worker.config_loader import get_config
+    from core.base.doc_worker import TOCExtractor, TextSplitter, LLMClassifier
+    from core.base.doc_worker.config_loader import get_config
 
 class DocumentProcessor:
     """文档处理器"""

+ 19 - 2
core/construction_review/component/reviewers/base_reviewer.py

@@ -4,6 +4,7 @@
 """
 
 
+import asyncio
 import uuid
 import time
 from abc import ABC
@@ -33,7 +34,8 @@ class BaseReviewer(ABC):
         self.prompt_loader = prompt_loader
     
     #@obverse
-    async def review(self, name: str, trace_id: str, reviewer_type: str, prompt_name: str, review_content: str,review_references: str = None) -> ReviewResult:
+    async def review(self, name: str, trace_id: str, reviewer_type: str, prompt_name: str, review_content: str, review_references: str = None,
+                    stage_name: str = None, state: dict = None, current_progress: int = None) -> ReviewResult:
         """
         执行审查
 
@@ -47,8 +49,10 @@ class BaseReviewer(ABC):
                 - rag: rag_enhanced_review, vector_search_review, hybrid_search_review
                 - ai: professional_suggestion, standardization_suggestion, completeness_suggestion, readability_suggestion
             review_content: 待审查内容 (必填)
-
             review_references: 审查参考内容 (可选)
+            stage_name: 阶段名称 (可选,用于进度更新)
+            state: 状态字典 (可选,用于进度更新)
+            current_progress: 当前进度 (可选,用于进度更新)
 
         Returns:
             ReviewResult: 审查结果
@@ -56,6 +60,19 @@ class BaseReviewer(ABC):
         start_time = time.time()
         name = prompt_name
         try:
+            # 添加进度更新
+            progress_message = f"{name}_{prompt_name}"
+            # 安全检查:确保所有必要参数都存在才执行进度更新
+            if state and state.get("progress_manager") and stage_name and current_progress is not None:
+                asyncio.create_task(
+                    state["progress_manager"].update_stage_progress(
+                        callback_task_id=state["callback_task_id"],
+                        stage_name=stage_name,
+                        current=current_progress,
+                        status="processing",
+                        message=progress_message
+                    )
+                )
             logger.info(f"开始执行 {name} 审查,trace_id: {trace_id},内容长度: {len(review_content)}")
             prompt_kwargs = {}
             prompt_kwargs["content"] = review_content

+ 54 - 24
core/construction_review/workflows/ai_review_workflow.py

@@ -202,13 +202,17 @@ class AIReviewWorkflow:
 
         # 更新进度
         if state["progress_manager"]:
+            logger.debug(f"AI审查工作流中ProgressManager ID: {id(state['progress_manager'])}")
+            logger.debug(f"AI审查工作流中ProgressManager有SSE回调: {hasattr(state['progress_manager'], 'sse_callback')}")
             await state["progress_manager"].update_stage_progress(
                 callback_task_id=state["callback_task_id"],
                 stage_name="AI审查",
-                progress=0,
+                current=0,
                 status="processing",
                 message="开始AI审查"
             )
+        else:
+            logger.warning(f"AI审查工作流中未找到ProgressManager: {state.get('progress_manager', 'None')}")
 
         state["messages"].append(AIMessage(content="进度初始化完成"))
 
@@ -231,30 +235,47 @@ class AIReviewWorkflow:
 
             logger.info(f"AI审查开始: 总单元数 {total_all_units}, 实际审查 {total_units} 个单元")
 
-            # 进度回调函数
-            def progress_callback(progress: int, message: str):
-                overall_progress = 50 + int(progress * 0.4)  # AI审查占整体进度的40%
-                if state["progress_manager"]:
-                    asyncio.create_task(
-                        state["progress_manager"].update_stage_progress(
-                            callback_task_id=state["callback_task_id"],
-                            stage_name="AI审查",
-                            progress=overall_progress,
-                            status="processing",
-                            message=message
-                        )
-                    )
-
+            # 开始AI审查进度
+            if state["progress_manager"]:
+                await state["progress_manager"].update_stage_progress(
+                    callback_task_id=state["callback_task_id"],
+                    stage_name="AI审查",
+                    current=0,
+                    status="processing",
+                    message=f"开始AI审查,共 {total_units} 个审查单元"
+                )
+
+            
             # 基本审查单元
             async def review_single_unit(unit_content: Dict[str, Any], unit_index: int,callback_task_id) -> ReviewResult:
                 """使用LangGraph编排的原子化组件方法审查单个单元"""
-                try:    
+                try:
                         # 构建Trace ID
                         trace_id_idx = "("+str(callback_task_id)+'-'+str(unit_index)+")"
+
+                        # 获取section_label用于stage_name
+                        section_label = unit_content.get('section_label', f'第{unit_index + 1}部分')
+                        stage_name = f"AI审查:{section_label}"
+
+                        # 方法内部进度计算(基于当前处理的单元)
+                        current_progress = int((unit_index / total_units) * 100)
+                        progress_message = f"正在处理第 {unit_index + 1}/{total_units} 个单元: {section_label}"
+
+                        if state["progress_manager"]:
+                            asyncio.create_task(
+                                state["progress_manager"].update_stage_progress(
+                                    callback_task_id=state["callback_task_id"],
+                                    stage_name=stage_name,
+                                    current=current_progress,
+                                    status="processing",
+                                    message=progress_message
+                                )
+                            )
+                        
                         # 并发执行各种原子化审查方法
                         review_tasks = [
-                            self.ai_review_engine.basic_compliance_check(trace_id_idx, unit_content),
-                            self.ai_review_engine.technical_compliance_check(trace_id_idx, unit_content),
+                            self.ai_review_engine.basic_compliance_check(trace_id_idx, unit_content, stage_name, state, current_progress),
+                            self.ai_review_engine.technical_compliance_check(trace_id_idx, unit_content, stage_name, state, current_progress),
                             # self.ai_review_engine.rag_enhanced_check(unit_content, trace_id_idx)
                         ]
 
@@ -274,11 +295,20 @@ class AIReviewWorkflow:
                         # 更新进度
                         nonlocal completed_units
                         completed_units += 1
-                        progress = int((completed_units / total_units) * 100)
+                        current = int((completed_units / total_units) * 100)
                         message = f"已完成 {completed_units}/{total_units} 个审查单元"
-
-                        if progress_callback:
-                            progress_callback(progress, message)
+                        logger.info(f"更新进度: {current}% {message}")
+                        # 更新ProgressManager进度
+                        if state["progress_manager"]:
+                            asyncio.create_task(
+                                state["progress_manager"].update_stage_progress(
+                                    callback_task_id=state["callback_task_id"],
+                                    stage_name="AI审查",
+                                    current=current,
+                                    status="processing",
+                                    message=message
+                                )
+                            )
 
                         return ReviewResult(
                         unit_index=unit_index,
@@ -351,7 +381,7 @@ class AIReviewWorkflow:
             await state["progress_manager"].update_stage_progress(
                 callback_task_id=state["callback_task_id"],
                 stage_name="AI审查",
-                progress=90,
+                current=90,
                 status="completed",
                 message="AI审查完成"
             )
@@ -372,7 +402,7 @@ class AIReviewWorkflow:
             await state["progress_manager"].update_stage_progress(
                 callback_task_id=state["callback_task_id"],
                 stage_name="AI审查",
-                progress=50,
+                current=50,
                 status="failed",
                 message=f"AI审查失败: {state['error_message']}"
             )

+ 13 - 19
core/construction_review/workflows/document_workflow.py

@@ -27,25 +27,19 @@ class DocumentWorkflow:
         try:
             logger.info(f"开始文档处理工作流,文件ID: {self.file_id}")
 
-            # 2. 初始化进度
-            await self.progress_manager.initialize_progress(
-                callback_task_id=self.callback_task_id,
-                user_id=self.user_id,
-                stages=[
-                    {"stage_name": "文档上传", "progress": 100, "status": "completed"},
-                    {"stage_name": "文档解析", "progress": 0, "status": "pending"},
-                    {"stage_name": "内容提取", "progress": 0, "status": "pending"},
-                    {"stage_name": "结构化处理", "progress": 0, "status": "pending"}
-                ]
-            )
+            # 检查是否已初始化进度,避免重复初始化
+            existing_progress = await self.progress_manager.get_progress(self.callback_task_id)
+            if not existing_progress:
+                logger.warning(f"文档处理工作流未找到进度数据: {self.callback_task_id}")
+
 
             # 4. 执行文档处理
-            def progress_callback(progress: int, message: str):
+            def progress_callback(current: int, message: str):
                 asyncio.create_task(
                     self.progress_manager.update_stage_progress(
                         callback_task_id=self.callback_task_id,
-                        stage_name="文档处理",
-                        progress=progress,
+                        stage_name="文档解析",
+                        current=current,
                         status="processing",
                         message=message
                     )
@@ -60,10 +54,10 @@ class DocumentWorkflow:
             # 5. 更新完成状态
             await self.progress_manager.update_stage_progress(
                 callback_task_id=self.callback_task_id,
-                stage_name="文档处理",
-                progress=100,
+                stage_name="文档解析",
+                current=100,
                 status="completed",
-                message="文档处理完成"
+                message="文档解析完成"
             )
 
             # 6. 保存处理结果
@@ -85,8 +79,8 @@ class DocumentWorkflow:
             if self.progress_manager:
                 await self.progress_manager.update_stage_progress(
                     callback_task_id=self.callback_task_id,
-                    stage_name="文档处理",
-                    progress=0,
+                    stage_name="文档解析",
+                    current=0,
                     status="failed",
                     message=f"处理失败: {str(e)}"
                 )

+ 7 - 11
core/construction_review/workflows/report_workflow.py

@@ -31,20 +31,20 @@ class ReportWorkflow:
             await self.progress_manager.update_stage_progress(
                 callback_task_id=self.callback_task_id,
                 stage_name="报告生成",
-                progress=0,
+                current=0,
                 status="processing",
                 message="开始生成报告"
             )
 
             # 2. 生成报告
-            def progress_callback(progress: int, message: str):
+            def progress_callback(current: int, message: str):
                 # 将报告生成的进度映射到整体进度
-                overall_progress = 90 + int(progress * 0.1)  # 报告生成占整体进度的10%
+                overall_progress = 90 + int(current * 0.1)  # 报告生成占整体进度的10%
                 asyncio.create_task(
                     self.progress_manager.update_stage_progress(
                         callback_task_id=self.callback_task_id,
                         stage_name="报告生成",
-                        progress=overall_progress,
+                        current=overall_progress,
                         status="processing",
                         message=message
                     )
@@ -60,16 +60,12 @@ class ReportWorkflow:
             await self.progress_manager.update_stage_progress(
                 callback_task_id=self.callback_task_id,
                 stage_name="报告生成",
-                progress=100,
+                current=100,
                 status="completed",
                 message="报告生成完成"
             )
 
-            # 4. 标记任务链完成
-            await self.progress_manager.complete_task(
-                callback_task_id=self.callback_task_id,
-                result=self._convert_report_to_dict(final_report)
-            )
+
 
             # 5. 处理结果
             result = self._convert_report_to_dict(final_report)
@@ -85,7 +81,7 @@ class ReportWorkflow:
                 await self.progress_manager.update_stage_progress(
                     callback_task_id=self.callback_task_id,
                     stage_name="报告生成",
-                    progress=90,
+                    current=90,
                     status="failed",
                     message=f"报告生成失败: {str(e)}"
                 )

+ 22 - 1
foundation/base/celery_app.py

@@ -7,6 +7,9 @@ import os
 from celery import Celery
 from .config import config_handler
 
+# 导入trace系统
+from foundation.trace.celery_trace import init
+
 # 从配置文件获取Redis连接信息
 redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
 redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
@@ -44,6 +47,17 @@ app.conf.update(
     worker_concurrency=2,          # 每个worker进程数(文档处理较重,不宜过多)
     worker_pool='solo',           # 使用单线程模式(避免GIL问题)
 
+    # 网络和连接配置 - 防止30分钟断连
+    broker_connection_timeout=30,      # 连接超时30秒
+    broker_connection_retry=True,      # 启用连接重试
+    broker_connection_retry_on_startup=True,  # 启动时重试
+    broker_connection_max_retries=10,  # 最大重试次数
+    broker_heartbeat=60,               # 心跳间隔60秒(默认是30秒的2倍)
+    broker_transport_options={
+        'visibility_timeout': 3600,    # 任务可见性超时
+        'socket_keepalive': True,      # 启用socket keepalive
+    },
+
     # 任务配置
     task_track_started=True,
     task_time_limit=600,           # 10分钟超时(文档处理较慢)
@@ -52,4 +66,11 @@ app.conf.update(
 
     # 结果过期时间
     result_expires=3600,           # 1小时后过期
-)
+
+    # 连接池配置
+    broker_pool_limit=None,        # 无连接池限制
+    result_backend_pool_limit=None, # 无结果后端连接池限制
+)
+
+# 初始化Celery trace系统
+init()

+ 1 - 1
foundation/base/redis_config.py

@@ -30,7 +30,7 @@ class RedisConfig:
 def load_config_from_env() -> tuple[RedisConfig]:
     """从环境变量加载配置"""
     redis_config = RedisConfig(
-        url=config_handler.get("redis", "REDIS_URL", "redis://127.0.0.1:6379"),
+        url=config_handler.get("redis", "REDIS_URL"),
         password=config_handler.get("redis", "REDIS_PASSWORD"),
         db=int(config_handler.get("redis", "REDIS_DB", "0")),
         max_connections=int(config_handler.get("redis", "REDIS_MAX_CONNECTIONS", "50"))

+ 123 - 12
foundation/base/redis_connection.py

@@ -8,17 +8,77 @@
 @Date       :2025/7/21 15:07
 '''
 import redis                     # 同步专用
-from redis import asyncio as aioredis
-
+# 尝试导入异步Redis模块
+try:
+    from redis import asyncio as redis_asyncio
+except ImportError:
+    try:
+        import aioredis as redis_asyncio
+    except ImportError:
+        raise ImportError("Neither redis.asyncio nor aioredis is available. Please install 'redis[asyncio]' or 'aioredis'")
+
+# 导入Redis异常类
+from redis.exceptions import ConnectionError as redis_ConnectionError
 
 from typing import Optional, Protocol, Dict, Any
+from functools import wraps
+import asyncio
 from foundation.base.redis_config import RedisConfig
 from foundation.base.redis_config import load_config_from_env
 from foundation.logger.loggering import server_logger
-from typing import Dict, Any, List
-from typing import Tuple, Optional
+from typing import Dict, Any, List, Tuple
 from langchain_community.storage import RedisStore
 
+
+def with_redis_retry(max_retries: int = 3, delay: float = 1.0):
+    """
+    Redis操作重连装饰器
+
+    Args:
+        max_retries: 最大重试次数,默认3次
+        delay: 重试间隔秒数,默认1秒
+    """
+    def decorator(func):
+        @wraps(func)
+        async def wrapper(self, *args, **kwargs):
+            last_exception = None
+
+            for attempt in range(max_retries + 1):  # +1 包含第一次尝试
+                try:
+                    return await func(self, *args, **kwargs)
+                except (ConnectionResetError, redis_ConnectionError) as e:
+                    last_exception = e
+
+                    if attempt < max_retries:
+                        server_logger.warning(
+                            f"Redis连接异常 (尝试 {attempt + 1}/{max_retries + 1}): {str(e)}"
+                        )
+
+                        # 尝试重连
+                        try:
+                            await self._reconnect()
+                        except Exception as reconnect_error:
+                            server_logger.error(f"Redis重连失败: {str(reconnect_error)}")
+                            # 如果重连失败,继续重试
+                            await asyncio.sleep(delay * (attempt + 1))  # 指数退避
+                            continue
+
+                        server_logger.info(f"Redis重连成功,重新执行操作")
+                        await asyncio.sleep(delay)  # 等待连接稳定
+                    else:
+                        server_logger.error(f"Redis操作失败,已达最大重试次数: {str(e)}")
+                        break
+                except Exception as e:
+                    # 非连接相关的异常直接抛出
+                    raise e
+
+            # 所有重试都失败了
+            raise last_exception
+
+        return wrapper
+    return decorator
+
+
 class RedisConnection(Protocol):
     """
     Redis 接口协议
@@ -64,23 +124,34 @@ class RedisAdapter(RedisConnection):
 
     async def connect(self):
         """创建Redis连接"""
-        self._redis = await aioredis.from_url(
+        # 简化的TCP Keep-Alive配置(兼容Windows系统)
+        socket_options = {
+            'socket_keepalive': True,
+            'socket_connect_timeout': 10,  # 连接超时10秒
+            'socket_timeout': 30,           # 读写超时30秒
+        }
+
+        # 使用新版本的redis.asyncio
+        self._redis = redis_asyncio.from_url(
             self.config.url,
             password=self.config.password,
             db=self.config.db,
             encoding="utf-8",
             decode_responses=True,
-            max_connections=self.config.max_connections
+            max_connections=self.config.max_connections,
+            **socket_options
         )
-        # 用于 langchain RedisStore 存储  
+
+        # 用于 langchain RedisStore 存储
         # 必须设为 False(LangChain 需要 bytes 数据)
-        self._langchain_redis_client = aioredis.from_url(
+        self._langchain_redis_client = redis_asyncio.from_url(
             self.config.url,
             password=self.config.password,
             db=self.config.db,
             encoding="utf-8",
             decode_responses=False,
-            max_connections=self.config.max_connections
+            max_connections=self.config.max_connections,
+            **socket_options
         )
        
         # ✅ 使用同步 Redis 客户端
@@ -100,51 +171,91 @@ class RedisAdapter(RedisConnection):
       
         return self
 
+    @with_redis_retry()
     async def get(self, key: str) -> Any:
+        """获取Redis键值"""
         return await self._redis.get(key)
 
+    @with_redis_retry()
     async def set(self, key: str, value: Any, ex: Optional[int] = None, nx: bool = False) -> bool:
+        """设置Redis键值"""
         return await self._redis.set(key, value, ex=ex, nx=nx)
 
+    @with_redis_retry()
+    async def setex(self, key: str, time: int, value: Any) -> bool:
+        """设置Redis键值并指定过期时间"""
+        return await self._redis.setex(key, time, value)
+
+    @with_redis_retry()
     async def hget(self, key: str, field: str) -> Any:
         return await self._redis.hget(key, field)
 
+    @with_redis_retry()
     async def hset(self, key: str, field: str, value: Any) -> int:
         return await self._redis.hset(key, field, value)
 
+    @with_redis_retry()
     async def hmset(self, key: str, mapping: Dict[str, Any]) -> bool:
         return await self._redis.hmset(key, mapping)
 
+    @with_redis_retry()
     async def hgetall(self, key: str) -> Dict[str, Any]:
         return await self._redis.hgetall(key)
 
+    @with_redis_retry()
     async def delete(self, *keys: str) -> int:
         return await self._redis.delete(*keys)
 
+    @with_redis_retry()
     async def exists(self, key: str) -> int:
         return await self._redis.exists(key)
 
+    @with_redis_retry()
     async def expire(self, key: str, seconds: int) -> bool:
         return await self._redis.expire(key, seconds)
 
+    @with_redis_retry()
     async def scan(self, cursor: int, match: Optional[str] = None, count: Optional[int] = None) -> tuple[
         int, list[str]]:
         return await self._redis.scan(cursor, match=match, count=count)
-    
+
+    @with_redis_retry()
     async def eval(self, script: str, numkeys: int, *keys_and_args: str) -> Any:
+        """执行Redis脚本"""
         return await self._redis.eval(script, numkeys, *keys_and_args) #  解包成独立参数
 
 
     def get_langchain_redis_client(self):
         return self._langchain_redis_client
 
+    async def _reconnect(self) -> None:
+        """重新连接Redis"""
+        try:
+            server_logger.info("正在重新连接Redis...")
+            if self._redis:
+                await self._redis.close()
+                await self._redis.wait_closed()
+            if self._langchain_redis_client:
+                await self._langchain_redis_client.close()
+                await self._langchain_redis_client.wait_closed()
+
+            # 等待短暂时间后重连
+            await asyncio.sleep(1)
+
+            # 重新建立连接
+            await self.connect()
+            server_logger.info("Redis重连成功")
+        except Exception as e:
+            server_logger.error(f"Redis重连失败: {str(e)}")
+            raise
+
     async def close(self) -> None:
         if self._redis:
             await self._redis.close()
-            await self._redis.wait_closed()
+            #await self._redis.wait_closed() #该方法已弃用
         if self._langchain_redis_client:
             await self._langchain_redis_client.close()
-            await self._langchain_redis_client.wait_closed()
+            #await self._langchain_redis_client.wait_closed()
 
 
 

+ 9 - 3
foundation/base/tasks.py

@@ -11,13 +11,19 @@ from foundation.utils.time_statistics import track_execution_time
 
 
 @app.task(bind=True)
-def submit_task_processing_task(self, file_info: dict):
+def submit_task_processing_task(self, file_info: dict, _system_trace_id: str = None):
     """
     提交任务处理到Celery队列
     这个任务只负责调用WorkflowManager,不包含业务逻辑
     """
     import traceback
 
+    # 恢复trace_id上下文
+    if _system_trace_id:
+        from foundation.trace.trace_context import TraceContext
+        TraceContext.set_trace_id(_system_trace_id)
+        logger.info(f"Celery任务恢复")
+
     # 添加调试信息
     logger.info("=== Celery任务接收调试 ===")
     logger.info(f"队列ID: {self.request.id}")
@@ -31,7 +37,7 @@ def submit_task_processing_task(self, file_info: dict):
     try:
         # 更新任务状态 - 开始处理
         self.update_state(
-            state='PROGRESS',
+            state='current',
             meta={
                 'current': 0,
                 'total': 100,
@@ -56,7 +62,7 @@ def submit_task_processing_task(self, file_info: dict):
 
         # 更新任务状态 - 完成
         self.update_state(
-            state='PROGRESS',
+            state='current',
             meta={
                 'current': 100,
                 'total': 100,

+ 19 - 4
foundation/logger/loggering.py

@@ -15,6 +15,9 @@ import sys
 import logging
 from logging.handlers import RotatingFileHandler
 
+# 导入trace系统
+from foundation.trace.trace_context import TraceContext, trace_filter
+
 
 class CompatibleLogger(logging.Logger):
     """
@@ -27,7 +30,7 @@ class CompatibleLogger(logging.Logger):
                  log_format=None, datefmt=None):
         # 初始化父类
         super().__init__(name)
-        self.setLevel(logging.DEBUG)  # 设置logger自身为最低级别
+        self.setLevel(logging.INFO)  # 设置logger自身为最低级别
 
         # 存储配置
         self.log_dir = log_dir
@@ -55,7 +58,8 @@ class CompatibleLogger(logging.Logger):
     def _set_formatter(self, log_format, datefmt):
         """设置日志格式"""
         if log_format is None:
-            log_format = 'P%(process)d.T%(thread)d | %(asctime)s | %(levelname)-8s | %(trace_id)-10s | %(log_type)-5s | %(message)s'
+            # 使用system_trace_id字段,通过TraceFilter自动注入
+            log_format = 'P%(process)d.T%(thread)d | %(asctime)s | %(levelname)-8s | %(system_trace_id)-15s | %(log_type)-5s | %(message)s'
 
         if datefmt is None:
             datefmt = '%Y-%m-%d %H:%M:%S'
@@ -84,18 +88,27 @@ class CompatibleLogger(logging.Logger):
             handler.setFormatter(self.formatter)
             # 为每个级别的日志文件都添加一个筛选器,确保记录该级别及其更高级别
             handler.addFilter(lambda record, lvl=level: record.levelno >= lvl)
+            # 添加trace_filter,自动注入system_trace_id
+            handler.addFilter(trace_filter)
             self.addHandler(handler)
 
     def _create_console_handler(self):
         """创建控制台日志处理器"""
         console_handler = logging.StreamHandler(sys.stdout)
-        console_handler.setLevel(logging.INFO)
+        console_handler.setLevel(logging.DEBUG)
         console_handler.setFormatter(self.formatter)
+        # 添加trace_filter,自动注入system_trace_id
+        console_handler.addFilter(trace_filter)
         self.addHandler(console_handler)
 
     def _log_with_context(self, level, msg, trace_id, log_type, *args, **kwargs):
-        """统一的日志记录方法"""
+        """统一的日志记录方法 - 兼容手动传递trace_id和自动获取trace_id"""
         extra = kwargs.get('extra', {})
+
+        # 如果没有手动传递trace_id,则从TraceContext自动获取
+        if not trace_id:
+            trace_id = TraceContext.get_trace_id()
+
         extra.update({
             'trace_id': trace_id,
             'log_type': log_type
@@ -140,6 +153,8 @@ server_logger = CompatibleLogger(
     backup_count=int(config_handler.get("log", "LOG_BACKUP_COUNT", "5"))
 )
 
+# 添加trace_filter到logger,自动注入system_trace_id
+server_logger.addFilter(trace_filter)
 
 # 设置日志级别
 server_logger.info("logging initialized")

+ 1 - 1
foundation/rag/vector/milvus_vector.py

@@ -1,6 +1,6 @@
 import time
 from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
-from sentence_transformers import SentenceTransformer
+# from sentence_transformers import SentenceTransformer
 import numpy as np
 from typing import List, Dict, Any, Optional
 import json

+ 121 - 0
foundation/trace/celery_trace.py

@@ -0,0 +1,121 @@
+"""
+Celery Trace管理
+负责在Celery队列任务中传递和恢复trace_id上下文
+"""
+
+from celery.signals import task_prerun, task_postrun, task_failure
+from foundation.trace.trace_context import TraceContext
+from foundation.logger.loggering import server_logger as logger
+
+
+class CeleryTraceManager:
+    """Celery trace上下文管理器"""
+
+    @staticmethod
+    def init_celery_signals():
+        """初始化Celery信号,自动管理trace_id上下文"""
+
+        @task_prerun.connect
+        def task_prerun_handler(sender=None, task_id=None, task=None, args=None, kwargs=None, **kwds):
+            """
+            任务执行前的信号处理
+            从任务参数中提取trace_id并设置到TraceContext
+            """
+            try:
+                # 从kwargs中提取trace_id参数
+                trace_id = kwargs.pop('_system_trace_id', None) or kwargs.pop('callback_task_id', None)
+
+                if trace_id:
+                    TraceContext.set_trace_id(trace_id)
+                    logger.info(f"Celery任务恢复trace_id: {trace_id}, 任务ID: {task_id}")
+                else:
+                    # 如果没有找到trace_id,生成一个临时的
+                    fallback_trace = f"celery-{task_id[:8]}"
+                    TraceContext.set_trace_id(fallback_trace)
+                    logger.warning(f"Celery任务未找到trace_id,使用临时trace: {fallback_trace}")
+
+            except Exception as e:
+                logger.error(f"Celery任务trace_id恢复失败: {str(e)}")
+                # 生成临时trace_id
+                fallback_trace = f"celery-error-{task_id[:8]}"
+                TraceContext.set_trace_id(fallback_trace)
+
+        @task_postrun.connect
+        def task_postrun_handler(sender=None, task_id=None, task=None, args=None, kwargs=None, retval=None, state=None, **kwds):
+            """
+            任务执行后的信号处理
+            清理trace_id上下文
+            """
+            try:
+                trace_id = TraceContext.get_trace_id()
+                logger.info(f"Celery任务完成: {trace_id}, 任务ID: {task_id}")
+                # 可选:清理trace_id
+                # TraceContext.set_trace_id(None)
+            except Exception as e:
+                logger.error(f"Celery任务trace_id清理失败: {str(e)}")
+
+        @task_failure.connect
+        def task_failure_handler(sender=None, task_id=None, exception=None, traceback=None, einfo=None, **kwds):
+            """
+            任务失败时的信号处理
+            """
+            try:
+                trace_id = TraceContext.get_trace_id()
+                logger.error(f"Celery任务失败: {trace_id}, 任务ID: {task_id}, 错误: {str(exception)}")
+            except Exception as e:
+                logger.error(f"Celery任务失败trace_id记录失败: {str(e)}, 任务ID: {task_id}")
+
+    @staticmethod
+    def submit_celery_task(task_func, *args, **kwargs):
+        """
+        提交Celery任务时自动传递当前trace_id
+
+        Args:
+            task_func: Celery任务函数
+            *args: 位置参数
+            **kwargs: 关键字参数
+
+        Returns:
+            Celery任务结果
+        """
+        # 获取当前trace_id
+        current_trace_id = TraceContext.get_trace_id()
+
+        # 将trace_id添加到任务参数中
+        if current_trace_id and current_trace_id != 'no-trace':
+            kwargs['_system_trace_id'] = current_trace_id   
+
+        logger.info(f"提交Celery任务")
+
+        # 提交任务
+        return task_func.delay(*args, **kwargs)
+
+
+def add_trace_to_celery_task(celery_task_func):
+    """
+    装饰器:为Celery任务函数自动添加trace_id支持
+
+    Usage:
+        @add_trace_to_celery_task
+        @app.task(bind=True)
+        def my_task(self, file_info: dict):
+            # 任务逻辑
+            pass
+    """
+    def decorator(*args, **kwargs):
+        # 获取当前trace_id
+        current_trace_id = TraceContext.get_trace_id()
+
+        if current_trace_id and current_trace_id != 'no-trace':
+            kwargs['_system_trace_id'] = current_trace_id
+
+        return celery_task_func(*args, **kwargs)
+
+    return decorator
+
+
+# 自动初始化Celery信号
+def init():
+    """初始化Celery trace系统"""
+    CeleryTraceManager.init_celery_signals()
+    logger.info("Celery trace系统初始化完成")

+ 153 - 0
foundation/trace/trace_context.py

@@ -0,0 +1,153 @@
+"""
+Trace Context Manager
+负责管理系统级别的trace_id上下文,支持异步并发和队列传播
+"""
+
+import contextvars
+import uuid
+import asyncio
+import threading
+from typing import Optional, Dict, Any, Callable
+from functools import wraps
+import logging
+
+# 全局trace_id上下文变量 - 自动跨异步传播
+system_trace_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('system_trace_id', default=None)
+
+
+class TraceContext:
+    """Trace上下文管理器"""
+
+    @staticmethod
+    def set_trace_id(trace_id: str) -> None:
+        """设置系统级trace_id"""
+        if trace_id:
+            system_trace_id.set(trace_id)
+
+    @staticmethod
+    def get_trace_id() -> str:
+        """获取当前trace_id"""
+        return system_trace_id.get() or 'no-trace'
+
+    @staticmethod
+    def generate_trace_id() -> str:
+        """生成新的trace_id"""
+        return str(uuid.uuid4())[:8]
+
+    @staticmethod
+    def get_or_generate_trace_id() -> str:
+        """获取当前trace_id,如果不存在则生成新的"""
+        current = system_trace_id.get()
+        return current if current else TraceContext.generate_trace_id()
+
+    @staticmethod
+    def extract_context() -> Dict[str, Any]:
+        """提取当前上下文信息,用于队列传递"""
+        return {
+            'system_trace_id': system_trace_id.get(),
+            'thread_id': threading.get_ident(),
+            'async_context': str(system_trace_id._context) if hasattr(system_trace_id, '_context') else None
+        }
+
+    @staticmethod
+    def restore_context(context_data: Dict[str, Any]) -> None:
+        """从队列任务中恢复trace_id上下文"""
+        if context_data and 'system_trace_id' in context_data:
+            trace_id = context_data['system_trace_id']
+            if trace_id:
+                system_trace_id.set(trace_id)
+
+    @staticmethod
+    def with_trace_context(trace_id: str):
+        """上下文管理器 - 临时设置trace_id"""
+        return _TraceContextManager(trace_id)
+
+
+class _TraceContextManager:
+    """临时trace上下文管理器"""
+
+    def __init__(self, trace_id: str):
+        self.trace_id = trace_id
+        self.token = None
+
+    def __enter__(self):
+        self.token = system_trace_id.set(self.trace_id)
+        return self.trace_id
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.token:
+            system_trace_id.reset(self.token)
+
+
+def auto_trace(trace_id_param: Optional[str] = 'callback_task_id', generate_if_missing: bool = False):
+    """
+    自动trace装饰器 - 自动管理trace_id生命周期
+
+    Args:
+        trace_id_param: 参数名,用于从函数参数中提取trace_id,如果为None则只使用generate_if_missing
+        generate_if_missing: 如果为True,当没有trace_id时自动生成
+    """
+    def decorator(func: Callable):
+        if asyncio.iscoroutinefunction(func):
+            @wraps(func)
+            async def async_wrapper(*args, **kwargs):
+                # 尝试从参数中提取trace_id
+                trace_id = None
+
+                # 只有当trace_id_param不为None时才从参数中查找
+                if trace_id_param:
+                    # 从kwargs中查找
+                    if trace_id_param in kwargs:
+                        trace_id = kwargs[trace_id_param]
+
+                    # 从位置参数中查找
+                    elif args and isinstance(args[0], str):
+                        trace_id = args[0]
+
+                # 如果还是没有找到且允许自动生成
+                if not trace_id and generate_if_missing:
+                    trace_id = TraceContext.generate_trace_id()
+
+                # 设置trace_id
+                if trace_id:
+                    TraceContext.set_trace_id(trace_id)
+
+                return await func(*args, **kwargs)
+            return async_wrapper
+        else:
+            @wraps(func)
+            def sync_wrapper(*args, **kwargs):
+                # 同步函数的逻辑类似
+                trace_id = None
+
+                # 只有当trace_id_param不为None时才从参数中查找
+                if trace_id_param:
+                    if trace_id_param in kwargs:
+                        trace_id = kwargs[trace_id_param]
+                    elif args and isinstance(args[0], str):
+                        trace_id = args[0]
+
+                if not trace_id and generate_if_missing:
+                    trace_id = TraceContext.generate_trace_id()
+
+                if trace_id:
+                    TraceContext.set_trace_id(trace_id)
+
+                return func(*args, **kwargs)
+            return sync_wrapper
+    return decorator
+
+
+class TraceFilter(logging.Filter):
+    """
+    自定义Logger Filter - 自动注入system_trace_id到日志记录
+    """
+
+    def filter(self, record: logging.LogRecord) -> bool:
+        """为日志记录添加system_trace_id字段"""
+        record.system_trace_id = TraceContext.get_trace_id()
+        return True
+
+
+# 全局TraceFilter实例,供logger使用
+trace_filter = TraceFilter()

+ 184 - 10
foundation/utils/redis_utils.py

@@ -1,16 +1,21 @@
 
 import json
+import time
+import asyncio
+import sys
+from pathlib import Path
+# root_dir = Path(__file__).parent.parent.parent 
+# print(root_dir) 
+# sys.path.append(str(root_dir))  
+from typing import Dict, Optional, Any
+from .time_statistics import track_execution_time
+from foundation.base.config import config_handler
 from foundation.logger.loggering import server_logger
 from foundation.base.redis_connection import RedisConnectionFactory
-from foundation.base.config import config_handler
 # 缓存数据有效期 默认 3 分钟
 CACHE_DATA_EXPIRED_TIME = 3 * 60
 
 
-
-
-
-
 async def set_redis_result_cache_data(data_type: str , trace_id: str, value: str):
     """
       设置redis结果缓存数据
@@ -24,9 +29,6 @@ async def set_redis_result_cache_data(data_type: str , trace_id: str, value: str
     redis_store = await RedisConnectionFactory.get_redis_store()
     await redis_store.set(key, value , ex=expired_time) 
 
-
-
-
 async def get_redis_result_cache_data(data_type: str , trace_id: str):
     """
       获取redis结果缓存数据
@@ -41,7 +43,6 @@ async def get_redis_result_cache_data(data_type: str , trace_id: str):
     return value
 
 
-
 async def get_redis_result_cache_data_and_delete_key(data_type: str , trace_id: str):
     """
       获取redis结果缓存数据
@@ -61,4 +62,177 @@ async def get_redis_result_cache_data_and_delete_key(data_type: str , trace_id:
     data = json.loads(json_str)
     # 删除key
     #await redis_store.delete(key)
-    return data
+    return data
+
+
+
+
+@track_execution_time
+async def store_file_info(file_id: str, file_info: Dict[str, Any], expire_seconds: int = 3600) -> bool:
+    """
+    存储文件信息(直接存储模式)
+
+    Args:
+        file_id: 文件ID
+        file_info: 文件信息字典
+        expire_seconds: 过期时间(秒),默认1小时
+
+    Returns:
+        bool: 存储是否成功
+    """
+    try:
+        redis_store = await RedisConnectionFactory.get_redis_store()
+
+        # 检查是否已存在,避免重复存储
+        existing_meta = await redis_store.get(f"meta:{file_id}")
+        if existing_meta:
+            server_logger.info(f"文件信息已存在,跳过存储: {file_id}")
+            return True
+
+        # 提取文件内容
+        file_content = file_info.get('file_content')
+
+        if file_content:
+            file_size = len(file_content)
+            server_logger.info(f"使用直接存储策略: {file_id}, {file_size/1024/1024:.2f}MB")
+
+            # 直接存储
+            metadata = {k: v for k, v in file_info.items() if k != 'file_content'}
+            metadata['file_size'] = file_size
+
+            # 并行执行元数据和内容存储以提高性能
+            tasks = [
+                redis_store.setex(f"meta:{file_id}", expire_seconds, json.dumps(metadata)),
+                redis_store.setex(f"content:{file_id}", expire_seconds, file_content)
+            ]
+            await asyncio.gather(*tasks)
+        else:
+            # 没有文件内容,只存元数据
+            metadata = file_info.copy()
+            await redis_store.setex(f"meta:{file_id}", expire_seconds, json.dumps(metadata))
+
+        server_logger.info(f"文件信息已存储到Redis: {file_id}")
+        return True
+
+    except Exception as e:
+        server_logger.error(f"存储文件信息到Redis失败: {str(e)}")
+        return False
+
+@track_execution_time
+async def get_file_info(file_id: str, include_content: bool = True) -> Optional[Dict[str, Any]]:
+    """
+    根据file_id获取文件信息
+
+    Args:
+        file_id: 文件ID
+        include_content: 是否包含文件内容(默认True),可选False以提高效率
+
+    Returns:
+        Dict: 文件信息字典,如果不存在返回None
+    """
+    try:
+        redis_store = await RedisConnectionFactory.get_redis_store()
+
+        # 获取元数据
+        meta_key = f"meta:{file_id}"
+        meta_bytes = await redis_store.get(meta_key)
+
+        if not meta_bytes:
+            server_logger.warning(f"文件元数据不存在: {meta_key}")
+            return None
+
+        # 解析元数据
+        file_info = json.loads(meta_bytes.decode('utf-8'))
+
+        # 根据存储类型获取文件内容
+        if include_content and 'file_size' in file_info:
+            # 直接获取文件内容
+            content_key = f"content:{file_id}"
+            file_content = await redis_store.get(content_key)
+            if file_content:
+                file_info['file_content'] = file_content
+            else:
+                server_logger.warning(f"文件内容不存在: {content_key}")
+                return None  # 文件内容缺失,返回None
+
+        server_logger.info(f"从Redis获取到文件信息: {meta_key}")
+        return file_info
+
+    except json.JSONDecodeError as e:
+        server_logger.error(f"解析文件元数据JSON失败: {str(e)}")
+        return None
+    except Exception as e:
+        server_logger.error(f"获取文件信息失败: {str(e)}")
+        return None
+
+
+async def delete_file_info(file_id: str) -> bool:
+    """
+    删除文件信息
+
+    Args:
+        file_id: 文件ID
+
+    Returns:
+        bool: 删除是否成功
+    """
+    try:
+        # 为了避免事件循环冲突,直接创建新的Redis连接
+        from foundation.base.redis_config import load_config_from_env
+        from foundation.base.redis_connection import RedisAdapter
+
+        redis_config = load_config_from_env()
+        adapter = RedisAdapter(redis_config)
+        await adapter.connect()
+        redis_store = adapter.get_langchain_redis_client()
+
+        # 获取元数据以确定存储类型
+        meta_key = f"meta:{file_id}"
+        meta_bytes = await redis_store.get(meta_key)
+
+        if not meta_bytes:
+            server_logger.warning(f"文件元数据不存在: {meta_key}")
+            # 清理连接
+            await adapter.close()
+            return True  # 可能已经删除了
+
+        # 解析元数据
+        file_info = json.loads(meta_bytes.decode('utf-8'))
+
+        # 删除相应的内容
+        deleted_count = 0
+
+        # 删除元数据
+        deleted_count += await redis_store.delete(meta_key)
+
+        # 如果有文件大小信息,说明有文件内容,需要删除
+        if 'file_size' in file_info:
+            # 删除文件内容
+            content_key = f"content:{file_id}"
+            deleted_count += await redis_store.delete(content_key)
+
+        if deleted_count > 0:
+            server_logger.info(f"已删除文件信息: {file_id}, {deleted_count}个键")
+        else:
+            server_logger.warning(f"Redis缓存不存在,无法删除: {file_id}")
+
+        # 清理连接
+        await adapter.close()
+        return True if deleted_count > 0 else False
+
+    except json.JSONDecodeError as e:
+        server_logger.error(f"解析文件元数据JSON失败: {str(e)}")
+        # 清理连接
+        await adapter.close()
+        return False
+    except Exception as e:
+        server_logger.error(f"删除文件信息失败: {str(e)}")
+        # 清理连接
+        await adapter.close()
+        return False
+    finally:
+        # 确保连接被关闭
+        await adapter.close()
+
+#asyncio.run(delete_file_info('e385049cde7d21a48c7de216182f0f23'))
+

+ 22 - 5
foundation/utils/time_statistics.py

@@ -1,21 +1,38 @@
 import time
+import asyncio
+import inspect
 from functools import wraps
 from ..logger.loggering import server_logger as logger
 
 def track_execution_time(func):
     """
     追踪函数执行时间并通过日志输出的装饰器
-    记录函数开始执行、执行完成及耗时(保留两位小数)
+    同时支持同步和异步函数,记录函数开始执行、执行完成及耗时(保留两位小数)
     """
     @wraps(func)
-    def wrapper(*args, **kwargs):
+    def sync_wrapper(*args, **kwargs):
         logger.info(f"[{func.__name__}] 开始执行")
         start_time = time.perf_counter()
-        
+
         try:
             return func(*args, **kwargs)
         finally:
             duration = time.perf_counter() - start_time
             logger.info(f"[{func.__name__}] 执行完成,耗时: {duration:.2f} 秒")
-    
-    return wrapper
+
+    @wraps(func)
+    async def async_wrapper(*args, **kwargs):
+        logger.info(f"[{func.__name__}] 开始执行")
+        start_time = time.perf_counter()
+
+        try:
+            return await func(*args, **kwargs)
+        finally:
+            duration = time.perf_counter() - start_time
+            logger.info(f"[{func.__name__}] 执行完成,耗时: {duration:.2f} 秒")
+
+    # 检查函数是否是异步函数
+    if inspect.iscoroutinefunction(func):
+        return async_wrapper
+    else:
+        return sync_wrapper

Разница между файлами не показана из-за своего большого размера
+ 3 - 13
temp/AI审查结果.json


+ 0 - 281
test/construction_review/api_test_client.py

@@ -1,281 +0,0 @@
-"""
-施工方案审查API测试客户端
-用于测试Mock接口和前端联调
-"""
-
-import requests
-import json
-import time
-import uuid
-from pathlib import Path
-from typing import Optional, Dict, Any
-
-class ConstructionReviewAPIClient:
-    """施工方案审查API客户端"""
-
-    def __init__(self, base_url: str = "http://127.0.0.1:8034", api_key: Optional[str] = None):
-        self.base_url = base_url.rstrip('/')
-        self.api_key = api_key
-        self.session = requests.Session()
-
-        if api_key:
-            self.session.headers.update({
-                'Authorization': f'Bearer {api_key}'
-            })
-
-    def upload_file(self, file_path: str, project_plan_type: str, user: str,
-                   callback_url: Optional[str] = None) -> Dict[str, Any]:
-        """
-        上传文件
-
-        Args:
-            file_path: 文件路径
-            project_plan_type: 工程方案类型
-            user: 用户标识
-            callback_url: 回调URL(可选)
-
-        Returns:
-            上传响应结果
-        """
-        url = f"{self.base_url}/sgsc/file_upload"
-
-        if not Path(file_path).exists():
-            raise FileNotFoundError(f"文件不存在: {file_path}")
-
-        with open(file_path, 'rb') as f:
-            files = {'file': f}
-            data = {
-                'project_plan_type': project_plan_type,
-                'user': user
-            }
-
-            if callback_url:
-                data['callback_url'] = callback_url
-
-            response = self.session.post(url, files=files, data=data)
-            response.raise_for_status()
-            return response.json()
-
-    def get_task_progress(self, callback_task_id: str, user: str) -> Dict[str, Any]:
-        """
-        查询任务进度
-
-        Args:
-            callback_task_id: 任务ID
-            user: 用户标识
-
-        Returns:
-            进度查询结果
-        """
-        url = f"{self.base_url}/sgsc/task_progress/{callback_task_id}"
-        params = {'user': user}
-
-        response = self.session.get(url, params=params)
-        response.raise_for_status()
-        return response.json()
-
-    def get_review_results(self, file_id: str, user: str, result_type: str) -> Dict[str, Any]:
-        """
-        获取审查结果
-
-        Args:
-            file_id: 文件ID
-            user: 用户标识
-            result_type: 结果类型 ("summary" 或 "issues")
-
-        Returns:
-            审查结果
-        """
-        url = f"{self.base_url}/sgsc/review_results"
-        data = {
-            'id': file_id,
-            'user': user,
-            'type': result_type
-        }
-
-        response = self.session.post(url, json=data)
-        response.raise_for_status()
-        return response.json()
-
-    def wait_for_completion(self, callback_task_id: str, user: str,
-                          timeout: int = 1800, interval: int = 10) -> Dict[str, Any]:
-        """
-        等待任务完成
-
-        Args:
-            callback_task_id: 任务ID
-            user: 用户标识
-            timeout: 超时时间(秒)
-            interval: 轮询间隔(秒)
-
-        Returns:
-            最终任务状态
-        """
-        start_time = time.time()
-
-        while time.time() - start_time < timeout:
-            try:
-                result = self.get_task_progress(callback_task_id, user)
-
-                if result['data']['review_task_status'] == 'completed':
-                    print(f"任务完成! 总进度: {result['data']['overall_progress']}%")
-                    return result
-                else:
-                    progress = result['data']['overall_progress']
-                    print(f"任务进行中... 进度: {progress}%")
-                    time.sleep(interval)
-
-            except Exception as e:
-                print(f"查询进度失败: {e}")
-                time.sleep(interval)
-
-        raise TimeoutError(f"任务超时,等待时间超过 {timeout} 秒")
-
-class MockTestRunner:
-    """Mock测试运行器"""
-
-    def __init__(self, client: ConstructionReviewAPIClient):
-        self.client = client
-
-    def test_file_upload(self, file_path: str = None) -> Dict[str, Any]:
-        """测试文件上传"""
-        print("=== 测试文件上传 ===")
-
-        # 创建测试文件(如果没有提供文件路径)
-        if not file_path:
-            test_file = Path(r"D:\wx_work\sichuan_luqiao\LQAgentPlatform\data_pipeline\test_rawdata\1f3e1d98-5b4a-4a06-87b3-c7f0413b901a.pdf")
-            if not test_file.exists():
-                # 创建一个简单的测试PDF文件内容
-                test_file.write_bytes(b"%PDF-1.4\n%Mock PDF for testing\n")
-            file_path = str(test_file)
-
-        try:
-            result = self.client.upload_file(
-                file_path=file_path,
-                project_plan_type="bridge_up_part",
-                user=str(uuid.uuid4()),
-                callback_url="https://client.example.com/callback"
-            )
-
-            print(f"✅ 文件上传成功")
-            print(f"文件ID: {result['data']['id']}")
-            print(f"任务ID: {result['data']['callback_task_id']}")
-
-            return result
-
-        except Exception as e:
-            print(f"❌ 文件上传失败: {e}")
-            raise
-
-    def test_progress_query(self, callback_task_id: str, user: str) -> None:
-        """测试进度查询"""
-        print("\n=== 测试进度查询 ===")
-
-        try:
-            result = self.client.get_task_progress(callback_task_id, user)
-
-            print(f"✅ 进度查询成功")
-            print(f"任务状态: {result['data']['review_task_status']}")
-            print(f"总进度: {result['data']['overall_progress']}%")
-
-            for stage in result['data']['stages']:
-                print(f"  - {stage['stage_name']}: {stage['progress']}% ({stage['stage_status']})")
-
-        except Exception as e:
-            print(f"❌ 进度查询失败: {e}")
-            raise
-
-    def test_review_results(self, file_id: str, user: str) -> None:
-        """测试审查结果获取"""
-        print("\n=== 测试审查结果获取 ===")
-
-        # 测试获取总结报告
-        try:
-            result = self.client.get_review_results(file_id, user, "summary")
-
-            print(f"✅ 总结报告获取成功")
-            print(f"风险统计: {result['data']['risk_stats']}")
-            print(f"四维评分: {result['data']['dimension_scores']}")
-            print(f"总结报告: {result['data']['summary_report']}")
-
-        except Exception as e:
-            print(f"❌ 总结报告获取失败: {e}")
-
-        # 测试获取问题条文
-        try:
-            result = self.client.get_review_results(file_id, user, "issues")
-
-            print(f"\n✅ 问题条文获取成功")
-            issues = result['data']['issues']
-            print(f"发现问题数量: {len(issues)}")
-
-            for i, issue in enumerate(issues):
-                print(f"\n问题 {i+1}:")
-                print(f"  ID: {issue['issue_id']}")
-                print(f"  页码: {issue['metadata']['page']}")
-                print(f"  章节: {issue['metadata']['chapter']}")
-                print(f"  风险等级: {issue['risk_summary']['max_risk_level']}")
-                print(f"  检查项数量: {len(issue['review_lists'])}")
-
-        except Exception as e:
-            print(f"❌ 问题条文获取失败: {e}")
-
-    def run_complete_test(self) -> None:
-        """运行完整测试流程"""
-        print("开始施工方案审查API完整测试...")
-
-        try:
-            # 1. 上传文件
-            upload_result = self.test_file_upload()
-            file_id = upload_result['data']['id']
-            callback_task_id = upload_result['data']['callback_task_id']
-            user = str(uuid.uuid4())  # 实际应该从上传响应中获取,这里简化
-
-            # 2. 查询进度(等待一段时间让任务完成)
-            print("\n等待任务完成...")
-            time.sleep(2)  # 短暂等待
-
-            # 先测试进度查询
-            self.test_progress_query(callback_task_id, user)
-
-            # 3. 获取审查结果(可能需要等待任务完成)
-            print("\n获取审查结果...")
-
-            # 如果任务还未完成,直接标记完成(仅用于Mock测试)
-            try:
-                self.test_review_results(file_id, user)
-            except Exception as e:
-                print(f"审查结果获取失败,尝试完成任务: {e}")
-
-                # 完成任务(Mock功能)
-                response = requests.post(f"{self.client.base_url}/sgsc/mock/complete_task",
-                                       data={"callback_task_id": callback_task_id})
-                print("任务已强制完成,重新获取结果...")
-
-                self.test_review_results(file_id, user)
-
-            print("\n🎉 完整测试流程执行成功!")
-
-        except Exception as e:
-            print(f"\n❌ 测试失败: {e}")
-            raise
-
-def main():
-    """主函数 - 运行测试"""
-    print("施工方案审查API Mock测试客户端")
-    print("=" * 50)
-
-    # 创建客户端
-    client = ConstructionReviewAPIClient(
-        base_url="http://127.0.0.1:8034",
-        api_key="mock-api-key-12345"
-    )
-
-    # 创建测试运行器
-    test_runner = MockTestRunner(client)
-
-    # 运行完整测试
-    test_runner.run_complete_test()
-
-if __name__ == "__main__":
-    main()

+ 0 - 370
test/construction_review/test_error_codes_pytest.py

@@ -1,370 +0,0 @@
-"""
-施工方案审查API错误码测试 - pytest版本
-使用pytest运行的标准测试套件
-"""
-
-import pytest
-import requests
-import json
-import uuid
-import time
-import os
-from typing import Dict, Any
-
-# pytest fixtures
-@pytest.fixture(scope="class")
-def api_config():
-    """API配置fixture"""
-    return {
-        "base_url": "http://127.0.0.1:8034",
-        "api_prefix": "/sgsc",
-        "valid_user": "user-001",
-        "valid_project_type": "bridge_up_part",
-        "test_callback_url": "http://test.callback.com"
-    }
-
-@pytest.fixture
-def test_file():
-    """测试文件fixture - 每个测试都创建新的文件对象"""
-    file_path = "data_pipeline/test_rawdata/1f3e1d98-5b4a-4a06-87b3-c7f0413b901a.pdf"
-
-    class TestFile:
-        def __init__(self):
-            if os.path.exists(file_path):
-                self.file = open(file_path, 'rb')
-                self.file_tuple = (os.path.basename(file_path), self.file, 'application/pdf')
-                self.close_file = True
-            else:
-                self.file = None
-                self.file_tuple = ("test.pdf", b"mock pdf content", "application/pdf")
-                self.close_file = False
-
-        def get_file(self):
-            """获取文件元组"""
-            if self.close_file and self.file:
-                # 重新打开文件,确保文件未被关闭
-                self.file.seek(0)
-            return self.file_tuple
-
-        def cleanup(self):
-            """清理资源"""
-            if self.close_file and self.file:
-                self.file.close()
-
-    test_file_obj = TestFile()
-    yield test_file_obj
-    test_file_obj.cleanup()
-
-class TestFileUploadErrors:
-    """文件上传接口错误码测试"""
-
-    @pytest.mark.parametrize("test_case,expected_code", [
-        ("missing_file", "WJSC001"),
-        ("empty_file", "WJSC003"),
-        ("unsupported_format", "WJSC004"),
-        ("invalid_project_type", "WJSC006")
-    ])
-    def test_file_upload_errors(self, api_config, test_case, expected_code):
-        """测试文件上传各种错误场景"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-
-        if test_case == "missing_file":
-            # 不上传文件
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": api_config["valid_project_type"],
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, data=data)
-
-        elif test_case == "empty_file":
-            # 上传空文件
-            files = {"file": ("empty.pdf", b"", "application/pdf")}
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": api_config["valid_project_type"],
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, files=files, data=data)
-
-        elif test_case == "unsupported_format":
-            # 上传不支持的格式
-            files = {"file": ("test.txt", b"text content", "text/plain")}
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": api_config["valid_project_type"],
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, files=files, data=data)
-
-        elif test_case == "invalid_project_type":
-            # 无效的工程方案类型
-            files = {"file": ("test.pdf", b"mock pdf content", "application/pdf")}
-            data = {
-                "callback_url": api_config["test_callback_url"],
-                "project_plan_type": "invalid_type",
-                "user": api_config["valid_user"]
-            }
-            response = requests.post(url, files=files, data=data)
-
-        # 验证错误响应
-        assert response.status_code in [400, 403, 404]  # 允许的业务错误状态码
-
-        try:
-            error_data = response.json()
-            assert error_data["code"] == expected_code
-            assert "error_type" in error_data
-            assert "message" in error_data
-        except json.JSONDecodeError:
-            pytest.fail(f"响应不是有效的JSON: {response.text}")
-
-    def test_file_upload_success(self, api_config, test_file):
-        """测试文件上传成功"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-
-        result = response.json()
-        assert "data" in result
-        assert "callback_task_id" in result["data"]
-        assert "id" in result["data"]
-
-    @pytest.mark.skip(reason="文件大小检查未实现")
-    def test_wjsc005_file_size_exceeded(self, api_config):
-        """测试WJSC005: 文件过大 - 跳过因为未实现"""
-        pass
-
-    @pytest.mark.skip(reason="认证检查未实现")
-    def test_wjsc007_unauthorized(self, api_config):
-        """测试WJSC007: 认证失败 - 跳过因为未实现"""
-        pass
-
-
-class TestTaskProgressErrors:
-    """进度查询接口错误码测试"""
-
-    def test_jdlx001_missing_parameters(self, api_config):
-        """测试JDLX001: 请求参数缺失"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/test-callback-id"
-        response = requests.get(url)  # 不提供user参数
-
-        assert response.status_code == 400
-        error_data = response.json()
-        assert error_data["code"] == "JDLX001"
-
-    @pytest.mark.parametrize("invalid_id", ["short", "123", "invalid-format"])
-    def test_jdlx002_invalid_param_format(self, api_config, invalid_id):
-        """测试JDLX002: 请求参数格式错误"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{invalid_id}"
-        params = {"user": api_config["valid_user"]}
-
-        response = requests.get(url, params=params)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "JDLX002"
-
-    @pytest.mark.parametrize("invalid_user", ["invalid_user", "user-999", ""])
-    def test_jdlx004_invalid_user(self, api_config, test_file, invalid_user):
-        """测试JDLX004: 用户标识无效"""
-        # 先上传文件获取有效的callback_task_id
-        callback_task_id = self._upload_file_and_get_callback(api_config, test_file)
-
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{callback_task_id}"
-        params = {"user": invalid_user}
-
-        response = requests.get(url, params=params)
-        assert response.status_code == 403
-
-        error_data = response.json()
-        assert error_data["code"] == "JDLX004"
-
-    def test_jdlx005_task_not_found(self, api_config):
-        """测试JDLX005: 任务不存在"""
-        fake_callback_id = f"{uuid.uuid4()}-{int(time.time())}"
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{fake_callback_id}"
-        params = {"user": api_config["valid_user"]}
-
-        response = requests.get(url, params=params)
-        assert response.status_code == 404
-
-        error_data = response.json()
-        assert error_data["code"] == "JDLX005"
-
-    def _upload_file_and_get_callback(self, api_config, test_file):
-        """辅助方法:上传文件并获取callback_task_id"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-        result = response.json()
-        return result["data"]["callback_task_id"]
-
-    @pytest.mark.skip(reason="认证检查未实现")
-    def test_jdlx003_unauthorized(self, api_config):
-        """测试JDLX003: 认证失败 - 跳过因为未实现"""
-        pass
-
-
-class TestReviewResultsErrors:
-    """审查结果接口错误码测试"""
-
-    @pytest.mark.parametrize("invalid_type", ["invalid", "risk", "detail", ""])
-    def test_scjg001_invalid_type(self, api_config, invalid_type):
-        """测试SCJG001: 结果类型无效"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": str(uuid.uuid4()),
-            "user": api_config["valid_user"],
-            "type": invalid_type
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG001"
-
-    @pytest.mark.parametrize("invalid_id", ["", None])
-    def test_scjg002_missing_param_id(self, api_config, invalid_id):
-        """测试SCJG002: 缺少文档ID"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-
-        if invalid_id is None:
-            payload = {
-                "user": api_config["valid_user"],
-                "type": "summary"
-            }
-        else:
-            payload = {
-                "id": invalid_id,
-                "user": api_config["valid_user"],
-                "type": "summary"
-            }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG002"
-
-    @pytest.mark.parametrize("invalid_format", ["123", "short-id", "invalid-uuid-format"])
-    def test_scjg003_invalid_id_format(self, api_config, invalid_format):
-        """测试SCJG003: 文档ID格式错误"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": invalid_format,
-            "user": api_config["valid_user"],
-            "type": "summary"
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 400
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG003"
-
-    def test_scjg005_invalid_user_review_results(self, api_config, test_file):
-        """测试SCJG005: 用户标识无效(审查结果接口)"""
-        # 先上传文件获取有效的文件ID
-        file_id = self._upload_file_and_get_file_id(api_config, test_file)
-
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": file_id,
-            "user": "invalid_user",
-            "type": "summary"
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 403
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG005"
-
-    def test_scjg006_task_not_found_review_results(self, api_config):
-        """测试SCJG006: 任务不存在(审查结果接口)"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/review_results"
-        payload = {
-            "id": str(uuid.uuid4()),
-            "user": api_config["valid_user"],
-            "type": "summary"
-        }
-
-        response = requests.post(url, json=payload)
-        assert response.status_code == 404
-
-        error_data = response.json()
-        assert error_data["code"] == "SCJG006"
-
-    def _upload_file_and_get_file_id(self, api_config, test_file):
-        """辅助方法:上传文件并获取文件ID"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-        result = response.json()
-        return result["data"]["id"]
-
-    @pytest.mark.skip(reason="认证检查未实现")
-    def test_scjg004_unauthorized(self, api_config):
-        """测试SCJG004: 认证失败 - 跳过因为未实现"""
-        pass
-
-
-class TestIntegration:
-    """集成测试"""
-
-    def test_complete_workflow_success(self, api_config, test_file):
-        """测试完整工作流程成功场景"""
-        # 1. 文件上传
-        callback_task_id = self._upload_file_and_get_callback(api_config, test_file)
-        assert callback_task_id is not None
-
-        # 2. 进度查询
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/task_progress/{callback_task_id}"
-        params = {"user": api_config["valid_user"]}
-        response = requests.get(url, params=params)
-        assert response.status_code == 200
-
-    def _upload_file_and_get_callback(self, api_config, test_file):
-        """辅助方法:上传文件并获取callback_task_id"""
-        url = f"{api_config['base_url']}{api_config['api_prefix']}/file_upload"
-        files = {"file": test_file.get_file()}
-        data = {
-            "callback_url": api_config["test_callback_url"],
-            "project_plan_type": api_config["valid_project_type"],
-            "user": api_config["valid_user"]
-        }
-
-        response = requests.post(url, files=files, data=data)
-        assert response.status_code == 200
-        result = response.json()
-        return result["data"]["callback_task_id"]
-
-
-if __name__ == "__main__":
-    # 如果直接运行此文件,给出提示
-    print("请使用 pytest 运行此测试文件:")
-    print("pytest test/construction_review/test_error_codes_pytest.py -v")
-    print("或者运行所有测试:")
-    print("pytest test/ -v")

+ 190 - 0
test/system_trace_id_test.py

@@ -0,0 +1,190 @@
+"""
+系统Trace ID测试
+验证trace_id在异步并发和队列中的正确传播
+"""
+import os
+import sys
+# Add the parent directory (LQAgentPlatform) to sys.path so we can import foundation
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(project_root)
+import asyncio
+import time
+from foundation.trace.trace_context import TraceContext, auto_trace
+from foundation.logger.loggering import server_logger as logger
+
+
+class TraceIDTest:
+    """Trace ID测试类"""
+
+    @staticmethod
+    async def test_basic_context():
+        """测试基础上下文功能"""
+        print("\n=== 测试1: 基础上下文功能 ===")
+
+        # 设置trace_id
+        trace_id = TraceContext.generate_trace_id()
+        TraceContext.set_trace_id(trace_id)
+
+        logger.info("测试基础日志,应该包含trace_id")
+        logger.info(f"手动设置的trace_id: {trace_id}")
+        logger.info(f"自动获取的trace_id: {TraceContext.get_trace_id()}")
+
+        assert TraceContext.get_trace_id() == trace_id, "trace_id设置失败"
+        print("[PASS] 基础上下文功能测试通过")
+
+    @staticmethod
+    async def test_async_propagation():
+        """测试异步并发传播"""
+        print("\n=== 测试2: 异步并发传播 ===")
+
+        # 设置主trace_id
+        main_trace = "main-async-test"
+        TraceContext.set_trace_id(main_trace)
+
+        logger.info("主异步任务开始")
+
+        async def concurrent_task(task_id: int):
+            """并发任务"""
+            current_trace = TraceContext.get_trace_id()
+            logger.info(f"并发任务 {task_id} 获取到的trace_id: {current_trace}")
+
+            # 在异步任务中修改trace_id,应该不影响其他任务
+            new_trace = f"{main_trace}-subtask-{task_id}"
+            TraceContext.set_trace_id(new_trace)
+
+            await asyncio.sleep(0.1)
+            logger.info(f"并发任务 {task_id} 修改后的trace_id: {new_trace}")
+
+            return current_trace
+
+        # 创建并发任务
+        tasks = [concurrent_task(i) for i in range(3)]
+        results = await asyncio.gather(*tasks)
+
+        # 验证所有任务都继承到了主trace_id
+        for i, result in enumerate(results):
+            assert result == main_trace, f"任务 {i} 没有继承主trace_id"
+
+        # 验证主任务trace_id不受影响
+        assert TraceContext.get_trace_id() == main_trace, "主trace_id被并发任务污染"
+
+        logger.info("主异步任务完成")
+        print("[PASS] 异步并发传播测试通过")
+
+    @staticmethod
+    @auto_trace('callback_task_id')
+    async def test_decorator_auto_trace(callback_task_id: str):
+        """测试装饰器自动trace"""
+        print(f"\n=== 测试3: 装饰器自动trace ===")
+
+        # 不需要手动设置trace_id,装饰器会自动处理
+        current_trace = TraceContext.get_trace_id()
+        logger.info("装饰器自动设置的日志")
+
+        assert current_trace == callback_task_id, "装饰器没有正确设置trace_id"
+
+        # 测试装饰器在异步并发中的表现
+        async def nested_task():
+            nested_trace = TraceContext.get_trace_id()
+            logger.info("嵌套异步任务")
+            return nested_trace
+
+        nested_result = await nested_task()
+        assert nested_result == callback_task_id, "嵌套任务没有继承装饰器设置的trace_id"
+
+        print(f"[PASS] 装饰器自动trace测试通过,trace_id: {callback_task_id}")
+
+    @staticmethod
+    async def test_context_manager():
+        """测试上下文管理器"""
+        print("\n=== 测试4: 上下文管理器 ===")
+
+        original_trace = TraceContext.get_trace_id()
+        logger.info(f"原始trace_id: {original_trace}")
+
+        # 使用上下文管理器临时设置trace_id
+        temp_trace = "temporary-trace"
+        with TraceContext.with_trace_context(temp_trace) as ctx:
+            logger.info("上下文管理器内的日志")
+            current_trace = TraceContext.get_trace_id()
+            assert current_trace == temp_trace, "上下文管理器没有正确设置trace_id"
+
+        # 退出上下文后应该恢复原始trace_id
+        restored_trace = TraceContext.get_trace_id()
+        logger.info(f"恢复后的trace_id: {restored_trace}")
+        assert restored_trace == original_trace, "上下文管理器没有正确恢复trace_id"
+
+        print("[PASS] 上下文管理器测试通过")
+
+    @staticmethod
+    def test_celery_task_simulation():
+        """测试Celery任务trace_id模拟"""
+        print("\n=== 测试5: Celery任务trace_id模拟 ===")
+
+        # 模拟提交Celery任务前的trace_id设置
+        submit_trace = "celery-submit-test"
+        TraceContext.set_trace_id(submit_trace)
+
+        logger.info("准备提交Celery任务")
+
+        # 模拟Celery任务执行
+        def simulate_celery_task_execution(file_info: dict, _system_trace_id=None):
+            """模拟Celery任务执行"""
+            if _system_trace_id:
+                TraceContext.set_trace_id(_system_trace_id)
+
+            current_trace = TraceContext.get_trace_id()
+            logger.info("Celery任务执行中")
+            logger.info(f"文件ID: {file_info.get('file_id')}")
+
+            return current_trace
+
+        # 提交任务(模拟)
+        file_info = {'file_id': 'test-file-123'}
+        extracted_trace = TraceContext.get_trace_id()
+
+        # 执行任务
+        task_trace = simulate_celery_task_execution(
+            file_info,
+            _system_trace_id=extracted_trace
+        )
+
+        assert task_trace == submit_trace, "Celery任务没有正确获取到trace_id"
+
+        print("[PASS] Celery任务trace_id模拟测试通过")
+
+
+async def run_all_tests():
+    """运行所有测试"""
+    print("开始运行系统Trace ID测试...\n")
+
+    try:
+        # 测试1: 基础上下文功能
+        await TraceIDTest.test_basic_context()
+
+        # 测试2: 异步并发传播
+        await TraceIDTest.test_async_propagation()
+
+        # 测试3: 装饰器自动trace
+        await TraceIDTest.test_decorator_auto_trace("decorator-test-123")
+
+        # 测试4: 上下文管理器
+        await TraceIDTest.test_context_manager()
+
+        # 测试5: Celery任务模拟
+        TraceIDTest.test_celery_task_simulation()
+
+        print("\n[SUCCESS] 所有测试通过!系统Trace ID机制工作正常")
+        return True
+
+    except Exception as e:
+        print(f"\n[FAIL] 测试失败: {str(e)}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+if __name__ == "__main__":
+    # 运行测试
+    success = asyncio.run(run_all_tests())
+    exit(0 if success else 1)

+ 179 - 0
test/test_sse_integration.py

@@ -0,0 +1,179 @@
+#!/usr/bin/env python3
+"""
+SSE进度推送集成测试脚本
+模拟文件上传后建立SSE连接的完整流程
+"""
+
+import requests
+import json
+import time
+import threading
+import os
+import signal
+import sys
+from typing import Dict, Any
+
+# 全局标志用于控制程序退出
+should_exit = False
+
+def signal_handler(signum, frame):
+    """信号处理器,处理Ctrl+C"""
+    global should_exit
+    print(f"\n\n⚡ 接收到信号 {signum},正在退出...")
+    should_exit = True
+    sys.exit(0)
+
+# 注册信号处理器
+signal.signal(signal.SIGINT, signal_handler)
+signal.signal(signal.SIGTERM, signal_handler)
+
+class ProgressTracker:
+    def __init__(self):
+        self.progress_events = []
+        self.completed = False
+        self.error = None
+
+    def handle_sse_event(self, event_data: Dict[str, Any]):
+        """处理SSE事件"""
+        self.progress_events.append(event_data)
+
+        event_type = event_data.get("type", "unknown")
+        data = event_data.get("data", {})
+
+        print(f"\n📡 SSE事件 [{event_type}]:")
+        print(f"   时间: {data.get('timestamp', 'N/A')}")
+        print(f"   进度: {data.get('current', 0)}%")
+        print(f"   阶段: {data.get('stage_name', 'N/A')}")
+        print(f"   状态: {data.get('status', 'N/A')}")
+        print(f"   消息: {data.get('message', 'N/A')}")
+
+        if event_type == "completed":
+            self.completed = True
+            print(f"\n✅ 任务完成: {data.get('task_status', 'N/A')}")
+        elif event_type == "error":
+            self.error = data.get("message", "未知错误")
+            print(f"\n❌ 任务错误: {self.error}")
+
+def test_file_upload():
+    """测试文件上传"""
+    upload_url = "http://127.0.0.1:8035/sgsc/file_upload"
+    file_path = "D:/wx_work/sichuan_luqiao/项目背景资料/路桥桥梁工程施工方案 7 份/罗成依达大桥拱座竖桩专项施工方案.pdf"
+
+    try:
+        with open(file_path, 'rb') as f:
+            files = {
+                'file': f
+            }
+            data = {
+                'callback_url': 'https://client.example.com/callback?task_id=ocr-12345',
+                'project_plan_type': 'bridge_up_part',
+                'user': 'user-001'
+            }
+
+            print("📤 开始上传文件...")
+            response = requests.post(upload_url, files=files, data=data)
+            response.raise_for_status()
+
+            result = response.json()
+            print(f"✅ 文件上传成功")
+            print(f"   文件ID: {result['data']['id']}")
+            print(f"   文件名: {result['data']['name']}")
+            print(f"   回调任务ID: {result['data']['callback_task_id']}")
+
+            return result['data']['callback_task_id']
+
+    except Exception as e:
+        print(f"❌ 文件上传失败: {e}")
+        return None
+
+def test_sse_connection(callback_task_id: str, tracker: ProgressTracker):
+    """测试SSE连接"""
+    sse_url = f"http://127.0.0.1:8035/sgsc/sse/progress/{callback_task_id}?user=user-001"
+
+    try:
+        print(f"🔗 建立SSE连接: {sse_url}")
+        print("=" * 60)
+        print("📡 SSE原始响应:")
+        print("-" * 60)
+
+        # 建立SSE连接
+        response = requests.get(sse_url, stream=True)
+        response.raise_for_status()
+
+        for line in response.iter_lines():
+            if should_exit:
+                print("\n👋 用户请求退出,停止监听SSE")
+                break
+
+            if line:
+                # 直接打印原始响应
+                line_str = line.decode('utf-8')
+                print(f"原始响应: {line_str}")
+
+                # 解析SSE事件格式(可选)
+                if line_str.startswith('data: '):
+                    try:
+                        event_data = json.loads(line_str[6:])
+                        tracker.handle_sse_event({
+                            "type": "data",
+                            "data": event_data
+                        })
+
+                        # 如果任务完成,退出监听
+                        if tracker.completed:
+                            print("-" * 60)
+                            print("📡 任务完成,结束监听")
+                            break
+
+                    except json.JSONDecodeError:
+                        # JSON解析失败也继续监听
+                        pass
+
+    except Exception as e:
+        print(f"❌ SSE连接失败: {e}")
+
+def main():
+    """主测试流程"""
+    print("🚀 开始SSE进度推送集成测试")
+    print("💡 提示: 按 Ctrl+C 可以随时退出")
+    print("=" * 60)
+
+    # 第一步:上传文件
+    callback_task_id = test_file_upload()
+    if not callback_task_id or should_exit:
+        print("❌ 文件上传失败,测试终止")
+        return
+
+    print(f"\n⏳ 等待2秒后建立SSE连接...")
+    for i in range(2, 0, -1):
+        if should_exit:
+            print("\n👋 用户请求退出,测试终止")
+            return
+        print(f"   {i}秒...")
+        time.sleep(1)
+
+    # 第二步:建立SSE连接监听进度
+    tracker = ProgressTracker()
+
+    # 在主线程中运行SSE监听
+    test_sse_connection(callback_task_id, tracker)
+
+    # 输出测试结果
+    if not should_exit:
+        print("\n" + "=" * 60)
+        print("📊 测试结果汇总:")
+        print(f"   收到事件数量: {len(tracker.progress_events)}")
+        print(f"   任务是否完成: {'✅' if tracker.completed else '❌'}")
+
+        if tracker.error:
+            print(f"   错误信息: {tracker.error}")
+
+        if tracker.completed:
+            print("🎉 SSE实时推送测试成功!")
+        else:
+            print("⚠️ 任务可能未在预期时间内完成")
+
+
+
+if __name__ == "__main__":
+    main()

+ 4 - 4
views/__init__.py

@@ -22,10 +22,10 @@ async def lifespan(app: FastAPI):
     #await mcp_server.get_mcp_tools()
     # 全局数据库连接池实例
     async_db_pool = None
-    async_db_pool = AsyncMySQLPool()
-    await async_db_pool.initialize()
-    app.state.async_db_pool = async_db_pool
-    server_logger.info(f"✅ MySQL数据库连接池:{app.state.async_db_pool}")
+    # async_db_pool = AsyncMySQLPool()
+    # await async_db_pool.initialize()
+    # app.state.async_db_pool = async_db_pool
+    #server_logger.info(f"✅ MySQL数据库连接池:{app.state.async_db_pool}")
 
     yield
     # 关闭时清理

+ 16 - 25
views/construction_review/app.py

@@ -3,31 +3,25 @@
 整合所有接口,提供统一的测试服务
 """
 
-import datetime
 import sys
-import os
-import threading
-import subprocess
 import time
-from multiprocessing import Process
-
-# 添加项目根目录到Python路径
-current_dir = os.path.dirname(os.path.abspath(__file__))
-project_root = os.path.dirname(os.path.dirname(current_dir))
-sys.path.insert(0, project_root)
+import uvicorn
+import datetime
+import threading
+from pathlib import Path
+root_dir = Path(__file__).parent.parent.parent
+sys.path.append(str(root_dir))
 
-# 现在可以正常导入了
-from foundation.logger.loggering import server_logger as logger
-from foundation.base.celery_app import app as celery_app
 from fastapi import FastAPI, HTTPException
-from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import JSONResponse
-import uvicorn
+from fastapi.middleware.cors import CORSMiddleware
+from foundation.base.celery_app import app as celery_app
+from foundation.logger.loggering import server_logger as logger
+
 
-# 现在可以正常导入了
 from views.construction_review.file_upload import file_upload_router
-from views.construction_review.task_progress import task_progress_router
 from views.construction_review.review_results import review_results_router
+from views.construction_review.launch_review import launch_review_router
 
 def create_app() -> FastAPI:
     """创建接口服务"""
@@ -48,8 +42,8 @@ def create_app() -> FastAPI:
 
     # 添加路由
     app.include_router(file_upload_router)
-    app.include_router(task_progress_router)
     app.include_router(review_results_router)
+    app.include_router(launch_review_router)
 
     # 全局异常处理
     @app.exception_handler(HTTPException)
@@ -91,10 +85,10 @@ def create_app() -> FastAPI:
                     "description": "上传施工方案文档"
                 },
                 {
-                    "name": "进度查询",
-                    "path": "/sgsc/task_progress/{callback_task_id}",
-                    "method": "GET",
-                    "description": "查询审查任务进度"
+                    "name": "审查启动",
+                    "path": "/sgsc/launch_review",
+                    "method": "POST",
+                    "description": "启动AI审查工作流"
                 },
                 {
                     "name": "结果获取",
@@ -125,9 +119,6 @@ class CeleryWorkerManager:
             return True
 
         try:
-            # 导入Celery应用
-            from foundation.base.celery_app import app as celery_app
-
             # 创建Worker函数
             def run_celery_worker():
                 try:

+ 83 - 56
views/construction_review/file_upload.py

@@ -7,16 +7,17 @@ import traceback
 import uuid
 import time
 from datetime import datetime
-from fastapi import APIRouter, UploadFile, File, Form, HTTPException
-from pydantic import BaseModel
+
+from pydantic import BaseModel, Field
 from typing import Optional,List
 from foundation.utils import md5
-from core.base.redis_duplicate_checker import RedisDuplicateChecker
-from core.base.workflow_manager import WorkflowManager
-from foundation.logger.loggering import server_logger as logger
 from foundation.base.config import config_handler
 from .schemas.error_schemas import FileUploadErrors
-
+from core.base.workflow_manager import WorkflowManager
+from foundation.logger.loggering import server_logger as logger
+from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Request
+from core.base.redis_duplicate_checker import RedisDuplicateChecker
+from foundation.trace.trace_context import TraceContext, auto_trace
 
 
 file_upload_router = APIRouter(prefix="/sgsc", tags=["文档上传"])
@@ -29,10 +30,26 @@ workflow_manager = WorkflowManager(
 # 使用workflow_manager的duplicatechecker实例,确保一致性
 duplicatechecker = workflow_manager.redis_duplicate_checker
 
+
+
 class FileUploadResponse(BaseModel):
     code: int
     data: dict
 
+def validate_upload_parameters(form_data) -> None:
+    """验证请求参数"""
+    allowed_params = {'file', 'user'}  # 只允许这两个参数
+
+    # 检查是否有不允许的参数
+    extra_params = []
+    for key in form_data.keys():
+        if key not in allowed_params:
+            extra_params.append(key)
+
+    if extra_params:
+        logger.warning(f"检测到不支持的参数: {extra_params}")
+        raise FileUploadErrors.invalid_parameters(extra_params)
+
 def get_file_size(file: UploadFile) -> int:
     """获取文件大小的可靠同步方法(兼容 seek 仅支持单参数的情况)"""
     try:
@@ -70,22 +87,20 @@ def validate_file(file: UploadFile, file_content: bytes = None) -> None:
     logger.info(f"文件类型验证通过: {actual_file_type} (扩展名: {file_extension}, MIME: {file.content_type})")
 
 @file_upload_router.post("/file_upload", response_model=FileUploadResponse)
+@auto_trace(generate_if_missing=True)  # 由于使用@auto_trace需要输入callback_task_id,但此时callback_task_id还未产生,所以暂时用初始trace_id替代
 async def file_upload(
-    file: List[UploadFile] = File([]),  
-    callback_url: str = Form(None),
-    project_plan_type: str = Form(None),
-    user: str = Form(None)  
+    request: Request,
+    file: List[UploadFile] = File([]),
+    user: str = Form(None)
 ):
     """
     文件上传接口
     """
     try:
-        # 验证工程方案类型
-        valid_project_types = {
-            'bridge_up_part',  # 桥梁上部结构
-            'tunnel_construction',  # 隧道施工
-            'road_repair'  # 道路维修
-        }
+        # 验证请求参数
+        form_data = await request.form()
+        validate_upload_parameters(form_data)
+
         valid_users = ast.literal_eval(config_handler.get("user_lists", "USERS"))
         
         # 验证文件上传
@@ -114,25 +129,16 @@ async def file_upload(
             raise FileUploadErrors.file_missing()
 
         # 验证文件大小限制
-        if file_size_mb > 30:  # 文件大小不能超过30MB
+        if file_size_mb > 50:  # 文件大小不能超过50MB
             raise FileUploadErrors.file_size_exceeded()
-        
-        # 验证回调地址
-        if callback_url is '':
-            raise FileUploadErrors.callback_url_missing()
- 
+
         # 验证用户标识
         if user is None or user not in valid_users:
             raise FileUploadErrors.invalid_user()
-        
-        # 工程方案类型校验
-        if project_plan_type not in valid_project_types:
-            raise FileUploadErrors.project_plan_type_invalid()
 
         # 生成文件MD5ID
         file_id = md5.md5_id(content)
-        if await duplicatechecker.is_duplicate_task(file_id):
-            raise FileUploadErrors.task_already_exists()
+
 
         created_at = int(time.time())
 
@@ -143,8 +149,6 @@ async def file_upload(
         logger.info(f"文件头信息: {content[:50] if 'content' in locals() else '未读取'}")
         logger.info(f"文件大小: {file_size_mb} MB")
         logger.info(f"========================", log_type="upload")
-        logger.info(f"请求参数 - 回调URL: {callback_url}\n, 工程类型: {project_plan_type}",
-                    log_type="upload")
         logger.info(f"用户标识: {user}")
 
         # 确定文件类型
@@ -156,10 +160,11 @@ async def file_upload(
         else:
             file_type = 'unknown'
 
-
-        # 生成回调任务ID
-        callback_task_id = f"{file_id}-{int(datetime.now().timestamp())}"
-
+        # 生成任务ID
+        callback_task_id = f"{file_id}-{int(datetime.now().timestamp())}"         
+        TraceContext.set_trace_id(callback_task_id)
+        logger.info(f"设置任务trace_id: {callback_task_id}")
+        
         # 记录文件信息
         file_info = {
                 'file_id': file_id,
@@ -169,32 +174,54 @@ async def file_upload(
                 'callback_task_id': callback_task_id,
                 "file_name": file[0].filename,
                 "file_size": file_size_mb,
-                "project_plan_type": project_plan_type,
                 'updated_at': created_at
             }
 
+        # 存储文件信息到Redis缓存,以file_id为键,供启动审查接口使用
         try:
-            # 提交处理任务到工作流管理器
-            await workflow_manager.submit_task_processing(file_info)
-            logger.info(f"文档处理任务已提交,任务ID: {callback_task_id}")
-
-
-
-            return FileUploadResponse(
-                code=200,
-                data={
-                    "id": file_info['file_id'],
-                    "name": file_info['file_name'],
-                    "size": file_size_mb,
-                    "created_at": created_at,
-                    "status": "processing",
-                    "callback_task_id": file_info['callback_task_id']
-                }
-            )
-
-        except Exception as workflow_error:
-            logger.error(f"工作流提交失败: {str(workflow_error)}")
-            raise FileUploadErrors.internal_error(workflow_error)
+            from foundation.utils.redis_utils import store_file_info
+
+            # 使用file_id作为键存储文件信息(1小时过期)
+            success = await store_file_info(file_id, file_info, 3600)
+            if success:
+                logger.info(f"文件信息已缓存到Redis: file_info:{file_id}")
+            else:
+                logger.warning(f"缓存文件信息到Redis失败")
+
+        except Exception as e:
+            logger.warning(f"缓存文件信息到Redis失败: {str(e)}")
+            # 不影响主流程,继续处理
+
+        # 预注册任务到重复检查器,以便启动审查时验证任务ID
+        try:
+            await duplicatechecker.register_task(file_info, callback_task_id)
+            logger.info(f"任务已预注册: {callback_task_id}")
+        except Exception as e:
+            logger.error(f"任务预注册失败: {str(e)}")
+            # 预注册失败不应影响文件上传成功
+
+        # try:
+            # # 提交处理任务到工作流管理器
+            # await workflow_manager.submit_task_processing(file_info)
+            # logger.info(f"文档处理任务已提交,任务ID: {callback_task_id}")
+
+
+
+        return FileUploadResponse(
+            code=200,
+            data={
+                "id": file_info['file_id'],
+                "name": file_info['file_name'],
+                "size": file_size_mb,
+                "created_at": created_at,
+                "status": "file_upload_success",
+                "callback_task_id": file_info['callback_task_id']
+            }
+        )
+
+        # except Exception as workflow_error:
+        #     logger.error(f"工作流提交失败: {str(workflow_error)}")
+        #     raise FileUploadErrors.internal_error(workflow_error)
 
     except HTTPException:
         logger.error(f"HTTP异常: {traceback.format_exc()}")

+ 358 - 0
views/construction_review/launch_review.py

@@ -0,0 +1,358 @@
+"""
+施工方案审查启动接口
+接收审查配置参数,启动AI审查工作流
+"""
+
+import uuid
+import time
+import json
+import asyncio
+import traceback
+from datetime import datetime
+from typing import List, Optional, Dict, Any
+from pydantic import BaseModel, Field
+from fastapi import APIRouter, HTTPException, Query
+from fastapi.responses import StreamingResponse
+from core.base.redis_duplicate_checker import RedisDuplicateChecker
+from foundation.logger.loggering import server_logger as logger
+from foundation.trace.trace_context import TraceContext, auto_trace
+from foundation.utils.redis_utils import get_file_info, delete_file_info
+from core.base.workflow_manager import WorkflowManager
+from core.base.progress_manager import ProgressManager, sse_callback_manager
+from views.construction_review.file_upload import validate_upload_parameters
+from .schemas.error_schemas import LaunchReviewErrors
+
+launch_review_router = APIRouter(prefix="/sgsc", tags=["审查启动"])
+duplicatechecker = RedisDuplicateChecker()
+# 初始化工作流管理器
+workflow_manager = WorkflowManager(
+    max_concurrent_docs=3,
+    max_concurrent_reviews=5
+)
+# 初始化进度管理器
+progress_manager = ProgressManager()
+
+async def sse_progress_callback(callback_task_id: str, current_data: dict):
+    """SSE推送回调函数 - 接收进度更新并推送到客户端"""
+    await sse_manager.send_progress(callback_task_id, current_data)
+
+class SimpleSSEManager:
+    """SSE连接管理器 - 管理客户端SSE连接和消息推送"""
+
+    def __init__(self):
+        self.connections: Dict[str, asyncio.Queue] = {}
+
+    async def connect(self, callback_task_id: str):
+        """建立SSE连接 - 创建消息队列并发送连接确认"""
+        queue = asyncio.Queue()
+        self.connections[callback_task_id] = queue
+
+        await queue.put({
+            "type": "connection_established",
+            "callback_task_id": callback_task_id,
+            "timestamp": datetime.now().isoformat()
+        })
+
+        logger.info(f"SSE连接: {callback_task_id}")
+        return queue
+
+    async def disconnect(self, callback_task_id: str):
+        """断开SSE连接 - 清理连接队列"""
+        if callback_task_id in self.connections:
+            del self.connections[callback_task_id]
+        logger.info(f"SSE连接已断开: {callback_task_id}")
+
+    async def send_progress(self, callback_task_id: str, current_data: dict):
+        """发送进度更新 - 将进度数据放入队列推送给客户端"""
+        queue = self.connections.get(callback_task_id)
+        if queue:
+            await queue.put({
+                "type": "progress_update",
+                "data": current_data,
+                "timestamp": datetime.now().isoformat()
+            })
+            logger.debug(f"SSE进度已推送: {callback_task_id}")
+
+sse_manager = SimpleSSEManager()
+
+def format_sse_event(event_type: str, data: str) -> str:
+    """格式化SSE事件 - 按照SSE协议格式化事件数据"""
+    lines = [
+        f"event: {event_type}",
+        f"data: {data}",
+        "",
+        ""
+    ]
+    return "\n".join(lines) + "\n"
+
+
+class LaunchReviewRequest(BaseModel):
+    """启动审查请求模型"""
+    callback_task_id: str = Field(..., description="回调任务ID,从文件上传接口获取")
+    review_config: List[str] = Field(
+        ...,
+        description="审查配置列表,包含的项为启用状态"
+    )
+    project_plan_type: str = Field(
+        "bridge_up_part",
+        description="工程方案类型,当前仅支持 bridge_up_part"
+    )
+
+    class Config:
+        extra = "forbid"  # 禁止额外的字段
+
+
+class LaunchReviewResponse(BaseModel):
+    """启动审查响应模型"""
+    code: int
+    data: dict
+
+
+def validate_review_config(review_config: List[str]) -> None:
+    """验证审查配置参数"""
+    # 检查review_config是否为空
+    if not review_config or len(review_config) == 0:
+        raise LaunchReviewErrors.enum_type_cannot_be_null()
+
+    # 支持的审查项枚举值
+    supported_review_items = {
+        'sensitive_word_check',       # 词句语法检查
+        'semantic_logic_check',       # 语义逻辑审查
+        'completeness_check',         # 条文完整性审查
+        'timeliness_check',           # 时效性审查
+        'reference_check',            # 规范性审查
+        'sensitive_words_check',      # 敏感词审查
+        'mandatory_standards_check',  # 强制性标准检查
+        'technical_parameters_check', # 技术参数精确检查
+        'design_values_check'         # 设计值符合性检查
+    }
+
+    # 检查是否包含不支持的审查项
+    unsupported_items = set(review_config) - supported_review_items
+    if unsupported_items:
+        raise LaunchReviewErrors.enum_type_invalid()
+
+def validate_project_plan_type(project_plan_type: str) -> None:
+    """验证工程方案类型"""
+    # 当前支持的工程方案类型
+    supported_types = {'bridge_up_part'}  # 桥梁上部结构
+
+    if project_plan_type not in supported_types:
+        raise LaunchReviewErrors.project_plan_type_invalid()
+
+
+@launch_review_router.post("/sse/launch_review")
+@auto_trace(generate_if_missing=True)
+async def launch_review_sse(request_data: LaunchReviewRequest):
+    """
+    启动施工方案审查并返回SSE进度流
+
+    Args:
+        request_data: 启动审查请求参数
+
+    Returns:
+        StreamingResponse: SSE事件流,包含任务启动状态和进度
+    """
+    callback_task_id = request_data.callback_task_id
+    TraceContext.set_trace_id(callback_task_id)
+    review_config = request_data.review_config
+    project_plan_type = request_data.project_plan_type
+
+    logger.info(f"收到审查启动SSE请求: callback_task_id={callback_task_id}")
+
+    # 验证审查配置
+    validate_review_config(review_config)
+
+    # 验证工程方案类型
+    validate_project_plan_type(project_plan_type)
+
+    # 注册SSE回调
+    sse_callback_manager.register_callback(callback_task_id, sse_progress_callback)
+    queue = await sse_manager.connect(callback_task_id)
+
+    async def generate_launch_review_events():
+        """生成启动审查SSE事件流"""
+        try:
+            # 发送连接确认
+            connected_data = json.dumps({
+                "callback_task_id": callback_task_id,
+                "message": "启动审查SSE连接已建立,正在处理请求...",
+                "timestamp": datetime.now().isoformat()
+            }, ensure_ascii=False)
+            yield format_sse_event("connected", connected_data)
+
+            # 处理启动审查逻辑
+            try:
+                from foundation.utils.redis_utils import get_file_info
+
+                # 从callback_task_id中提取file_id (格式: file_id-timestamp)
+                file_id = callback_task_id.rsplit('-', 1)[0] if '-' in callback_task_id else callback_task_id
+                logger.info(f"处理文件: {file_id}")
+                # 发送处理状态
+                status_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "stage": "validation",
+                    "message": f"正在验证文件信息: {file_id}",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("processing", status_data)
+
+                # 验证任务ID是否存在且未过期
+                if not await duplicatechecker.is_valid_task_id(callback_task_id):
+                    raise LaunchReviewErrors.task_not_found_or_expired()
+
+                # 检查任务是否已经被使用启动审查
+                if await duplicatechecker.is_task_already_used(callback_task_id):
+                    raise LaunchReviewErrors.task_already_exists()
+
+                # 标记任务为已使用
+                await duplicatechecker.mark_task_as_used(callback_task_id)
+
+                # 获取文件信息
+                status_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "stage": "loading",
+                    "message": "正在加载文件信息...",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("processing", status_data)
+
+                file_info = await get_file_info(file_id, include_content=True)
+
+                if not file_info:
+                    logger.error(f"文件信息获取失败: {file_id}")
+                    raise LaunchReviewErrors.file_info_not_found()
+
+                # 添加审查配置到文件信息
+                file_info.update({
+                    'review_config': review_config,
+                    'project_plan_type': project_plan_type,
+                    'launched_at': int(time.time())
+                })
+
+
+
+                # 发送提交任务状态
+                status_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "stage": "submitting",
+                    "message": "正在提交AI审查任务...",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("processing", status_data)
+
+                # 提交处理任务到工作流管理器
+                task_id = await workflow_manager.submit_task_processing(file_info)
+
+                # 发送成功启动状态
+                success_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "file_id": file_info['file_id'],
+                    "review_config": review_config,
+                    "project_plan_type": project_plan_type,
+                    "status": "submitted",
+                    "submitted_at": file_info['launched_at'],
+                    "message": "AI审查任务已成功启动",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("submitted", success_data)
+
+                # 继续监听工作流进度
+                logger.info(f"开始监听工作流进度: {callback_task_id}")
+                while True:
+                    try:
+                        message = await queue.get()
+
+                        if message.get("type") == "progress_update":
+                            current_data = message.get("data")
+                            if current_data:
+                                progress_json = json.dumps(current_data, ensure_ascii=False)
+                                yield format_sse_event("progress", progress_json)
+
+                                overall_task_status = current_data.get("overall_task_status")
+                                if overall_task_status in ["completed", "failed"]:
+                                    completion_data = {
+                                        "callback_task_id": callback_task_id,
+                                        "task_status": overall_task_status,
+                                        "overall_progress": current_data.get("current", 100),
+                                        "timestamp": datetime.now().isoformat(),
+                                        "message": "审查任务处理完成!"
+                                    }
+                                    completion_json = json.dumps(completion_data, ensure_ascii=False)
+                                    yield format_sse_event("completed", completion_json)
+                                    break
+
+                    except Exception as e:
+                        logger.error(f"队列消息处理异常: {callback_task_id}")
+                        logger.error(f"异常详情: {str(e)}")
+                        logger.error(f"异常堆栈: {traceback.format_exc()}")
+                        break
+
+            except HTTPException as e:
+                logger.error(f"HTTP异常: {callback_task_id}")
+                logger.error(f"异常详情: {str(e)}")
+                logger.error(f"异常堆栈: {traceback.format_exc()}")
+                error_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "error": e.detail.get("code") if hasattr(e, 'detail') and e.detail else "http_error",
+                    "message": e.detail.get("message") if hasattr(e, 'detail') and e.detail else str(e),
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("error", error_data)
+
+            except Exception as e:
+                logger.error(f"启动审查处理异常: {callback_task_id}")
+                logger.error(f"异常详情: {str(e)}")
+                logger.error(f"异常堆栈: {traceback.format_exc()}")
+                error_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "error": "internal_error",
+                    "message": f"服务端内部错误: {str(e)}",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("error", error_data)
+
+        except Exception as e:
+            logger.error(f"启动审查SSE事件流异常: {callback_task_id}")
+            logger.error(f"异常详情: {str(e)}")
+            logger.error(f"异常堆栈: {traceback.format_exc()}")
+            error_data = json.dumps({
+                "callback_task_id": callback_task_id,
+                "error": "sse_error",
+                "message": f"SSE流异常: {str(e)}",
+                "timestamp": datetime.now().isoformat()
+            }, ensure_ascii=False)
+            yield format_sse_event("error", error_data)
+
+        finally:
+            # 清理回调连接
+            sse_callback_manager.unregister_callback(callback_task_id)
+            await sse_manager.disconnect(callback_task_id)
+            logger.debug(f"启动审查SSE流已结束: {callback_task_id}")
+
+    return StreamingResponse(
+        generate_launch_review_events(),
+        media_type="text/event-stream",
+        headers={
+            "Cache-Control": "no-cache, no-store, must-revalidate",
+            "Connection": "keep-alive",
+            "Access-Control-Allow-Origin": "*",
+            "Access-Control-Allow-Headers": "Cache-Control, EventSource, Content-Type",
+            "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
+            "X-Accel-Buffering": "no",
+            "X-Content-Type-Options": "nosniff"
+        }
+    )
+
+
+
+
+@launch_review_router.get("/sse/launch_review_status")
+async def get_launch_review_sse_status():
+    """获取启动审查SSE连接状态 - 返回当前活跃的启动审查SSE连接信息"""
+    return {
+        "active_connections": len(sse_manager.connections),
+        "connections": list(sse_manager.connections.keys()),
+        "timestamp": datetime.now().isoformat(),
+        "service": "launch_review_sse"
+    }

+ 132 - 51
views/construction_review/schemas/error_schemas.py

@@ -42,16 +42,11 @@ class ErrorCodes:
     WJSC005 = {
         "code": "WJSC005",
         "error_type": "FILE_SIZE_EXCEEDED",
-        "message": "文件过大(最大不超过30MB)",
+        "message": "文件过大(最大不超过50MB)",
         "status_code": 400
     }
 
-    WJSC006 = {
-        "code": "WJSC006",
-        "error_type": "PROJECT_PLAN_TYPE_INVALID",
-        "message": "无效工程方案类型(未提供或未注册)",
-        "status_code": 400
-    }
+
 
     WJSC007 = {
         "code": "WJSC007",
@@ -69,17 +64,11 @@ class ErrorCodes:
 
     WJSC009 = {
         "code": "WJSC009",
-        "error_type": "CALLBACK_URL_MISS",
-        "message": "回调客户端地址缺失,请提供回调客户端地址",
-        "status_code": 403
+        "error_type": "INVALID_PARAMETERS",
+        "message": "请求参数无效或不支持",
+        "status_code": 400
     }
 
-    WJSC010 = {
-        "code": "WJSC010",
-        "error_type": "TASK_ALREADY_EXISTS",
-        "message": "任务已存在,请勿重复提交",
-        "status_code": 409
-    }
 
     WJSC011 = {
         "code": "WJSC011",
@@ -88,49 +77,102 @@ class ErrorCodes:
         "status_code": 500
     }
 
-    # 进度查询接口错误码 (JDLX001-JDLX006)
-    JDLX001 = {
-        "code": "JDLX001",
+
+
+    # 启动审查接口错误码 (QDSC001-QDSC006)
+    QDSC001 = {
+        "code": "QDSC001",
         "error_type": "MISSING_PARAMETERS",
         "message": "请求参数缺失",
         "status_code": 400
     }
 
-    JDLX002 = {
-        "code": "JDLX002",
+    QDSC002 = {
+        "code": "QDSC002",
         "error_type": "INVALID_PARAM_FORMAT",
         "message": "请求参数格式错误",
         "status_code": 400
     }
 
-    JDLX003 = {
-        "code": "JDLX003",
+    QDSC003 = {
+        "code": "QDSC003",
         "error_type": "UNAUTHORIZED",
         "message": "认证失败(未提供或无效的Authorization)",
         "status_code": 401
     }
 
-    JDLX004 = {
-        "code": "JDLX004",
+    QDSC004 = {
+        "code": "QDSC004",
         "error_type": "INVALID_USER",
         "message": "用户标识未提供或无效",
         "status_code": 403
     }
 
-    JDLX005 = {
-        "code": "JDLX005",
+    QDSC005 = {
+        "code": "QDSC005",
         "error_type": "TASK_NOT_FOUND",
         "message": "任务ID不存在或已过期",
         "status_code": 404
     }
 
-    JDLX006 = {
-        "code": "JDLX006",
+    
+    QDSC006 = {
+        "code": "QDSC006",
+        "error_type": "TASK_ALREADY_EXISTS",
+        "message": "任务已存在,请勿重复提交",
+        "status_code": 409
+    }
+
+    QDSC007 = {
+        "code": "QDSC007",
+        "error_type": "PROJECT_PLAN_TYPE_INVALID",
+        "message": "无效工程方案类型(未提供或未注册)",
+        "status_code": 400
+    }
+
+    QDSC008 = {
+        "code": "QDSC008",
+        "error_type": "ENUM_TYPE_INVALID",
+        "message": "审查枚举类型无效",
+        "status_code": 400
+    }
+
+    QDSC009 = {
+        "code": "QDSC009",
+        "error_type": "ENUM_TYPE_CANNOT_BE_NULL",
+        "message": "审查枚举类型不能为空",
+        "status_code": 400
+    }
+
+    QDSC010 = {
+        "code": "QDSC010",
+        "error_type": "FILE_INFO_NOT_FOUND",
+        "message": "文件信息获取失败",
+        "status_code": 500
+    }
+
+    QDSC011 = {
+        "code": "QDSC011",
         "error_type": "SERVER_INTERNAL_ERROR",
         "message": "服务端内部错误",
         "status_code": 500
     }
 
+    QDSC012 = {
+        "code": "QDSC012",
+        "error_type": "TASK_NOT_FOUND_OR_EXPIRED",
+        "message": "任务ID不存在或已过期,请重新检查callback_task_id是否正确,或重新上传文件",
+        "status_code": 404
+    }
+
+    QDSC013 = {
+        "code": "QDSC013",
+        "error_type": "FILE_INFO_NOT_FOUND",
+        "message": "文件信息获取失败,任务ID不存在或已过期",
+        "status_code": 404
+    }
+
+
     # 审查结果接口错误码 (SCJG001-SCJG008)
     SCJG001 = {
         "code": "SCJG001",
@@ -217,7 +259,7 @@ def create_server_error(error_code: str, original_error: Exception) -> HTTPExcep
     创建服务器内部错误异常
 
     Args:
-        error_code: 错误码 (如 "WJSC008", "JDLX006", "SCJG008")
+        error_code: 错误码 (如 "WJSC008", "QDSC006", "SCJG008")
         original_error: 原始异常
 
     Returns:
@@ -225,7 +267,7 @@ def create_server_error(error_code: str, original_error: Exception) -> HTTPExcep
     """
     error_map = {
         "WJSC011": ErrorCodes.WJSC011,
-        "JDLX006": ErrorCodes.JDLX006,
+        "QDSC006": ErrorCodes.QDSC006,
         "SCJG008": ErrorCodes.SCJG008
     }
 
@@ -279,10 +321,6 @@ class FileUploadErrors:
         logger.error(ErrorCodes.WJSC008)
         return create_http_exception(ErrorCodes.WJSC008)
     
-    @staticmethod
-    def callback_url_missing():
-        logger.error(ErrorCodes.WJSC009)
-        return create_http_exception(ErrorCodes.WJSC009)
 
 
     @staticmethod
@@ -290,44 +328,85 @@ class FileUploadErrors:
         logger.error(ErrorCodes.WJSC010)
         return create_http_exception(ErrorCodes.WJSC010)
 
+    
+    @staticmethod
+    def invalid_parameters():
+        logger.error(ErrorCodes.WJSC009)
+        return create_http_exception(ErrorCodes.WJSC009)
+
     @staticmethod
     def internal_error(original_error: Exception):
         logger.error(ErrorCodes.WJSC011)
         return create_server_error("WJSC011", original_error)
 
 
-class TaskProgressErrors:
-    """进度查询接口错误"""
+class LaunchReviewErrors:
+    """启动审查接口错误"""
 
     @staticmethod
     def missing_parameters():
-        logger.error(ErrorCodes.JDLX001)
-        return create_http_exception(ErrorCodes.JDLX001)
+        logger.error(ErrorCodes.QDSC001)
+        return create_http_exception(ErrorCodes.QDSC001)
 
     @staticmethod
     def invalid_param_format():
-        logger.error(ErrorCodes.JDLX002)
-        return create_http_exception(ErrorCodes.JDLX002)
+        logger.error(ErrorCodes.QDSC002)
+        return create_http_exception(ErrorCodes.QDSC002)
 
     @staticmethod
     def unauthorized():
-        logger.error(ErrorCodes.JDLX003)
-        return create_http_exception(ErrorCodes.JDLX003)
+        logger.error(ErrorCodes.QDSC003)
+        return create_http_exception(ErrorCodes.QDSC003)
 
     @staticmethod
     def invalid_user():
-        logger.error(ErrorCodes.JDLX004)
-        return create_http_exception(ErrorCodes.JDLX004)
+        logger.error(ErrorCodes.QDSC004)
+        return create_http_exception(ErrorCodes.QDSC004)
 
     @staticmethod
     def task_not_found():
-        logger.error(ErrorCodes.JDLX005)
-        return create_http_exception(ErrorCodes.JDLX005)
+        logger.error(ErrorCodes.QDSC005)
+        return create_http_exception(ErrorCodes.QDSC005)
 
     @staticmethod
-    def server_internal_error(original_error: Exception):
-        logger.error(ErrorCodes.JDLX006, original_error)
-        return create_server_error("JDLX006", original_error)
+    def task_already_exists():
+        logger.error(ErrorCodes.QDSC006)
+        return create_http_exception(ErrorCodes.QDSC006)
+
+    @staticmethod
+    def project_plan_type_invalid():
+        logger.error(ErrorCodes.QDSC007)
+        return create_http_exception(ErrorCodes.QDSC007)
+
+    @staticmethod
+    def enum_type_invalid():
+        logger.error(ErrorCodes.QDSC008)
+        return create_http_exception(ErrorCodes.QDSC008)
+
+    @staticmethod
+    def enum_type_cannot_be_null():
+        logger.error(ErrorCodes.QDSC009)
+        return create_http_exception(ErrorCodes.QDSC009)
+
+    @staticmethod
+    def file_info_not_found(original_error: Exception):
+        logger.error(ErrorCodes.QDSC010)
+        return create_server_error("QDSC010", original_error)
+
+    @staticmethod
+    def internal_error(original_error: Exception):
+        logger.error(ErrorCodes.QDSC011)
+        return create_server_error("QDSC011", original_error)
+
+    @staticmethod
+    def task_not_found_or_expired():
+        logger.error(ErrorCodes.QDSC012)
+        return create_http_exception(ErrorCodes.QDSC012)
+
+    @staticmethod
+    def file_info_not_found():
+        logger.error(ErrorCodes.QDSC013)
+        return create_http_exception(ErrorCodes.QDSC013)
 
 
 class ReviewResultsErrors:
@@ -371,4 +450,6 @@ class ReviewResultsErrors:
     @staticmethod
     def server_error(original_error: Exception):
         logger.error(ErrorCodes.SCJG008)
-        return create_server_error("SCJG008", original_error)
+        return create_server_error("SCJG008", original_error)
+
+

+ 167 - 130
views/construction_review/task_progress.py

@@ -1,158 +1,195 @@
 """
-审查进度轮询接口
-支持Celery任务状态查询和进度展示
+审查进度SSE实时推送接口
 """
 
-import time
-import random
+import json
+import asyncio
+from typing import Dict
 from datetime import datetime
-from fastapi import APIRouter, HTTPException, Query
 from pydantic import BaseModel
-from typing import Optional
-from celery.result import AsyncResult
-from foundation.base.celery_app import app
+from fastapi import APIRouter, Query
+from .schemas.error_schemas import LaunchReviewErrors
+from fastapi.responses import StreamingResponse
+from foundation.logger.loggering import server_logger as logger
+from foundation.trace.trace_context import TraceContext, auto_trace
+from core.base.progress_manager import ProgressManager, sse_callback_manager
 
-task_progress_router = APIRouter(prefix="/sgsc", tags=["进度轮询"])
+progress_manager = ProgressManager()
 
+task_progress_router = APIRouter(prefix="/sgsc", tags=["进度推送"])
 
-# 导入错误码定义
-from .schemas.error_schemas import TaskProgressErrors
+async def sse_progress_callback(callback_task_id: str, current_data: dict):
+    """SSE推送回调函数 - 接收进度更新并推送到客户端"""
+    await sse_manager.send_progress(callback_task_id, current_data)
 
 class TaskProgressResponse(BaseModel):
     code: int
     data: dict
 
-def update_task_progress(callback_task_id: str) -> dict:
-    """更新任务进度(模拟真实的处理过程)"""
-    if callback_task_id not in uploaded_files:
-        return None
 
-    task_info = uploaded_files[callback_task_id]
-    current_time = int(time.time())
+class SimpleSSEManager:
+    """SSE连接管理器 - 管理客户端SSE连接和消息推送"""
 
-    # 根据时间模拟进度推进
-    time_elapsed = current_time - task_info.get("updated_at", current_time)
 
-    # 定义各阶段的时间分配(总时长约30分钟)
-    stage_durations = {
-        "格式校验": 60,      # 1分钟
-        "内容提取": 900,     # 15分钟
-        "智能审查": 840      # 14分钟
-    }
+    def __init__(self):
+        self.connections: Dict[str, asyncio.Queue] = {}
 
-    total_duration = sum(stage_durations.values())
 
-    # 计算当前应该处于哪个阶段
-    accumulated_time = 0
-    overall_progress = 0
-    stages = []
+    async def connect(self, callback_task_id: str):
+        """建立SSE连接 - 创建消息队列并发送连接确认"""
+        queue = asyncio.Queue()
+        self.connections[callback_task_id] = queue
 
-    for stage_name, duration in stage_durations.items():
-        if time_elapsed > accumulated_time + duration:
-            # 阶段已完成
-            stages.append({
-                "stage_name": stage_name,
-                "progress": 100,
-                "stage_status": "completed"
-            })
-            accumulated_time += duration
-        elif time_elapsed > accumulated_time:
-            # 阶段进行中
-            stage_progress = min(100, int((time_elapsed - accumulated_time) / duration * 100))
-            stages.append({
-                "stage_name": stage_name,
-                "progress": stage_progress,
-                "stage_status": "processing"
-            })
-            accumulated_time += duration
-        else:
-            # 阶段未开始
-            stages.append({
-                "stage_name": stage_name,
-                "progress": 0,
-                "stage_status": "pending"
+        await queue.put({
+            "type": "connection_established",
+            "callback_task_id": callback_task_id,
+            "timestamp": datetime.now().isoformat()
+        })
+
+        logger.info(f"SSE连接: {callback_task_id}")
+        return queue
+
+
+    async def disconnect(self, callback_task_id: str):
+        """断开SSE连接 - 清理连接队列"""
+        if callback_task_id in self.connections:
+            del self.connections[callback_task_id]
+        logger.info(f"SSE连接已断开: {callback_task_id}")
+
+
+    async def send_progress(self, callback_task_id: str, current_data: dict):
+        """发送进度更新 - 将进度数据放入队列推送给客户端"""
+        queue = self.connections.get(callback_task_id)
+        if queue:
+            await queue.put({
+                "type": "progress_update",
+                "data": current_data,
+                "timestamp": datetime.now().isoformat()
             })
+            logger.debug(f"SSE进度已推送: {callback_task_id}")
 
-    # 计算总进度
-    overall_progress = min(100, int(time_elapsed / total_duration * 100))
-
-    # 确定任务状态
-    if overall_progress >= 100:
-        review_task_status = "completed"
-        estimated_remaining = 0
-    else:
-        review_task_status = "processing"
-        estimated_remaining = max(0, total_duration - time_elapsed)
-
-    # 更新任务信息
-    task_info.update({
-        "review_task_status": review_task_status,
-        "overall_progress": overall_progress,
-        "stages": stages,
-        "updated_at": current_time,
-        "estimated_remaining": estimated_remaining
-    })
-
-    return task_info
-
-@task_progress_router.get("/task_progress/{callback_task_id}", response_model=TaskProgressResponse)
-async def task_progress(
+sse_manager = SimpleSSEManager()
+
+def format_sse_event(event_type: str, data: str) -> str:
+    """格式化SSE事件 - 按照SSE协议格式化事件数据"""
+    lines = [
+        f"event: {event_type}",
+        f"data: {data}",
+        "",
+        ""
+    ]
+    return "\n".join(lines) + "\n" 
+
+
+@task_progress_router.get("/sse/progress/{callback_task_id}")
+@auto_trace("callback_task_id")
+async def sse_progress_stream(
     callback_task_id: str,
-    user: str = Query(None)
+    user: str = Query(..., description="用户标识")
 ):
-    """
-    任务进度轮询接口
-    """
+    """SSE实时进度推送接口 - 建立SSE连接并实时推送任务进度"""
     try:
-        # 验证参数
-        if user is None or not isinstance(user, str):
-            raise TaskProgressErrors.missing_parameters()
-
-        if not callback_task_id or not isinstance(callback_task_id, str):
-            raise TaskProgressErrors.missing_parameters()
-
-        # 检查callback_task_id格式(应该是UUID-时间戳格式)
-        if len(callback_task_id) < 20 or callback_task_id.count('-') < 4:
-            raise TaskProgressErrors.invalid_param_format()
-
-        # 验证用户标识(应该是指定用户如user-001)
-        valid_users = {"user-001", "user-002", "user-003"}  # 可以配置化
-        if user == "" or user not in valid_users:
-            raise TaskProgressErrors.invalid_user()
-
-        # 检查任务是否存在
-        if callback_task_id not in uploaded_files:
-            raise TaskProgressErrors.task_not_found()
-
-        # 验证用户权限
-        task_info = uploaded_files[callback_task_id]
-        if task_info.get("user") != user:
-            raise TaskProgressErrors.invalid_user()
-
-        # 更新进度
-        updated_task = update_task_progress(callback_task_id)
-
-        return TaskProgressResponse(
-            code=200,
-            data={
-                "callback_task_id": callback_task_id,
-                "user": user,
-                "review_task_status": updated_task["review_task_status"],
-                "overall_progress": updated_task["overall_progress"],
-                "stages": updated_task["stages"],
-                "updated_at": updated_task["updated_at"]
+        valid_users = {"user-001", "user-002", "user-003"}
+        if user not in valid_users:
+            raise LaunchReviewErrors.invalid_user()
+        sse_callback_manager.register_callback(callback_task_id, sse_progress_callback)
+
+        queue = await sse_manager.connect(callback_task_id)
+
+        async def generate_events():
+            """生成SSE事件流 - 处理连接确认、进度推送和任务完成检测"""
+            try:
+                logger.info(f"开始SSE事件流: {callback_task_id}")
+
+                connected_data = json.dumps({
+                    "callback_task_id": callback_task_id,
+                    "message": "SSE连接已建立,等待进度更新...",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("connected", connected_data)
+
+                current_progress = await progress_manager.get_progress(callback_task_id)
+                if current_progress:
+                    progress_json = json.dumps(current_progress, ensure_ascii=False)
+                    yield format_sse_event("current", progress_json)
+
+                logger.debug(f"开始监听队列中的进度更新: {callback_task_id}")
+
+                while True:
+                    try:
+                        message = await queue.get()
+
+                        if message.get("type") == "progress_update":
+                            current_data = message.get("data")
+                            if current_data:
+                                logger.info(f"总流程处理进度: {current_data.get("message")}")
+
+                                progress_json = json.dumps(current_data, ensure_ascii=False)
+                                yield format_sse_event("current", progress_json)
+
+                                overall_task_status = current_data.get("overall_task_status")
+
+                                if overall_task_status in ["completed", "failed"]:
+                                    completion_data = {
+                                        "callback_task_id": callback_task_id,
+                                        "task_status": overall_task_status,
+                                        "overall_progress": current_data.get("current", 100),
+                                        "timestamp": datetime.now().isoformat(),
+                                        "message": "全部任务完成!"
+                                    }
+                                    completion_json = json.dumps(completion_data, ensure_ascii=False)
+                                    yield format_sse_event("completed", completion_json)
+
+                                    #sse_callback_manager.unregister_callback(callback_task_id)
+                                    await sse_manager.disconnect(callback_task_id)
+                                    logger.info(f"全部任务完成,SSE连接已关闭: {callback_task_id}, 状态: {overall_task_status}")
+                                    break
+
+                        elif message.get("type") == "connection_established":
+                            pass
+
+                    except Exception as e:
+                        logger.error(f"队列消息处理异常: {callback_task_id}, {e}")
+                        break
+
+            except Exception as e:
+                logger.error(f"SSE事件流异常: {callback_task_id}, {e}")
+                error_data = json.dumps({
+                    "error": f"SSE异常: {str(e)}",
+                    "timestamp": datetime.now().isoformat()
+                }, ensure_ascii=False)
+                yield format_sse_event("error", error_data)
+
+            finally:
+                sse_callback_manager.unregister_callback(callback_task_id)
+                await sse_manager.disconnect(callback_task_id)
+                logger.debug(f"SSE流已结束: {callback_task_id}")
+
+        return StreamingResponse(
+            generate_events(),
+            media_type="text/event-stream",
+            headers={
+                "Cache-Control": "no-cache, no-store, must-revalidate",
+                "Connection": "keep-alive",
+                "Access-Control-Allow-Origin": "*",
+                "Access-Control-Allow-Headers": "Cache-Control, EventSource",
+                "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
+                "X-Accel-Buffering": "no",  
+                "X-Content-Type-Options": "nosniff"
             }
         )
 
-    except HTTPException:
-        raise
     except Exception as e:
-        raise TaskProgressErrors.server_internal_error(e)
-
-@task_progress_router.post("/mock/advance_time")
-async def advance_time(seconds: int = 300):
-    """Mock接口:推进时间(用于测试)"""
-    for callback_task_id in list(uploaded_files.keys()):
-        if "review_task_status" in uploaded_files[callback_task_id]:
-            uploaded_files[callback_task_id]["updated_at"] -= seconds
-    return {"message": f"时间推进了 {seconds} 秒"}
+        logger.error(f"SSE连接失败: {callback_task_id}, {e}")
+        raise LaunchReviewErrors.internal_error(e)
+
+
+@task_progress_router.get("/sse/status")
+async def get_sse_status():
+    """获取SSE连接状态 - 返回当前活跃的SSE连接信息"""
+    return {
+        "active_connections": len(sse_manager.connections),
+        "connections": list(sse_manager.connections.keys()),
+        "timestamp": datetime.now().isoformat()
+    }
+

Некоторые файлы не были показаны из-за большого количества измененных файлов