""" Trace Context Manager 负责管理系统级别的trace_id上下文,支持异步并发和队列传播 """ import contextvars import uuid import asyncio import threading from typing import Optional, Dict, Any, Callable from functools import wraps import logging # 全局trace_id上下文变量 - 自动跨异步传播 system_trace_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('system_trace_id', default=None) class TraceContext: """Trace上下文管理器""" @staticmethod def set_trace_id(trace_id: str) -> None: """设置系统级trace_id""" if trace_id: system_trace_id.set(trace_id) @staticmethod def get_trace_id() -> str: """获取当前trace_id""" return system_trace_id.get() or 'no-trace' @staticmethod def generate_trace_id() -> str: """生成新的trace_id""" return str(uuid.uuid4())[:8] @staticmethod def get_or_generate_trace_id() -> str: """获取当前trace_id,如果不存在则生成新的""" current = system_trace_id.get() return current if current else TraceContext.generate_trace_id() @staticmethod def extract_context() -> Dict[str, Any]: """提取当前上下文信息,用于队列传递""" return { 'system_trace_id': system_trace_id.get(), 'thread_id': threading.get_ident(), 'async_context': str(system_trace_id._context) if hasattr(system_trace_id, '_context') else None } @staticmethod def restore_context(context_data: Dict[str, Any]) -> None: """从队列任务中恢复trace_id上下文""" if context_data and 'system_trace_id' in context_data: trace_id = context_data['system_trace_id'] if trace_id: system_trace_id.set(trace_id) @staticmethod def with_trace_context(trace_id: str): """上下文管理器 - 临时设置trace_id""" return _TraceContextManager(trace_id) class _TraceContextManager: """临时trace上下文管理器""" def __init__(self, trace_id: str): self.trace_id = trace_id self.token = None def __enter__(self): self.token = system_trace_id.set(self.trace_id) return self.trace_id def __exit__(self, exc_type, exc_val, exc_tb): if self.token: system_trace_id.reset(self.token) def auto_trace(trace_id_param: Optional[str] = 'callback_task_id', generate_if_missing: bool = False): """ 自动trace装饰器 - 自动管理trace_id生命周期 Args: trace_id_param: 参数名,用于从函数参数中提取trace_id,如果为None则只使用generate_if_missing generate_if_missing: 如果为True,当没有trace_id时自动生成 """ def decorator(func: Callable): if asyncio.iscoroutinefunction(func): @wraps(func) async def async_wrapper(*args, **kwargs): # 尝试从参数中提取trace_id trace_id = None # 只有当trace_id_param不为None时才从参数中查找 if trace_id_param: # 从kwargs中查找 if trace_id_param in kwargs: trace_id = kwargs[trace_id_param] # 从位置参数中查找 elif args and isinstance(args[0], str): trace_id = args[0] # 如果还是没有找到且允许自动生成 if not trace_id and generate_if_missing: trace_id = TraceContext.generate_trace_id() # 设置trace_id if trace_id: TraceContext.set_trace_id(trace_id) return await func(*args, **kwargs) return async_wrapper else: @wraps(func) def sync_wrapper(*args, **kwargs): # 同步函数的逻辑类似 trace_id = None # 只有当trace_id_param不为None时才从参数中查找 if trace_id_param: if trace_id_param in kwargs: trace_id = kwargs[trace_id_param] elif args and isinstance(args[0], str): trace_id = args[0] if not trace_id and generate_if_missing: trace_id = TraceContext.generate_trace_id() if trace_id: TraceContext.set_trace_id(trace_id) return func(*args, **kwargs) return sync_wrapper return decorator class TraceFilter(logging.Filter): """ 自定义Logger Filter - 自动注入system_trace_id到日志记录 """ def filter(self, record: logging.LogRecord) -> bool: """为日志记录添加system_trace_id字段""" record.system_trace_id = TraceContext.get_trace_id() return True # 全局TraceFilter实例,供logger使用 trace_filter = TraceFilter()