sse_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """
  2. SSE 格式化工具
  3. 提供 SSE 事件的格式化和生成功能,用于审查调试 API 的流式响应。
  4. 支持 started / step_progress / step_result / phase_label / completed / error / replay_comparison
  5. 共 7 种事件类型。
  6. """
  7. import asyncio
  8. import json
  9. from typing import Any, Dict, List, Optional, AsyncGenerator, Tuple
  10. # ============================================================
  11. # 常量
  12. # ============================================================
  13. SSE_EVENT_TYPES = [
  14. "started",
  15. "step_progress",
  16. "step_result",
  17. "phase_label",
  18. "completed",
  19. "error",
  20. "replay_comparison",
  21. ]
  22. MAX_CONCURRENT_DEBUG_TASKS = 5
  23. """最大并发调试任务数"""
  24. DEBUG_GLOBAL_TIMEOUT = 180
  25. """全局默认超时(秒)"""
  26. # ============================================================
  27. # 并发控制
  28. # ============================================================
  29. debug_semaphore = asyncio.Semaphore(MAX_CONCURRENT_DEBUG_TASKS)
  30. """全局调试任务信号量,控制并发上限"""
  31. _running_tasks: Dict[str, asyncio.Queue] = {}
  32. """正在执行的任务队列缓存,key=task_id, value=event_queue,用于断线重连"""
  33. # ============================================================
  34. # SSE 格式化
  35. # ============================================================
  36. def format_sse_event(event: str, data: dict) -> str:
  37. """
  38. 格式化为 SSE 规范的事件字符串。
  39. Args:
  40. event: 事件类型(started, step_progress, step_result, phase_label,
  41. completed, error, replay_comparison)
  42. data: 事件数据字典
  43. Returns:
  44. 符合 SSE 协议的文本行(含两个末尾换行)
  45. """
  46. return (
  47. f"event: {event}\n"
  48. f"data: {json.dumps(data, ensure_ascii=False, default=str)}\n\n"
  49. )
  50. async def sse_generator(event_queue: asyncio.Queue) -> AsyncGenerator[str, None]:
  51. """
  52. SSE 事件生成器,从队列消费事件并格式化为 SSE 文本流。
  53. 队列元素格式: (event_type: str, data: dict)
  54. 当收到 ("__end__", None) 时停止生成。
  55. Args:
  56. event_queue: asyncio.Queue,由执行器填充事件
  57. Yields:
  58. SSE 格式的文本行
  59. """
  60. while True:
  61. event_type, data = await event_queue.get()
  62. if event_type == "__end__":
  63. break
  64. yield format_sse_event(event_type, data)
  65. event_queue.task_done()
  66. def make_trace_id(chain_id: str) -> str:
  67. """生成 trace_id,添加 debug_ 前缀实现生产隔离"""
  68. from datetime import datetime
  69. import uuid
  70. return (
  71. f"debug_{chain_id}_"
  72. f"{datetime.now().strftime('%H%M%S')}_"
  73. f"{uuid.uuid4().hex[:8]}"
  74. )
  75. def make_record_id() -> str:
  76. """生成记录 ID: call-{YYYYMMDD}-{HHMMSS}-{hex}"""
  77. from datetime import datetime
  78. import uuid
  79. return (
  80. f"call-{datetime.now().strftime('%Y%m%d')}-"
  81. f"{datetime.now().strftime('%H%M%S')}-"
  82. f"{uuid.uuid4().hex[:6]}"
  83. )
  84. # ============================================================
  85. # 链名称映射(避免循环引用 executor)
  86. # ============================================================
  87. CHAIN_NAMES = {
  88. "completeness": "完整性审查",
  89. "timeliness": "时效性审查",
  90. "reference": "规范性审查",
  91. "sensitive": "敏感词检查",
  92. "semantic": "语义逻辑检查",
  93. "grammar": "语法检查",
  94. "professional": "专业性审查",
  95. }
  96. CHAIN_STEPS_COUNT = {
  97. "completeness": 3,
  98. "timeliness": 3,
  99. "reference": 3,
  100. "sensitive": 3,
  101. "semantic": 3,
  102. "grammar": 3,
  103. "professional": 7,
  104. }