| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- """
- Trace Context Manager
- 负责管理系统级别的trace_id上下文,支持异步并发和队列传播
- """
- import contextvars
- import uuid
- import asyncio
- import threading
- from typing import Optional, Dict, Any, Callable
- from functools import wraps
- import logging
- # 全局trace_id上下文变量 - 自动跨异步传播
- system_trace_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('system_trace_id', default=None)
- class TraceContext:
- """Trace上下文管理器"""
- @staticmethod
- def set_trace_id(trace_id: str) -> None:
- """设置系统级trace_id"""
- if trace_id:
- system_trace_id.set(trace_id)
- @staticmethod
- def get_trace_id() -> str:
- """获取当前trace_id"""
- return system_trace_id.get() or 'no-trace'
- @staticmethod
- def generate_trace_id() -> str:
- """生成新的trace_id"""
- return str(uuid.uuid4())[:8]
- @staticmethod
- def get_or_generate_trace_id() -> str:
- """获取当前trace_id,如果不存在则生成新的"""
- current = system_trace_id.get()
- return current if current else TraceContext.generate_trace_id()
- @staticmethod
- def extract_context() -> Dict[str, Any]:
- """提取当前上下文信息,用于队列传递"""
- return {
- 'system_trace_id': system_trace_id.get(),
- 'thread_id': threading.get_ident(),
- 'async_context': str(system_trace_id._context) if hasattr(system_trace_id, '_context') else None
- }
- @staticmethod
- def restore_context(context_data: Dict[str, Any]) -> None:
- """从队列任务中恢复trace_id上下文"""
- if context_data and 'system_trace_id' in context_data:
- trace_id = context_data['system_trace_id']
- if trace_id:
- system_trace_id.set(trace_id)
- @staticmethod
- def with_trace_context(trace_id: str):
- """上下文管理器 - 临时设置trace_id"""
- return _TraceContextManager(trace_id)
- class _TraceContextManager:
- """临时trace上下文管理器"""
- def __init__(self, trace_id: str):
- self.trace_id = trace_id
- self.token = None
- def __enter__(self):
- self.token = system_trace_id.set(self.trace_id)
- return self.trace_id
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self.token:
- system_trace_id.reset(self.token)
- def auto_trace(trace_id_param: Optional[str] = 'callback_task_id', generate_if_missing: bool = False):
- """
- 自动trace装饰器 - 自动管理trace_id生命周期
- Args:
- trace_id_param: 参数名,用于从函数参数中提取trace_id,如果为None则只使用generate_if_missing
- generate_if_missing: 如果为True,当没有trace_id时自动生成
- """
- def decorator(func: Callable):
- if asyncio.iscoroutinefunction(func):
- @wraps(func)
- async def async_wrapper(*args, **kwargs):
- # 尝试从参数中提取trace_id
- trace_id = None
- # 只有当trace_id_param不为None时才从参数中查找
- if trace_id_param:
- # 从kwargs中查找
- if trace_id_param in kwargs:
- trace_id = kwargs[trace_id_param]
- # 从位置参数中查找
- elif args and isinstance(args[0], str):
- trace_id = args[0]
- # 如果还是没有找到且允许自动生成
- if not trace_id and generate_if_missing:
- trace_id = TraceContext.generate_trace_id()
- # 设置trace_id
- if trace_id:
- TraceContext.set_trace_id(trace_id)
- return await func(*args, **kwargs)
- return async_wrapper
- else:
- @wraps(func)
- def sync_wrapper(*args, **kwargs):
- # 同步函数的逻辑类似
- trace_id = None
- # 只有当trace_id_param不为None时才从参数中查找
- if trace_id_param:
- if trace_id_param in kwargs:
- trace_id = kwargs[trace_id_param]
- elif args and isinstance(args[0], str):
- trace_id = args[0]
- if not trace_id and generate_if_missing:
- trace_id = TraceContext.generate_trace_id()
- if trace_id:
- TraceContext.set_trace_id(trace_id)
- return func(*args, **kwargs)
- return sync_wrapper
- return decorator
- class TraceFilter(logging.Filter):
- """
- 自定义Logger Filter - 自动注入system_trace_id到日志记录
- """
- def filter(self, record: logging.LogRecord) -> bool:
- """为日志记录添加system_trace_id字段"""
- record.system_trace_id = TraceContext.get_trace_id()
- return True
- # 全局TraceFilter实例,供logger使用
- trace_filter = TraceFilter()
|