debug_api.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. """
  2. 审查链路调试 API 端点
  3. 提供审查链路调试的 SSE 流式执行端点,以及完整的 Pydantic 请求/响应模型定义。
  4. 功能:
  5. 1. POST /debug/api/review/execute — 执行审查调试,SSE 流式返回进度
  6. 2. GET /debug/api/review/stream/{task_id} — SSE 断线重连
  7. 3. 并发控制(最大 5 个并发调试任务)
  8. 4. 全局超时控制(默认 180s)
  9. 5. 环节隔离模式支持
  10. 6. 调用记录自动保存
  11. """
  12. import asyncio
  13. import json
  14. import logging
  15. from datetime import datetime
  16. from typing import Optional, List, Dict, AsyncGenerator, ClassVar
  17. from fastapi import APIRouter, HTTPException, Path
  18. from fastapi.responses import StreamingResponse
  19. from pydantic import BaseModel, Field, field_validator
  20. from core.debug.sse_utils import (
  21. format_sse_event,
  22. sse_generator,
  23. debug_semaphore,
  24. _running_tasks,
  25. MAX_CONCURRENT_DEBUG_TASKS,
  26. DEBUG_GLOBAL_TIMEOUT,
  27. make_trace_id,
  28. make_record_id,
  29. CHAIN_NAMES,
  30. CHAIN_STEPS_COUNT,
  31. )
  32. logger = logging.getLogger(__name__)
  33. # ============ 枚举常量 ============
  34. class ChainId:
  35. """链路标识枚举"""
  36. COMPLETENESS = "completeness"
  37. TIMELINESS = "timeliness"
  38. REFERENCE = "reference"
  39. SENSITIVE = "sensitive"
  40. SEMANTIC = "semantic"
  41. GRAMMAR = "grammar"
  42. PROFESSIONAL = "professional"
  43. class StepStatus:
  44. """步骤状态枚举"""
  45. PENDING = "pending"
  46. RUNNING = "running"
  47. SUCCESS = "success"
  48. ERROR = "error"
  49. class RecordStatus:
  50. """记录状态枚举"""
  51. SUCC = "succ"
  52. FAIL = "fail"
  53. TIMEOUT = "timeout"
  54. class ReviewType:
  55. """审查类型枚举"""
  56. BOTH = "both"
  57. NON_PARAMETER = "non_parameter"
  58. PARAMETER = "parameter"
  59. # ============ 审查链路调试 - 请求/响应模型 ============
  60. class RagParams(BaseModel):
  61. """专业性审查 RAG 参数"""
  62. review_type: str = Field(
  63. default="both",
  64. description="审查类型: both(全部) / non_parameter(非参数) / parameter(参数)",
  65. )
  66. collection_name: str = Field(
  67. default="construction_specs",
  68. description="Milvus 集合名称",
  69. )
  70. top_k: int = Field(
  71. default=5,
  72. ge=1,
  73. le=50,
  74. description="向量检索 top_k",
  75. )
  76. hybrid_top_k: int = Field(
  77. default=20,
  78. ge=1,
  79. le=100,
  80. description="混合检索 top_k",
  81. )
  82. dense_weight: float = Field(
  83. default=0.5,
  84. ge=0.0,
  85. le=1.0,
  86. description="稠密向量权重(0-1)",
  87. )
  88. parent_threshold: float = Field(
  89. default=0.3,
  90. ge=0.0,
  91. le=1.0,
  92. description="父文档增强阈值",
  93. )
  94. class DebugExecuteRequest(BaseModel):
  95. """审查调试执行请求"""
  96. chain_id: str = Field(
  97. ...,
  98. description="链路标识: completeness / timeliness / reference / sensitive / semantic / grammar / professional",
  99. )
  100. content: str = Field(
  101. ...,
  102. description="待审查的方案内容",
  103. )
  104. reference: Optional[str] = Field(
  105. default=None,
  106. description="审查参考依据",
  107. )
  108. model: Optional[str] = Field(
  109. default=None,
  110. description="覆盖模型名称(如 deepseek_v3),与 function_name 互斥",
  111. )
  112. function_name: Optional[str] = Field(
  113. default=None,
  114. description="覆盖功能名称(如 completeness_review_generate),与 model 互斥",
  115. )
  116. prompt_version: Optional[str] = Field(
  117. default=None,
  118. description="指定提示词版本,null 则使用当前激活版本",
  119. )
  120. timeout: int = Field(
  121. default=60,
  122. ge=10,
  123. le=600,
  124. description="模型调用超时(秒)",
  125. )
  126. rag_params: Optional[RagParams] = Field(
  127. default=None,
  128. description="专业性审查专属参数,其他链路忽略",
  129. )
  130. isolation_mode: bool = Field(
  131. default=False,
  132. description="是否启用环节隔离模式",
  133. )
  134. isolation_steps: List[int] = Field(
  135. default_factory=list,
  136. description="隔离模式下要执行的步骤索引列表",
  137. )
  138. manual_inputs: Dict[str, str] = Field(
  139. default_factory=dict,
  140. description="隔离模式下各步骤的手动输入,key 为步骤索引,value 为输入内容",
  141. )
  142. _VALID_CHAIN_IDS: ClassVar[set] = {
  143. ChainId.COMPLETENESS,
  144. ChainId.TIMELINESS,
  145. ChainId.REFERENCE,
  146. ChainId.SENSITIVE,
  147. ChainId.SEMANTIC,
  148. ChainId.GRAMMAR,
  149. ChainId.PROFESSIONAL,
  150. }
  151. @field_validator("chain_id")
  152. @classmethod
  153. def check_chain_id(cls, v: str) -> str:
  154. """验证 chain_id 是否为合法枚举值"""
  155. if v not in cls._VALID_CHAIN_IDS:
  156. raise ValueError(
  157. f"chain_id 必须为以下值之一: "
  158. f"{json.dumps(sorted(cls._VALID_CHAIN_IDS), ensure_ascii=False)}"
  159. )
  160. return v
  161. @field_validator("content")
  162. @classmethod
  163. def check_content_not_empty(cls, v: str) -> str:
  164. """验证 content 不能为空"""
  165. if not v or not v.strip():
  166. raise ValueError("content 不能为空")
  167. return v
  168. # ============ 提示词管理 - 请求/响应模型 ============
  169. class PromptItem(BaseModel):
  170. """提示词列表项"""
  171. name: str = Field(default="", description="提示词名称")
  172. version: str = Field(default="", description="版本号")
  173. time: str = Field(default="", description="创建时间")
  174. chain: str = Field(default="", description="所属链路")
  175. is_current: bool = Field(default=False, description="是否为当前激活版本")
  176. note: str = Field(default="", description="版本说明")
  177. class PromptListResponse(BaseModel):
  178. """提示词列表响应"""
  179. status: str = Field(default="not_implemented", description="状态")
  180. total: int = Field(default=0, description="总数")
  181. page: int = Field(default=1, description="当前页码")
  182. page_size: int = Field(default=50, description="每页条数")
  183. items: List[PromptItem] = Field(default_factory=list, description="提示词列表")
  184. chains: List[str] = Field(default_factory=list, description="可筛选的链路列表")
  185. class PromptDetailResponse(BaseModel):
  186. """提示词详情响应"""
  187. status: str = Field(default="not_implemented", description="状态")
  188. name: str = Field(default="", description="提示词名称")
  189. version: str = Field(default="", description="版本号")
  190. time: str = Field(default="", description="创建时间")
  191. chain: str = Field(default="", description="所属链路")
  192. is_current: bool = Field(default=False, description="是否为当前激活版本")
  193. system_prompt: str = Field(default="", description="系统提示词")
  194. user_prompt: str = Field(default="", description="用户提示词模板")
  195. note: str = Field(default="", description="版本说明")
  196. variables: List[str] = Field(default_factory=list, description="模板变量列表")
  197. based_on: Optional[str] = Field(default=None, description="基于哪个版本")
  198. file_path: str = Field(default="", description="文件路径")
  199. class PromptSaveRequest(BaseModel):
  200. """保存新版本请求"""
  201. name: str = Field(..., description="提示词名称")
  202. system_prompt: str = Field(..., description="系统提示词内容")
  203. user_prompt: str = Field(..., description="用户提示词模板")
  204. note: str = Field(default="", description="版本说明")
  205. set_current: bool = Field(default=True, description="是否设为当前激活版本")
  206. class PromptSaveResponse(BaseModel):
  207. """保存新版本响应"""
  208. success: bool = Field(default=True, description="是否成功")
  209. name: str = Field(default="", description="提示词名称")
  210. version: str = Field(default="", description="新版本号")
  211. time: str = Field(default="", description="创建时间")
  212. message: str = Field(default="", description="消息说明")
  213. class PromptCompareRequest(BaseModel):
  214. """版本对比请求"""
  215. name: str = Field(..., description="提示词名称")
  216. base_version: str = Field(..., description="基准版本号")
  217. target_version: str = Field(..., description="目标版本号")
  218. class PromptCompareResponse(BaseModel):
  219. """版本对比响应"""
  220. status: str = Field(default="not_implemented", description="状态")
  221. name: str = Field(default="", description="提示词名称")
  222. base_version: str = Field(default="", description="基准版本")
  223. target_version: str = Field(default="", description="目标版本")
  224. diffs: List[dict] = Field(default_factory=list, description="差异列表")
  225. class PromptActivateRequest(BaseModel):
  226. """激活版本请求"""
  227. name: str = Field(..., description="提示词名称")
  228. version: str = Field(..., description="版本号")
  229. class PromptActivateResponse(BaseModel):
  230. """激活版本响应"""
  231. success: bool = Field(default=True, description="是否成功")
  232. name: str = Field(default="", description="提示词名称")
  233. version: str = Field(default="", description="版本号")
  234. message: str = Field(default="", description="消息说明")
  235. class PromptVersionsResponse(BaseModel):
  236. """版本列表响应"""
  237. status: str = Field(default="not_implemented", description="状态")
  238. name: str = Field(default="", description="提示词名称")
  239. chain: str = Field(default="", description="所属链路")
  240. current_version: str = Field(default="", description="当前激活版本")
  241. versions: List[dict] = Field(default_factory=list, description="版本列表")
  242. # ============ 调用记录 - 请求/响应模型 ============
  243. class CallRecordItem(BaseModel):
  244. """调用记录列表项"""
  245. id: str = Field(default="", description="记录 ID")
  246. time: str = Field(default="", description="调用时间")
  247. chain: str = Field(default="", description="链路标识")
  248. chain_name: str = Field(default="", description="链路名称")
  249. doc_ref: str = Field(default="", description="文档引用")
  250. duration: str = Field(default="", description="持续时间(格式化)")
  251. duration_ms: int = Field(default=0, description="持续时间(毫秒)")
  252. status: str = Field(default="", description="状态: succ/fail/timeout")
  253. model: str = Field(default="", description="模型名称")
  254. prompt_ver: str = Field(default="", description="提示词版本")
  255. tokens: int = Field(default=0, description="Token 消耗")
  256. result_preview: str = Field(default="", description="结果预览")
  257. class CallRecordListResponse(BaseModel):
  258. """调用记录列表响应"""
  259. status: str = Field(default="not_implemented", description="状态")
  260. total: int = Field(default=0, description="总数")
  261. page: int = Field(default=1, description="当前页码")
  262. page_size: int = Field(default=20, description="每页条数")
  263. total_pages: int = Field(default=0, description="总页数")
  264. items: List[CallRecordItem] = Field(default_factory=list, description="记录列表")
  265. chains: List[str] = Field(default_factory=list, description="可筛选的链路列表")
  266. status_counts: Dict[str, int] = Field(default_factory=dict, description="状态统计")
  267. class StepDetail(BaseModel):
  268. """步骤详情"""
  269. index: int = Field(default=0, description="步骤索引")
  270. name: str = Field(default="", description="步骤名称")
  271. duration_ms: int = Field(default=0, description="持续时间(毫秒)")
  272. status: str = Field(default="", description="状态: succ/fail/timeout")
  273. input: dict = Field(default_factory=dict, description="步骤输入")
  274. output: dict = Field(default_factory=dict, description="步骤输出")
  275. class CallRecordDetailResponse(BaseModel):
  276. """调用记录详情响应"""
  277. id: str = Field(default="", description="记录 ID")
  278. time: str = Field(default="", description="调用时间")
  279. chain: str = Field(default="", description="链路标识")
  280. chain_name: str = Field(default="", description="链路名称")
  281. doc_ref: str = Field(default="", description="文档引用")
  282. status: str = Field(default="", description="状态: succ/fail/timeout")
  283. duration_ms: int = Field(default=0, description="持续时间(毫秒)")
  284. model: str = Field(default="", description="模型名称")
  285. function_name: Optional[str] = Field(default=None, description="功能名称")
  286. prompt_ver: str = Field(default="", description="提示词版本")
  287. prompt_name: str = Field(default="", description="提示词名称")
  288. tokens: int = Field(default=0, description="Token 消耗")
  289. params: dict = Field(default_factory=dict, description="请求参数")
  290. steps: List[StepDetail] = Field(default_factory=list, description="步骤列表")
  291. result: str = Field(default="", description="审查结果")
  292. error_message: Optional[str] = Field(default=None, description="错误信息")
  293. class OverrideParams(BaseModel):
  294. """回放覆盖参数"""
  295. model: Optional[str] = Field(default=None, description="覆盖模型名称")
  296. prompt_version: Optional[str] = Field(default=None, description="覆盖提示词版本")
  297. rag_params: Optional[RagParams] = Field(default=None, description="覆盖 RAG 参数")
  298. class ReplayRequest(BaseModel):
  299. """回放调用请求"""
  300. record_id: str = Field(..., description="要回放的调用记录 ID")
  301. override_params: Optional[OverrideParams] = Field(
  302. default=None,
  303. description="覆盖原始参数,不指定则使用原始参数",
  304. )
  305. class ExportRequest(BaseModel):
  306. """导出调用记录请求"""
  307. record_ids: List[str] = Field(..., description="要导出的记录 ID 列表")
  308. format: str = Field(default="json", description="导出格式")
  309. # ============ 路由定义 ============
  310. debug_router = APIRouter(prefix="/debug", tags=["审查调试工作台"])
  311. # 注册子模块路由
  312. from .prompt_api import register_routes as register_prompt_routes
  313. from .record_api import register_routes as register_record_routes
  314. from .rag_debug_api import register_routes as register_rag_debug_routes
  315. register_prompt_routes(debug_router)
  316. register_record_routes(debug_router)
  317. register_rag_debug_routes(debug_router)
  318. # ============ 审查链路调试端点 ============
  319. @debug_router.post("/api/review/execute")
  320. async def execute_review(request: DebugExecuteRequest):
  321. """
  322. 执行审查调试(SSE 流式返回)
  323. 支持全部 7 个审查链路:
  324. - completeness: 完整性审查
  325. - timeliness: 时效性审查
  326. - reference: 规范性审查
  327. - sensitive: 敏感词检查
  328. - semantic: 语义逻辑检查
  329. - grammar: 语法检查
  330. - professional: 专业性审查
  331. 支持环节隔离模式,可单独执行指定步骤。
  332. 调试执行不进 Celery 队列,直接在请求协程中执行,支持实时 SSE 推送。
  333. 并发限制:最多 5 个调试任务同时执行。
  334. 超时控制:全局默认 180s。
  335. """
  336. # ---- 检查并发上限 ----
  337. if debug_semaphore.locked():
  338. raise HTTPException(
  339. status_code=429,
  340. detail=(
  341. f"并发调试任务数已达上限 ({MAX_CONCURRENT_DEBUG_TASKS}),"
  342. f"请等待其他任务完成后再试"
  343. ),
  344. )
  345. chain_id = request.chain_id
  346. total_steps = CHAIN_STEPS_COUNT.get(chain_id, 3)
  347. async def event_generator():
  348. async with debug_semaphore:
  349. event_queue: asyncio.Queue = asyncio.Queue()
  350. task_id = make_trace_id(chain_id)
  351. record_id = make_record_id()
  352. # 缓存任务队列供断线重连
  353. _running_tasks[task_id] = event_queue
  354. # ---- 发送 started 事件 ----
  355. yield format_sse_event("started", {
  356. "task_id": task_id,
  357. "chain_id": chain_id,
  358. "total_steps": total_steps,
  359. })
  360. # ---- 后台执行任务 ----
  361. exec_task = asyncio.create_task(
  362. _run_debug_execution(request, event_queue, task_id, record_id)
  363. )
  364. completed_data = None
  365. error_occurred = None
  366. steps_collected: List[dict] = []
  367. try:
  368. # ---- 消费队列事件 ----
  369. while True:
  370. try:
  371. event_type, data = await asyncio.wait_for(
  372. event_queue.get(),
  373. timeout=DEBUG_GLOBAL_TIMEOUT,
  374. )
  375. except asyncio.TimeoutError:
  376. error_occurred = "执行超时"
  377. yield format_sse_event("error", {
  378. "task_id": task_id,
  379. "message": f"执行超时 (>{DEBUG_GLOBAL_TIMEOUT}s)",
  380. })
  381. exec_task.cancel()
  382. break
  383. if event_type == "__end__":
  384. break
  385. # 跳过 executor 的 started(我们已发送自定义 started)
  386. if event_type == "started":
  387. continue
  388. # 收集步骤数据用于保存记录
  389. if event_type == "step_result":
  390. steps_collected.append({
  391. "index": data.get("step_index"),
  392. "name": data.get("step_name"),
  393. "status": data.get("status"),
  394. "duration_ms": int((data.get("duration") or 0) * 1000),
  395. "input": data.get("input", {}),
  396. "output": data.get("output", {}),
  397. })
  398. if event_type == "error":
  399. error_occurred = data.get("message", "未知错误")
  400. if event_type == "completed":
  401. completed_data = data
  402. yield format_sse_event(event_type, data)
  403. if event_type == "completed":
  404. break
  405. except asyncio.CancelledError:
  406. yield format_sse_event("error", {
  407. "task_id": task_id,
  408. "message": "连接已断开,任务被取消",
  409. })
  410. except Exception as exc:
  411. logger.exception("[execute_review] 事件处理异常")
  412. yield format_sse_event("error", {
  413. "task_id": task_id,
  414. "message": str(exc),
  415. })
  416. finally:
  417. _running_tasks.pop(task_id, None)
  418. if not exec_task.done():
  419. exec_task.cancel()
  420. # ---- 自动保存调用记录 ----
  421. try:
  422. await _save_debug_record(
  423. request=request,
  424. task_id=task_id,
  425. record_id=record_id,
  426. chain_id=chain_id,
  427. completed_data=completed_data,
  428. error_message=error_occurred,
  429. steps=steps_collected,
  430. )
  431. except Exception as exc:
  432. logger.warning("[execute_review] 保存调用记录失败: %s", exc)
  433. return StreamingResponse(
  434. event_generator(),
  435. media_type="text/event-stream",
  436. headers={
  437. "Cache-Control": "no-cache",
  438. "X-Accel-Buffering": "no",
  439. },
  440. )
  441. # ============ SSE 断线重连端点 ============
  442. @debug_router.get("/api/review/stream/{task_id}")
  443. async def stream_review_progress(task_id: str = Path(..., description="任务 ID")):
  444. """
  445. 重新连接获取审查调试进度(SSE 流式)
  446. 当 SSE 连接断开时,用于重新连接获取仍在执行中的任务进度。
  447. 从当前进度继续推送,不重放已完成的 step 事件。
  448. """
  449. async def event_generator():
  450. queue = _running_tasks.get(task_id)
  451. if queue is None:
  452. yield format_sse_event("error", {
  453. "task_id": task_id,
  454. "message": "任务不存在或已完成",
  455. })
  456. return
  457. # 发送 resumed 标记
  458. yield format_sse_event("started", {
  459. "task_id": task_id,
  460. "resumed": True,
  461. })
  462. # 继续从原队列消费事件
  463. while True:
  464. try:
  465. event_type, data = await asyncio.wait_for(
  466. queue.get(),
  467. timeout=DEBUG_GLOBAL_TIMEOUT,
  468. )
  469. except asyncio.TimeoutError:
  470. yield format_sse_event("error", {
  471. "task_id": task_id,
  472. "message": "重连等待超时",
  473. })
  474. break
  475. if event_type == "__end__":
  476. break
  477. yield format_sse_event(event_type, data)
  478. if event_type == "completed":
  479. break
  480. queue.task_done()
  481. return StreamingResponse(
  482. event_generator(),
  483. media_type="text/event-stream",
  484. headers={
  485. "Cache-Control": "no-cache",
  486. "X-Accel-Buffering": "no",
  487. },
  488. )
  489. # ============================================================
  490. # 内部辅助函数
  491. # ============================================================
  492. async def _run_debug_execution(
  493. request: DebugExecuteRequest,
  494. event_queue: asyncio.Queue,
  495. task_id: str,
  496. record_id: str,
  497. ) -> None:
  498. """
  499. 执行审查调试后台任务。
  500. 根据 isolation_mode 选择执行路径:
  501. - True: 使用 IsolationRunner.run_selected_steps()
  502. - False: 使用 DebugExecutor.execute()
  503. """
  504. try:
  505. if request.isolation_mode:
  506. await _run_isolation_mode(request, event_queue, task_id)
  507. else:
  508. await _run_normal_mode(request, event_queue, task_id, record_id)
  509. except asyncio.CancelledError:
  510. await event_queue.put(("error", {
  511. "task_id": task_id,
  512. "message": "执行被取消",
  513. }))
  514. except Exception as exc:
  515. logger.exception("[_run_debug_execution] 执行异常")
  516. await event_queue.put(("error", {
  517. "task_id": task_id,
  518. "message": str(exc),
  519. }))
  520. async def _run_normal_mode(
  521. request: DebugExecuteRequest,
  522. event_queue: asyncio.Queue,
  523. task_id: str,
  524. record_id: str,
  525. ) -> None:
  526. """正常模式:使用 DebugExecutor 执行"""
  527. from core.debug.executor import DebugExecutor
  528. executor = DebugExecutor()
  529. await executor.execute(request, event_queue)
  530. async def _run_isolation_mode(
  531. request: DebugExecuteRequest,
  532. event_queue: asyncio.Queue,
  533. task_id: str,
  534. ) -> None:
  535. """隔离模式:使用 IsolationRunner 执行选中的步骤"""
  536. from core.debug.isolation_runner import IsolationRunner
  537. from core.debug.step_dispatcher import CHAIN_STEPS
  538. runner = IsolationRunner()
  539. chain_id = request.chain_id
  540. params = {
  541. "content": request.content,
  542. "reference": request.reference or "",
  543. "model": request.model,
  544. "function_name": request.function_name,
  545. "timeout": request.timeout,
  546. }
  547. if request.rag_params:
  548. params["rag_params"] = request.rag_params
  549. manual_inputs = request.manual_inputs or {}
  550. selected_indices = list(request.isolation_steps) if request.isolation_steps else []
  551. step_result_dicts = await runner.run_selected_steps(
  552. chain_id=chain_id,
  553. selected_indices=selected_indices,
  554. manual_inputs=manual_inputs,
  555. **params,
  556. )
  557. # 从 step_def 获取 phase 信息
  558. step_defs = CHAIN_STEPS.get(chain_id, [])
  559. current_phase = None
  560. for sr in step_result_dicts:
  561. si = sr.get("index", 0)
  562. # 获取 phase 信息
  563. phase = None
  564. if si < len(step_defs):
  565. sd = step_defs[si]
  566. if hasattr(sd, "phase"):
  567. phase = sd.phase
  568. elif isinstance(sd, dict):
  569. phase = sd.get("phase")
  570. # phase_label:阶段切换时推送
  571. if phase and phase != current_phase:
  572. current_phase = phase
  573. await event_queue.put(("phase_label", {
  574. "task_id": task_id,
  575. "label": phase,
  576. }))
  577. # step_progress
  578. await event_queue.put(("step_progress", {
  579. "task_id": task_id,
  580. "step_index": si,
  581. "step_name": sr.get("name", ""),
  582. "status": "running",
  583. "duration": None,
  584. }))
  585. # step_result
  586. await event_queue.put(("step_result", {
  587. "task_id": task_id,
  588. "step_index": si,
  589. "step_name": sr.get("name", ""),
  590. "status": sr.get("status", ""),
  591. "duration": sr.get("duration", 0),
  592. "input": sr.get("input", {}),
  593. "output": sr.get("output", {}),
  594. "error": sr.get("error"),
  595. }))
  596. # completed 事件
  597. total_duration = sum(
  598. sr.get("duration", 0) or 0 for sr in step_result_dicts
  599. )
  600. success_count = sum(
  601. 1 for sr in step_result_dicts if sr.get("status") == "success"
  602. )
  603. error_count = sum(
  604. 1 for sr in step_result_dicts if sr.get("status") == "error"
  605. )
  606. skipped_count = sum(
  607. 1 for sr in step_result_dicts if sr.get("status") == "skipped"
  608. )
  609. await event_queue.put(("completed", {
  610. "task_id": task_id,
  611. "chain_id": chain_id,
  612. "total_duration": round(total_duration, 3),
  613. "record_id": "",
  614. "final_result": {
  615. "summary": (
  616. f"{success_count}/{len(step_result_dicts)} 步骤成功, "
  617. f"{error_count} 错误, {skipped_count} 跳过"
  618. ),
  619. "success_count": success_count,
  620. "error_count": error_count,
  621. "skipped_count": skipped_count,
  622. "total_steps": len(step_result_dicts),
  623. },
  624. }))
  625. async def _save_debug_record(
  626. request: DebugExecuteRequest,
  627. task_id: str,
  628. record_id: str,
  629. chain_id: str,
  630. completed_data: Optional[dict],
  631. error_message: Optional[str],
  632. steps: List[dict],
  633. ) -> None:
  634. """通过 RecordManager 保存调用记录"""
  635. from core.debug.record_manager import RecordManager
  636. duration_ms = 0
  637. final_result = ""
  638. if completed_data:
  639. duration_sec = completed_data.get("total_duration", 0) or 0
  640. duration_ms = int(duration_sec * 1000)
  641. final_result = str(completed_data.get("final_result", {}).get("summary", ""))
  642. status = "succ"
  643. if error_message:
  644. if "超时" in str(error_message):
  645. status = "timeout"
  646. else:
  647. status = "fail"
  648. record_data = {
  649. "id": record_id,
  650. "time": datetime.now().isoformat(),
  651. "chain": chain_id,
  652. "chain_name": CHAIN_NAMES.get(chain_id, ""),
  653. "doc_ref": "",
  654. "status": status,
  655. "duration_ms": duration_ms,
  656. "model": request.model or "default",
  657. "function_name": request.function_name or "",
  658. "prompt_ver": request.prompt_version or "",
  659. "prompt_name": "",
  660. "tokens": 0,
  661. "params": {
  662. "review_content": request.content,
  663. "review_references": request.reference or "",
  664. "model_override": request.model,
  665. "function_name": request.function_name,
  666. "timeout": request.timeout,
  667. },
  668. "execution_params": {
  669. "isolation_mode": request.isolation_mode,
  670. "isolation_steps": list(request.isolation_steps),
  671. "rag_params": (
  672. request.rag_params.model_dump()
  673. if request.rag_params else None
  674. ),
  675. },
  676. "steps": steps,
  677. "result": final_result,
  678. "error_message": error_message,
  679. }
  680. rm = RecordManager()
  681. await rm.save_record(record_data)