celery_trace.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """
  2. Celery Trace管理
  3. 负责在Celery队列任务中传递和恢复trace_id上下文
  4. """
  5. from celery.signals import task_prerun, task_postrun, task_failure
  6. from .trace_context import TraceContext
  7. class CeleryTraceManager:
  8. """Celery trace上下文管理器"""
  9. @staticmethod
  10. def init_celery_signals():
  11. """初始化Celery信号,自动管理trace_id上下文"""
  12. @task_prerun.connect
  13. def task_prerun_handler(sender=None, task_id=None, task=None, args=None, kwargs=None, **kwds):
  14. """
  15. 任务执行前的信号处理
  16. 从任务参数中提取trace_id并设置到TraceContext
  17. """
  18. # 延迟导入避免循环依赖
  19. from foundation.observability.logger.loggering import server_logger as logger
  20. try:
  21. # 从kwargs中提取trace_id参数
  22. trace_id = kwargs.pop('_system_trace_id', None) or kwargs.pop('callback_task_id', None)
  23. if trace_id:
  24. TraceContext.set_trace_id(trace_id)
  25. logger.info(f"Celery任务恢复trace_id: {trace_id}, 任务ID: {task_id}")
  26. else:
  27. # 如果没有找到trace_id,生成一个临时的
  28. fallback_trace = f"celery-{task_id[:8]}"
  29. TraceContext.set_trace_id(fallback_trace)
  30. logger.warning(f"Celery任务未找到trace_id,使用临时trace: {fallback_trace}")
  31. except Exception as e:
  32. logger.error(f"Celery任务trace_id恢复失败: {str(e)}")
  33. # 生成临时trace_id
  34. fallback_trace = f"celery-error-{task_id[:8]}"
  35. TraceContext.set_trace_id(fallback_trace)
  36. @task_postrun.connect
  37. def task_postrun_handler(sender=None, task_id=None, task=None, args=None, kwargs=None, retval=None, state=None, **kwds):
  38. """
  39. 任务执行后的信号处理
  40. 清理trace_id上下文
  41. """
  42. # 延迟导入避免循环依赖
  43. from foundation.observability.logger.loggering import server_logger as logger
  44. try:
  45. trace_id = TraceContext.get_trace_id()
  46. logger.info(f"Celery任务完成: {trace_id}, 任务ID: {task_id}")
  47. # 可选:清理trace_id
  48. # TraceContext.set_trace_id(None)
  49. except Exception as e:
  50. logger.error(f"Celery任务trace_id清理失败: {str(e)}")
  51. @task_failure.connect
  52. def task_failure_handler(sender=None, task_id=None, exception=None, traceback=None, einfo=None, **kwds):
  53. """
  54. 任务失败时的信号处理
  55. """
  56. # 延迟导入避免循环依赖
  57. from foundation.observability.logger.loggering import server_logger as logger
  58. try:
  59. trace_id = TraceContext.get_trace_id()
  60. logger.error(f"Celery任务失败: {trace_id}, 任务ID: {task_id}, 错误: {str(exception)}")
  61. except Exception as e:
  62. logger.error(f"Celery任务失败trace_id记录失败: {str(e)}, 任务ID: {task_id}")
  63. @staticmethod
  64. def submit_celery_task(task_func, *args, **kwargs):
  65. """
  66. 提交Celery任务时自动传递当前trace_id
  67. Args:
  68. task_func: Celery任务函数
  69. *args: 位置参数
  70. **kwargs: 关键字参数
  71. Returns:
  72. Celery任务结果
  73. """
  74. # 延迟导入避免循环依赖
  75. from foundation.observability.logger.loggering import server_logger as logger
  76. # 获取当前trace_id
  77. current_trace_id = TraceContext.get_trace_id()
  78. # 将trace_id添加到任务参数中
  79. if current_trace_id and current_trace_id != 'no-trace':
  80. kwargs['_system_trace_id'] = current_trace_id
  81. logger.info(f"提交Celery任务")
  82. # 提交任务
  83. return task_func.delay(*args, **kwargs)
  84. def add_trace_to_celery_task(celery_task_func):
  85. """
  86. 装饰器:为Celery任务函数自动添加trace_id支持
  87. Usage:
  88. @add_trace_to_celery_task
  89. @app.task(bind=True)
  90. def my_task(self, file_info: dict):
  91. # 任务逻辑
  92. pass
  93. """
  94. def decorator(*args, **kwargs):
  95. # 获取当前trace_id
  96. current_trace_id = TraceContext.get_trace_id()
  97. if current_trace_id and current_trace_id != 'no-trace':
  98. kwargs['_system_trace_id'] = current_trace_id
  99. return celery_task_func(*args, **kwargs)
  100. return decorator
  101. # 自动初始化Celery信号
  102. def init():
  103. """初始化Celery trace系统"""
  104. # 延迟导入避免循环依赖
  105. try:
  106. from foundation.observability.logger.loggering import server_logger as logger
  107. except ImportError:
  108. import logging
  109. logger = logging.getLogger(__name__)
  110. CeleryTraceManager.init_celery_signals()
  111. try:
  112. logger.info("Celery trace系统初始化完成")
  113. except:
  114. pass # 如果logger不可用,静默继续