trace_context.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. """
  2. Trace Context Manager
  3. 负责管理系统级别的trace_id上下文,支持异步并发和队列传播
  4. """
  5. import contextvars
  6. import uuid
  7. import asyncio
  8. import threading
  9. from typing import Optional, Dict, Any, Callable
  10. from functools import wraps
  11. import logging
  12. # 全局trace_id上下文变量 - 自动跨异步传播
  13. system_trace_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar('system_trace_id', default=None)
  14. class TraceContext:
  15. """Trace上下文管理器"""
  16. @staticmethod
  17. def set_trace_id(trace_id: str) -> None:
  18. """设置系统级trace_id"""
  19. if trace_id:
  20. system_trace_id.set(trace_id)
  21. @staticmethod
  22. def get_trace_id() -> str:
  23. """获取当前trace_id"""
  24. return system_trace_id.get() or 'no-trace'
  25. @staticmethod
  26. def generate_trace_id() -> str:
  27. """生成新的trace_id"""
  28. return str(uuid.uuid4())[:8]
  29. @staticmethod
  30. def get_or_generate_trace_id() -> str:
  31. """获取当前trace_id,如果不存在则生成新的"""
  32. current = system_trace_id.get()
  33. return current if current else TraceContext.generate_trace_id()
  34. @staticmethod
  35. def extract_context() -> Dict[str, Any]:
  36. """提取当前上下文信息,用于队列传递"""
  37. return {
  38. 'system_trace_id': system_trace_id.get(),
  39. 'thread_id': threading.get_ident(),
  40. 'async_context': str(system_trace_id._context) if hasattr(system_trace_id, '_context') else None
  41. }
  42. @staticmethod
  43. def restore_context(context_data: Dict[str, Any]) -> None:
  44. """从队列任务中恢复trace_id上下文"""
  45. if context_data and 'system_trace_id' in context_data:
  46. trace_id = context_data['system_trace_id']
  47. if trace_id:
  48. system_trace_id.set(trace_id)
  49. @staticmethod
  50. def with_trace_context(trace_id: str):
  51. """上下文管理器 - 临时设置trace_id"""
  52. return _TraceContextManager(trace_id)
  53. class _TraceContextManager:
  54. """临时trace上下文管理器"""
  55. def __init__(self, trace_id: str):
  56. self.trace_id = trace_id
  57. self.token = None
  58. def __enter__(self):
  59. self.token = system_trace_id.set(self.trace_id)
  60. return self.trace_id
  61. def __exit__(self, exc_type, exc_val, exc_tb):
  62. if self.token:
  63. system_trace_id.reset(self.token)
  64. def auto_trace(trace_id_param: Optional[str] = 'callback_task_id', generate_if_missing: bool = False):
  65. """
  66. 自动trace装饰器 - 自动管理trace_id生命周期
  67. Args:
  68. trace_id_param: 参数名,用于从函数参数中提取trace_id,如果为None则只使用generate_if_missing
  69. generate_if_missing: 如果为True,当没有trace_id时自动生成
  70. """
  71. def decorator(func: Callable):
  72. if asyncio.iscoroutinefunction(func):
  73. @wraps(func)
  74. async def async_wrapper(*args, **kwargs):
  75. # 尝试从参数中提取trace_id
  76. trace_id = None
  77. # 只有当trace_id_param不为None时才从参数中查找
  78. if trace_id_param:
  79. # 从kwargs中查找
  80. if trace_id_param in kwargs:
  81. trace_id = kwargs[trace_id_param]
  82. # 从位置参数中查找
  83. elif args and isinstance(args[0], str):
  84. trace_id = args[0]
  85. # 如果还是没有找到且允许自动生成
  86. if not trace_id and generate_if_missing:
  87. trace_id = TraceContext.generate_trace_id()
  88. # 设置trace_id
  89. if trace_id:
  90. TraceContext.set_trace_id(trace_id)
  91. return await func(*args, **kwargs)
  92. return async_wrapper
  93. else:
  94. @wraps(func)
  95. def sync_wrapper(*args, **kwargs):
  96. # 同步函数的逻辑类似
  97. trace_id = None
  98. # 只有当trace_id_param不为None时才从参数中查找
  99. if trace_id_param:
  100. if trace_id_param in kwargs:
  101. trace_id = kwargs[trace_id_param]
  102. elif args and isinstance(args[0], str):
  103. trace_id = args[0]
  104. if not trace_id and generate_if_missing:
  105. trace_id = TraceContext.generate_trace_id()
  106. if trace_id:
  107. TraceContext.set_trace_id(trace_id)
  108. return func(*args, **kwargs)
  109. return sync_wrapper
  110. return decorator
  111. class TraceFilter(logging.Filter):
  112. """
  113. 自定义Logger Filter - 自动注入system_trace_id到日志记录
  114. """
  115. def filter(self, record: logging.LogRecord) -> bool:
  116. """为日志记录添加system_trace_id字段"""
  117. record.system_trace_id = TraceContext.get_trace_id()
  118. return True
  119. # 全局TraceFilter实例,供logger使用
  120. trace_filter = TraceFilter()