Răsfoiți Sursa

v0.0.2-增加全局trace_id

WangXuMing 3 luni în urmă
părinte
comite
1f5f5ea60a

+ 2 - 2
config/config.ini

@@ -1,7 +1,7 @@
 
 
 [model]
-MODEL_TYPE=qwen_local_14b
+MODEL_TYPE=qwen_local_1.5b
 
 
 
@@ -29,7 +29,7 @@ QWEN_API_KEY=ms-9ad4a379-d592-4acd-b92c-8bac08a4a045
 
 [ai_review]
 # 调试模式配置
-MAX_REVIEW_UNITS=1
+MAX_REVIEW_UNITS=10
 REVIEW_MODE=random
 # REVIEW_MODE=all/random/first
 

+ 1 - 88
core/base/workflow_manager.py

@@ -110,6 +110,7 @@ class WorkflowManager:
 
             # 6. 启动处理流程(同步执行)
             self._process_task_chain_sync(task_chain, file_info['file_content'], file_info['file_type'])
+
             # logger.info(f"提交文档处理任务: {callback_task_id}")
             logger.info(f"施工方案审查任务已完成! ")
             logger.info(f"文件ID: {file_info['file_id']}")
@@ -120,94 +121,6 @@ class WorkflowManager:
             raise
     
 
-    async def _process_task_chain(self, task_chain: TaskChain, file_content: bytes, file_type: str):
-        """处理文档任务链 - 串行执行,内部并发"""
-        try:
-            task_chain.started_at = datetime.now()
-
-            # 阶段1:文档处理(串行)
-            async with self.doc_semaphore:
-                task_chain.current_stage = "document_processing"
-
-                document_workflow = DocumentWorkflow(
-                    file_id=task_chain.file_id,
-                    callback_task_id=task_chain.callback_task_id,
-                    user_id=task_chain.user_id,
-                    progress_manager=self.progress_manager,
-                    redis_duplicate_checker=self.redis_duplicate_checker
-                )
-
-                doc_result = await document_workflow.execute(file_content, file_type)
-                task_chain.results['document'] = doc_result
-
-            # 阶段2:AI审查(内部并发)
-            task_chain.current_stage = "ai_review"
-
-            structured_content = doc_result['structured_content']
-
-            # 读取AI审查配置
-            import configparser
-            config = configparser.ConfigParser()
-            config.read('config/config.ini', encoding='utf-8')
-
-            max_review_units = config.getint('ai_review', 'MAX_REVIEW_UNITS', fallback=None)
-            if max_review_units == 0:  # 如果配置为0,表示审查所有
-                max_review_units = None
-            review_mode = config.get('ai_review', 'REVIEW_MODE', fallback='all')
-
-            logger.info(f"AI审查配置: 最大审查条文数量={max_review_units}, 审查模式={review_mode}")
-
-            ai_workflow = AIReviewWorkflow(
-                file_id=task_chain.file_id,
-                callback_task_id=task_chain.callback_task_id,
-                user_id=task_chain.user_id,
-                structured_content=structured_content,
-                progress_manager=self.progress_manager,
-                max_review_units=max_review_units,
-                review_mode=review_mode
-            )
-
-            ai_result = await ai_workflow.execute()
-            task_chain.results['ai_review'] = ai_result
-
-            # 阶段3:报告生成(串行)
-            task_chain.current_stage = "report_generation"
-
-            report_workflow = ReportWorkflow(
-                file_id=task_chain.file_id,
-                callback_task_id=task_chain.callback_task_id,
-                user_id=task_chain.user_id,
-                ai_review_results=ai_result,
-                progress_manager=self.progress_manager
-            )
-
-            report_result = await report_workflow.execute()
-            task_chain.results['report'] = report_result
-
-            # 完成任务链
-            task_chain.status = "completed"
-            task_chain.completed_at = datetime.now()
-
-            # 清理任务注册
-            await self.redis_duplicate_checker.unregister_task(task_chain.file_id)
-
-            logger.info(f"文档处理任务链完成: {task_chain.callback_task_id}")
-
-        except Exception as e:
-            task_chain.status = "failed"
-            logger.error(f"文档处理任务链失败: {task_chain.callback_task_id}, 错误: {str(e)}")
-
-            # 清理任务注册
-            await self.redis_duplicate_checker.unregister_task(task_chain.file_id)
-
-            raise
-        finally:
-            # 清理活跃任务
-            if task_chain.callback_task_id in self.active_chains:
-                del self.active_chains[task_chain.callback_task_id]
-
-
-
     def _process_task_chain_sync(self, task_chain: TaskChain, file_content: bytes, file_type: str):
         """同步处理文档任务链(用于Celery worker)"""
         try:

+ 2 - 2
core/construction_review/component/reviewers/base_reviewer.py

@@ -9,7 +9,7 @@ import time
 from abc import ABC
 from typing import Dict, Any, Optional
 from dataclasses import dataclass
-from langfuse import obverse
+# from langfuse import obverse
 from foundation.agent.monitor.ai_trace_monitor import lf
 from foundation.agent.generate.model_generate import generate_model_client
 from core.construction_review.component.reviewers.utils.prompt_loader import prompt_loader
@@ -32,7 +32,7 @@ class BaseReviewer(ABC):
         self.model_client = generate_model_client
         self.prompt_loader = prompt_loader
     
-    @obverse
+    # @obverse
     async def review(self, name: str, trace_id: str, reviewer_type: str, prompt_name: str, review_content: str,review_references: str = None) -> ReviewResult:
         """
         执行审查

+ 7 - 1
foundation/base/celery_app.py

@@ -7,6 +7,9 @@ import os
 from celery import Celery
 from .config import config_handler
 
+# 导入trace系统
+from foundation.trace.celery_trace import init
+
 # 从配置文件获取Redis连接信息
 redis_host = config_handler.get('redis', 'REDIS_HOST', 'localhost')
 redis_port = config_handler.get('redis', 'REDIS_PORT', '6379')
@@ -52,4 +55,7 @@ app.conf.update(
 
     # 结果过期时间
     result_expires=3600,           # 1小时后过期
-)
+)
+
+# 初始化Celery trace系统
+init()

+ 7 - 1
foundation/base/tasks.py

@@ -11,13 +11,19 @@ from foundation.utils.time_statistics import track_execution_time
 
 
 @app.task(bind=True)
-def submit_task_processing_task(self, file_info: dict):
+def submit_task_processing_task(self, file_info: dict, _system_trace_id: str = None):
     """
     提交任务处理到Celery队列
     这个任务只负责调用WorkflowManager,不包含业务逻辑
     """
     import traceback
 
+    # 恢复trace_id上下文
+    if _system_trace_id:
+        from foundation.trace.trace_context import TraceContext
+        TraceContext.set_trace_id(_system_trace_id)
+        logger.info(f"Celery任务恢复trace_id: {_system_trace_id}")
+
     # 添加调试信息
     logger.info("=== Celery任务接收调试 ===")
     logger.info(f"队列ID: {self.request.id}")

+ 17 - 2
foundation/logger/loggering.py

@@ -15,6 +15,9 @@ import sys
 import logging
 from logging.handlers import RotatingFileHandler
 
+# 导入trace系统
+from foundation.trace.trace_context import TraceContext, trace_filter
+
 
 class CompatibleLogger(logging.Logger):
     """
@@ -55,7 +58,8 @@ class CompatibleLogger(logging.Logger):
     def _set_formatter(self, log_format, datefmt):
         """设置日志格式"""
         if log_format is None:
-            log_format = 'P%(process)d.T%(thread)d | %(asctime)s | %(levelname)-8s | %(trace_id)-10s | %(log_type)-5s | %(message)s'
+            # 使用system_trace_id字段,通过TraceFilter自动注入
+            log_format = 'P%(process)d.T%(thread)d | %(asctime)s | %(levelname)-8s | %(system_trace_id)-15s | %(log_type)-5s | %(message)s'
 
         if datefmt is None:
             datefmt = '%Y-%m-%d %H:%M:%S'
@@ -84,6 +88,8 @@ class CompatibleLogger(logging.Logger):
             handler.setFormatter(self.formatter)
             # 为每个级别的日志文件都添加一个筛选器,确保记录该级别及其更高级别
             handler.addFilter(lambda record, lvl=level: record.levelno >= lvl)
+            # 添加trace_filter,自动注入system_trace_id
+            handler.addFilter(trace_filter)
             self.addHandler(handler)
 
     def _create_console_handler(self):
@@ -91,11 +97,18 @@ class CompatibleLogger(logging.Logger):
         console_handler = logging.StreamHandler(sys.stdout)
         console_handler.setLevel(logging.INFO)
         console_handler.setFormatter(self.formatter)
+        # 添加trace_filter,自动注入system_trace_id
+        console_handler.addFilter(trace_filter)
         self.addHandler(console_handler)
 
     def _log_with_context(self, level, msg, trace_id, log_type, *args, **kwargs):
-        """统一的日志记录方法"""
+        """统一的日志记录方法 - 兼容手动传递trace_id和自动获取trace_id"""
         extra = kwargs.get('extra', {})
+
+        # 如果没有手动传递trace_id,则从TraceContext自动获取
+        if not trace_id:
+            trace_id = TraceContext.get_trace_id()
+
         extra.update({
             'trace_id': trace_id,
             'log_type': log_type
@@ -140,6 +153,8 @@ server_logger = CompatibleLogger(
     backup_count=int(config_handler.get("log", "LOG_BACKUP_COUNT", "5"))
 )
 
+# 添加trace_filter到logger,自动注入system_trace_id
+server_logger.addFilter(trace_filter)
 
 # 设置日志级别
 server_logger.info("logging initialized")

+ 121 - 0
foundation/trace/celery_trace.py

@@ -0,0 +1,121 @@
+"""
+Celery Trace管理
+负责在Celery队列任务中传递和恢复trace_id上下文
+"""
+
+from celery.signals import task_prerun, task_postrun, task_failure
+from foundation.trace.trace_context import TraceContext
+from foundation.logger.loggering import server_logger as logger
+
+
+class CeleryTraceManager:
+    """Celery trace上下文管理器"""
+
+    @staticmethod
+    def init_celery_signals():
+        """初始化Celery信号,自动管理trace_id上下文"""
+
+        @task_prerun.connect
+        def task_prerun_handler(sender=None, task_id=None, task=None, args=None, kwargs=None, **kwds):
+            """
+            任务执行前的信号处理
+            从任务参数中提取trace_id并设置到TraceContext
+            """
+            try:
+                # 从kwargs中提取trace_id参数
+                trace_id = kwargs.pop('_system_trace_id', None) or kwargs.pop('callback_task_id', None)
+
+                if trace_id:
+                    TraceContext.set_trace_id(trace_id)
+                    logger.info(f"Celery任务恢复trace_id: {trace_id}, 任务ID: {task_id}")
+                else:
+                    # 如果没有找到trace_id,生成一个临时的
+                    fallback_trace = f"celery-{task_id[:8]}"
+                    TraceContext.set_trace_id(fallback_trace)
+                    logger.warning(f"Celery任务未找到trace_id,使用临时trace: {fallback_trace}")
+
+            except Exception as e:
+                logger.error(f"Celery任务trace_id恢复失败: {str(e)}")
+                # 生成临时trace_id
+                fallback_trace = f"celery-error-{task_id[:8]}"
+                TraceContext.set_trace_id(fallback_trace)
+
+        @task_postrun.connect
+        def task_postrun_handler(sender=None, task_id=None, task=None, args=None, kwargs=None, retval=None, state=None, **kwds):
+            """
+            任务执行后的信号处理
+            清理trace_id上下文
+            """
+            try:
+                trace_id = TraceContext.get_trace_id()
+                logger.info(f"Celery任务完成: {trace_id}, 任务ID: {task_id}")
+                # 可选:清理trace_id
+                # TraceContext.set_trace_id(None)
+            except Exception as e:
+                logger.error(f"Celery任务trace_id清理失败: {str(e)}")
+
+        @task_failure.connect
+        def task_failure_handler(sender=None, task_id=None, exception=None, traceback=None, einfo=None, **kwds):
+            """
+            任务失败时的信号处理
+            """
+            try:
+                trace_id = TraceContext.get_trace_id()
+                logger.error(f"Celery任务失败: {trace_id}, 任务ID: {task_id}, 错误: {str(exception)}")
+            except Exception as e:
+                logger.error(f"Celery任务失败trace_id记录失败: {str(e)}, 任务ID: {task_id}")
+
+    @staticmethod
+    def submit_celery_task(task_func, *args, **kwargs):
+        """
+        提交Celery任务时自动传递当前trace_id
+
+        Args:
+            task_func: Celery任务函数
+            *args: 位置参数
+            **kwargs: 关键字参数
+
+        Returns:
+            Celery任务结果
+        """
+        # 获取当前trace_id
+        current_trace_id = TraceContext.get_trace_id()
+
+        # 将trace_id添加到任务参数中
+        if current_trace_id and current_trace_id != 'no-trace':
+            kwargs['_system_trace_id'] = current_trace_id
+
+        logger.info(f"提交Celery任务,trace_id: {current_trace_id}")
+
+        # 提交任务
+        return task_func.delay(*args, **kwargs)
+
+
+def add_trace_to_celery_task(celery_task_func):
+    """
+    装饰器:为Celery任务函数自动添加trace_id支持
+
+    Usage:
+        @add_trace_to_celery_task
+        @app.task(bind=True)
+        def my_task(self, file_info: dict):
+            # 任务逻辑
+            pass
+    """
+    def decorator(*args, **kwargs):
+        # 获取当前trace_id
+        current_trace_id = TraceContext.get_trace_id()
+
+        if current_trace_id and current_trace_id != 'no-trace':
+            kwargs['_system_trace_id'] = current_trace_id
+
+        return celery_task_func(*args, **kwargs)
+
+    return decorator
+
+
+# 自动初始化Celery信号
+def init():
+    """初始化Celery trace系统"""
+    CeleryTraceManager.init_celery_signals()
+    logger.info("Celery trace系统初始化完成")

+ 153 - 0
foundation/trace/trace_context.py

@@ -0,0 +1,153 @@
+"""
+Trace Context Manager
+负责管理系统级别的trace_id上下文,支持异步并发和队列传播
+"""
+
+import contextvars
+import uuid
+import asyncio
+import threading
+from typing import Optional, Dict, Any, Callable
+from functools import wraps
+import logging
+
+# 全局trace_id上下文变量 - 自动跨异步传播
+system_trace_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('system_trace_id', default=None)
+
+
+class TraceContext:
+    """Trace上下文管理器"""
+
+    @staticmethod
+    def set_trace_id(trace_id: str) -> None:
+        """设置系统级trace_id"""
+        if trace_id:
+            system_trace_id.set(trace_id)
+
+    @staticmethod
+    def get_trace_id() -> str:
+        """获取当前trace_id"""
+        return system_trace_id.get() or 'no-trace'
+
+    @staticmethod
+    def generate_trace_id() -> str:
+        """生成新的trace_id"""
+        return str(uuid.uuid4())[:8]
+
+    @staticmethod
+    def get_or_generate_trace_id() -> str:
+        """获取当前trace_id,如果不存在则生成新的"""
+        current = system_trace_id.get()
+        return current if current else TraceContext.generate_trace_id()
+
+    @staticmethod
+    def extract_context() -> Dict[str, Any]:
+        """提取当前上下文信息,用于队列传递"""
+        return {
+            'system_trace_id': system_trace_id.get(),
+            'thread_id': threading.get_ident(),
+            'async_context': str(system_trace_id._context) if hasattr(system_trace_id, '_context') else None
+        }
+
+    @staticmethod
+    def restore_context(context_data: Dict[str, Any]) -> None:
+        """从队列任务中恢复trace_id上下文"""
+        if context_data and 'system_trace_id' in context_data:
+            trace_id = context_data['system_trace_id']
+            if trace_id:
+                system_trace_id.set(trace_id)
+
+    @staticmethod
+    def with_trace_context(trace_id: str):
+        """上下文管理器 - 临时设置trace_id"""
+        return _TraceContextManager(trace_id)
+
+
+class _TraceContextManager:
+    """临时trace上下文管理器"""
+
+    def __init__(self, trace_id: str):
+        self.trace_id = trace_id
+        self.token = None
+
+    def __enter__(self):
+        self.token = system_trace_id.set(self.trace_id)
+        return self.trace_id
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if self.token:
+            system_trace_id.reset(self.token)
+
+
+def auto_trace(trace_id_param: Optional[str] = 'callback_task_id', generate_if_missing: bool = False):
+    """
+    自动trace装饰器 - 自动管理trace_id生命周期
+
+    Args:
+        trace_id_param: 参数名,用于从函数参数中提取trace_id,如果为None则只使用generate_if_missing
+        generate_if_missing: 如果为True,当没有trace_id时自动生成
+    """
+    def decorator(func: Callable):
+        if asyncio.iscoroutinefunction(func):
+            @wraps(func)
+            async def async_wrapper(*args, **kwargs):
+                # 尝试从参数中提取trace_id
+                trace_id = None
+
+                # 只有当trace_id_param不为None时才从参数中查找
+                if trace_id_param:
+                    # 从kwargs中查找
+                    if trace_id_param in kwargs:
+                        trace_id = kwargs[trace_id_param]
+
+                    # 从位置参数中查找
+                    elif args and isinstance(args[0], str):
+                        trace_id = args[0]
+
+                # 如果还是没有找到且允许自动生成
+                if not trace_id and generate_if_missing:
+                    trace_id = TraceContext.generate_trace_id()
+
+                # 设置trace_id
+                if trace_id:
+                    TraceContext.set_trace_id(trace_id)
+
+                return await func(*args, **kwargs)
+            return async_wrapper
+        else:
+            @wraps(func)
+            def sync_wrapper(*args, **kwargs):
+                # 同步函数的逻辑类似
+                trace_id = None
+
+                # 只有当trace_id_param不为None时才从参数中查找
+                if trace_id_param:
+                    if trace_id_param in kwargs:
+                        trace_id = kwargs[trace_id_param]
+                    elif args and isinstance(args[0], str):
+                        trace_id = args[0]
+
+                if not trace_id and generate_if_missing:
+                    trace_id = TraceContext.generate_trace_id()
+
+                if trace_id:
+                    TraceContext.set_trace_id(trace_id)
+
+                return func(*args, **kwargs)
+            return sync_wrapper
+    return decorator
+
+
+class TraceFilter(logging.Filter):
+    """
+    自定义Logger Filter - 自动注入system_trace_id到日志记录
+    """
+
+    def filter(self, record: logging.LogRecord) -> bool:
+        """为日志记录添加system_trace_id字段"""
+        record.system_trace_id = TraceContext.get_trace_id()
+        return True
+
+
+# 全局TraceFilter实例,供logger使用
+trace_filter = TraceFilter()

Fișier diff suprimat deoarece este prea mare
+ 6 - 5
temp/AI审查结果.json


+ 190 - 0
test/system_trace_id_test.py

@@ -0,0 +1,190 @@
+"""
+系统Trace ID测试
+验证trace_id在异步并发和队列中的正确传播
+"""
+import os
+import sys
+# Add the parent directory (LQAgentPlatform) to sys.path so we can import foundation
+project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.append(project_root)
+import asyncio
+import time
+from foundation.trace.trace_context import TraceContext, auto_trace
+from foundation.logger.loggering import server_logger as logger
+
+
+class TraceIDTest:
+    """Trace ID测试类"""
+
+    @staticmethod
+    async def test_basic_context():
+        """测试基础上下文功能"""
+        print("\n=== 测试1: 基础上下文功能 ===")
+
+        # 设置trace_id
+        trace_id = TraceContext.generate_trace_id()
+        TraceContext.set_trace_id(trace_id)
+
+        logger.info("测试基础日志,应该包含trace_id")
+        logger.info(f"手动设置的trace_id: {trace_id}")
+        logger.info(f"自动获取的trace_id: {TraceContext.get_trace_id()}")
+
+        assert TraceContext.get_trace_id() == trace_id, "trace_id设置失败"
+        print("[PASS] 基础上下文功能测试通过")
+
+    @staticmethod
+    async def test_async_propagation():
+        """测试异步并发传播"""
+        print("\n=== 测试2: 异步并发传播 ===")
+
+        # 设置主trace_id
+        main_trace = "main-async-test"
+        TraceContext.set_trace_id(main_trace)
+
+        logger.info("主异步任务开始")
+
+        async def concurrent_task(task_id: int):
+            """并发任务"""
+            current_trace = TraceContext.get_trace_id()
+            logger.info(f"并发任务 {task_id} 获取到的trace_id: {current_trace}")
+
+            # 在异步任务中修改trace_id,应该不影响其他任务
+            new_trace = f"{main_trace}-subtask-{task_id}"
+            TraceContext.set_trace_id(new_trace)
+
+            await asyncio.sleep(0.1)
+            logger.info(f"并发任务 {task_id} 修改后的trace_id: {new_trace}")
+
+            return current_trace
+
+        # 创建并发任务
+        tasks = [concurrent_task(i) for i in range(3)]
+        results = await asyncio.gather(*tasks)
+
+        # 验证所有任务都继承到了主trace_id
+        for i, result in enumerate(results):
+            assert result == main_trace, f"任务 {i} 没有继承主trace_id"
+
+        # 验证主任务trace_id不受影响
+        assert TraceContext.get_trace_id() == main_trace, "主trace_id被并发任务污染"
+
+        logger.info("主异步任务完成")
+        print("[PASS] 异步并发传播测试通过")
+
+    @staticmethod
+    @auto_trace('callback_task_id')
+    async def test_decorator_auto_trace(callback_task_id: str):
+        """测试装饰器自动trace"""
+        print(f"\n=== 测试3: 装饰器自动trace ===")
+
+        # 不需要手动设置trace_id,装饰器会自动处理
+        current_trace = TraceContext.get_trace_id()
+        logger.info("装饰器自动设置的日志")
+
+        assert current_trace == callback_task_id, "装饰器没有正确设置trace_id"
+
+        # 测试装饰器在异步并发中的表现
+        async def nested_task():
+            nested_trace = TraceContext.get_trace_id()
+            logger.info("嵌套异步任务")
+            return nested_trace
+
+        nested_result = await nested_task()
+        assert nested_result == callback_task_id, "嵌套任务没有继承装饰器设置的trace_id"
+
+        print(f"[PASS] 装饰器自动trace测试通过,trace_id: {callback_task_id}")
+
+    @staticmethod
+    async def test_context_manager():
+        """测试上下文管理器"""
+        print("\n=== 测试4: 上下文管理器 ===")
+
+        original_trace = TraceContext.get_trace_id()
+        logger.info(f"原始trace_id: {original_trace}")
+
+        # 使用上下文管理器临时设置trace_id
+        temp_trace = "temporary-trace"
+        with TraceContext.with_trace_context(temp_trace) as ctx:
+            logger.info("上下文管理器内的日志")
+            current_trace = TraceContext.get_trace_id()
+            assert current_trace == temp_trace, "上下文管理器没有正确设置trace_id"
+
+        # 退出上下文后应该恢复原始trace_id
+        restored_trace = TraceContext.get_trace_id()
+        logger.info(f"恢复后的trace_id: {restored_trace}")
+        assert restored_trace == original_trace, "上下文管理器没有正确恢复trace_id"
+
+        print("[PASS] 上下文管理器测试通过")
+
+    @staticmethod
+    def test_celery_task_simulation():
+        """测试Celery任务trace_id模拟"""
+        print("\n=== 测试5: Celery任务trace_id模拟 ===")
+
+        # 模拟提交Celery任务前的trace_id设置
+        submit_trace = "celery-submit-test"
+        TraceContext.set_trace_id(submit_trace)
+
+        logger.info("准备提交Celery任务")
+
+        # 模拟Celery任务执行
+        def simulate_celery_task_execution(file_info: dict, _system_trace_id=None):
+            """模拟Celery任务执行"""
+            if _system_trace_id:
+                TraceContext.set_trace_id(_system_trace_id)
+
+            current_trace = TraceContext.get_trace_id()
+            logger.info("Celery任务执行中")
+            logger.info(f"文件ID: {file_info.get('file_id')}")
+
+            return current_trace
+
+        # 提交任务(模拟)
+        file_info = {'file_id': 'test-file-123'}
+        extracted_trace = TraceContext.get_trace_id()
+
+        # 执行任务
+        task_trace = simulate_celery_task_execution(
+            file_info,
+            _system_trace_id=extracted_trace
+        )
+
+        assert task_trace == submit_trace, "Celery任务没有正确获取到trace_id"
+
+        print("[PASS] Celery任务trace_id模拟测试通过")
+
+
+async def run_all_tests():
+    """运行所有测试"""
+    print("开始运行系统Trace ID测试...\n")
+
+    try:
+        # 测试1: 基础上下文功能
+        await TraceIDTest.test_basic_context()
+
+        # 测试2: 异步并发传播
+        await TraceIDTest.test_async_propagation()
+
+        # 测试3: 装饰器自动trace
+        await TraceIDTest.test_decorator_auto_trace("decorator-test-123")
+
+        # 测试4: 上下文管理器
+        await TraceIDTest.test_context_manager()
+
+        # 测试5: Celery任务模拟
+        TraceIDTest.test_celery_task_simulation()
+
+        print("\n[SUCCESS] 所有测试通过!系统Trace ID机制工作正常")
+        return True
+
+    except Exception as e:
+        print(f"\n[FAIL] 测试失败: {str(e)}")
+        import traceback
+        traceback.print_exc()
+        return False
+
+
+if __name__ == "__main__":
+    # 运行测试
+    success = asyncio.run(run_all_tests())
+    exit(0 if success else 1)

+ 13 - 7
views/construction_review/file_upload.py

@@ -7,16 +7,17 @@ import traceback
 import uuid
 import time
 from datetime import datetime
-from fastapi import APIRouter, UploadFile, File, Form, HTTPException
+
 from pydantic import BaseModel
 from typing import Optional,List
 from foundation.utils import md5
-from core.base.redis_duplicate_checker import RedisDuplicateChecker
-from core.base.workflow_manager import WorkflowManager
-from foundation.logger.loggering import server_logger as logger
 from foundation.base.config import config_handler
 from .schemas.error_schemas import FileUploadErrors
-
+from core.base.workflow_manager import WorkflowManager
+from foundation.logger.loggering import server_logger as logger
+from fastapi import APIRouter, UploadFile, File, Form, HTTPException
+from core.base.redis_duplicate_checker import RedisDuplicateChecker
+from foundation.trace.trace_context import TraceContext, auto_trace
 
 
 file_upload_router = APIRouter(prefix="/sgsc", tags=["文档上传"])
@@ -70,11 +71,12 @@ def validate_file(file: UploadFile, file_content: bytes = None) -> None:
     logger.info(f"文件类型验证通过: {actual_file_type} (扩展名: {file_extension}, MIME: {file.content_type})")
 
 @file_upload_router.post("/file_upload", response_model=FileUploadResponse)
+@auto_trace(generate_if_missing=True)  # 不查找参数,直接生成初始trace_id
 async def file_upload(
-    file: List[UploadFile] = File([]),  
+    file: List[UploadFile] = File([]),
     callback_url: str = Form(None),
     project_plan_type: str = Form(None),
-    user: str = Form(None)  
+    user: str = Form(None)
 ):
     """
     文件上传接口
@@ -160,6 +162,10 @@ async def file_upload(
         # 生成回调任务ID
         callback_task_id = f"{file_id}-{int(datetime.now().timestamp())}"
 
+        # 更新trace_id为正式的callback_task_id
+        TraceContext.set_trace_id(callback_task_id)
+        logger.info(f"更新trace_id为正式callback_task_id: {callback_task_id}")
+
         # 记录文件信息
         file_info = {
                 'file_id': file_id,

+ 1 - 0
views/construction_review/schemas/error_schemas.py

@@ -290,6 +290,7 @@ class FileUploadErrors:
         logger.error(ErrorCodes.WJSC010)
         return create_http_exception(ErrorCodes.WJSC010)
 
+    
     @staticmethod
     def internal_error(original_error: Exception):
         logger.error(ErrorCodes.WJSC011)

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff