debug_api.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767
  1. """
  2. 审查链路调试 API 端点
  3. 提供审查链路调试的 SSE 流式执行端点,以及完整的 Pydantic 请求/响应模型定义。
  4. 功能:
  5. 1. POST /debug/api/review/execute — 执行审查调试,SSE 流式返回进度
  6. 2. GET /debug/api/review/stream/{task_id} — SSE 断线重连
  7. 3. 并发控制(最大 5 个并发调试任务)
  8. 4. 全局超时控制(默认 180s)
  9. 5. 环节隔离模式支持
  10. 6. 调用记录自动保存
  11. """
  12. import asyncio
  13. import json
  14. import logging
  15. from datetime import datetime
  16. from typing import Optional, List, Dict, AsyncGenerator, ClassVar
  17. from fastapi import APIRouter, HTTPException, Path
  18. from fastapi.responses import StreamingResponse
  19. from pydantic import BaseModel, Field, field_validator
  20. from core.debug.sse_utils import (
  21. format_sse_event,
  22. sse_generator,
  23. debug_semaphore,
  24. _running_tasks,
  25. MAX_CONCURRENT_DEBUG_TASKS,
  26. DEBUG_GLOBAL_TIMEOUT,
  27. make_trace_id,
  28. make_record_id,
  29. CHAIN_NAMES,
  30. CHAIN_STEPS_COUNT,
  31. )
  32. logger = logging.getLogger(__name__)
  33. # ============ 枚举常量 ============
  34. class ChainId:
  35. """链路标识枚举"""
  36. COMPLETENESS = "completeness"
  37. TIMELINESS = "timeliness"
  38. REFERENCE = "reference"
  39. SENSITIVE = "sensitive"
  40. SEMANTIC = "semantic"
  41. GRAMMAR = "grammar"
  42. PROFESSIONAL = "professional"
  43. class StepStatus:
  44. """步骤状态枚举"""
  45. PENDING = "pending"
  46. RUNNING = "running"
  47. SUCCESS = "success"
  48. ERROR = "error"
  49. class RecordStatus:
  50. """记录状态枚举"""
  51. SUCC = "succ"
  52. FAIL = "fail"
  53. TIMEOUT = "timeout"
  54. class ReviewType:
  55. """审查类型枚举"""
  56. BOTH = "both"
  57. NON_PARAMETER = "non_parameter"
  58. PARAMETER = "parameter"
  59. # ============ 审查链路调试 - 请求/响应模型 ============
  60. class RagParams(BaseModel):
  61. """专业性审查 RAG 参数"""
  62. review_type: str = Field(
  63. default="both",
  64. description="审查类型: both(全部) / non_parameter(非参数) / parameter(参数)",
  65. )
  66. collection_name: str = Field(
  67. default="construction_specs",
  68. description="Milvus 集合名称",
  69. )
  70. top_k: int = Field(
  71. default=5,
  72. ge=1,
  73. le=50,
  74. description="向量检索 top_k",
  75. )
  76. hybrid_top_k: int = Field(
  77. default=20,
  78. ge=1,
  79. le=100,
  80. description="混合检索 top_k",
  81. )
  82. dense_weight: float = Field(
  83. default=0.5,
  84. ge=0.0,
  85. le=1.0,
  86. description="稠密向量权重(0-1)",
  87. )
  88. parent_threshold: float = Field(
  89. default=0.3,
  90. ge=0.0,
  91. le=1.0,
  92. description="父文档增强阈值",
  93. )
  94. class DebugExecuteRequest(BaseModel):
  95. """审查调试执行请求"""
  96. chain_id: str = Field(
  97. ...,
  98. description="链路标识: completeness / timeliness / reference / sensitive / semantic / grammar / professional",
  99. )
  100. content: str = Field(
  101. ...,
  102. description="待审查的方案内容",
  103. )
  104. reference: Optional[str] = Field(
  105. default=None,
  106. description="审查参考依据",
  107. )
  108. model: Optional[str] = Field(
  109. default=None,
  110. description="覆盖模型名称(如 deepseek_v3),与 function_name 互斥",
  111. )
  112. function_name: Optional[str] = Field(
  113. default=None,
  114. description="覆盖功能名称(如 completeness_review_generate),与 model 互斥",
  115. )
  116. prompt_version: Optional[str] = Field(
  117. default=None,
  118. description="指定提示词版本,null 则使用当前激活版本",
  119. )
  120. timeout: int = Field(
  121. default=60,
  122. ge=10,
  123. le=600,
  124. description="模型调用超时(秒)",
  125. )
  126. rag_params: Optional[RagParams] = Field(
  127. default=None,
  128. description="专业性审查专属参数,其他链路忽略",
  129. )
  130. isolation_mode: bool = Field(
  131. default=False,
  132. description="是否启用环节隔离模式",
  133. )
  134. isolation_steps: List[int] = Field(
  135. default_factory=list,
  136. description="隔离模式下要执行的步骤索引列表",
  137. )
  138. manual_inputs: Dict[str, str] = Field(
  139. default_factory=dict,
  140. description="隔离模式下各步骤的手动输入,key 为步骤索引,value 为输入内容",
  141. )
  142. _VALID_CHAIN_IDS: ClassVar[set] = {
  143. ChainId.COMPLETENESS,
  144. ChainId.TIMELINESS,
  145. ChainId.REFERENCE,
  146. ChainId.SENSITIVE,
  147. ChainId.SEMANTIC,
  148. ChainId.GRAMMAR,
  149. ChainId.PROFESSIONAL,
  150. }
  151. @field_validator("chain_id")
  152. @classmethod
  153. def check_chain_id(cls, v: str) -> str:
  154. """验证 chain_id 是否为合法枚举值"""
  155. if v not in cls._VALID_CHAIN_IDS:
  156. raise ValueError(
  157. f"chain_id 必须为以下值之一: "
  158. f"{json.dumps(sorted(cls._VALID_CHAIN_IDS), ensure_ascii=False)}"
  159. )
  160. return v
  161. @field_validator("content")
  162. @classmethod
  163. def check_content_not_empty(cls, v: str) -> str:
  164. """验证 content 不能为空"""
  165. if not v or not v.strip():
  166. raise ValueError("content 不能为空")
  167. return v
  168. # ============ 提示词管理 - 请求/响应模型 ============
  169. class PromptItem(BaseModel):
  170. """提示词列表项"""
  171. name: str = Field(default="", description="提示词名称")
  172. version: str = Field(default="", description="版本号")
  173. time: str = Field(default="", description="创建时间")
  174. chain: str = Field(default="", description="所属链路")
  175. is_current: bool = Field(default=False, description="是否为当前激活版本")
  176. note: str = Field(default="", description="版本说明")
  177. class PromptListResponse(BaseModel):
  178. """提示词列表响应"""
  179. status: str = Field(default="not_implemented", description="状态")
  180. total: int = Field(default=0, description="总数")
  181. page: int = Field(default=1, description="当前页码")
  182. page_size: int = Field(default=50, description="每页条数")
  183. items: List[PromptItem] = Field(default_factory=list, description="提示词列表")
  184. chains: List[str] = Field(default_factory=list, description="可筛选的链路列表")
  185. class PromptDetailResponse(BaseModel):
  186. """提示词详情响应"""
  187. status: str = Field(default="not_implemented", description="状态")
  188. name: str = Field(default="", description="提示词名称")
  189. version: str = Field(default="", description="版本号")
  190. time: str = Field(default="", description="创建时间")
  191. chain: str = Field(default="", description="所属链路")
  192. is_current: bool = Field(default=False, description="是否为当前激活版本")
  193. system_prompt: str = Field(default="", description="系统提示词")
  194. user_prompt: str = Field(default="", description="用户提示词模板")
  195. note: str = Field(default="", description="版本说明")
  196. variables: List[str] = Field(default_factory=list, description="模板变量列表")
  197. based_on: Optional[str] = Field(default=None, description="基于哪个版本")
  198. file_path: str = Field(default="", description="文件路径")
  199. class PromptSaveRequest(BaseModel):
  200. """保存新版本请求"""
  201. name: str = Field(..., description="提示词名称")
  202. system_prompt: str = Field(..., description="系统提示词内容")
  203. user_prompt: str = Field(..., description="用户提示词模板")
  204. note: str = Field(default="", description="版本说明")
  205. set_current: bool = Field(default=True, description="是否设为当前激活版本")
  206. class PromptSaveResponse(BaseModel):
  207. """保存新版本响应"""
  208. success: bool = Field(default=True, description="是否成功")
  209. name: str = Field(default="", description="提示词名称")
  210. version: str = Field(default="", description="新版本号")
  211. time: str = Field(default="", description="创建时间")
  212. message: str = Field(default="", description="消息说明")
  213. class PromptCompareRequest(BaseModel):
  214. """版本对比请求"""
  215. name: str = Field(..., description="提示词名称")
  216. base_version: str = Field(..., description="基准版本号")
  217. target_version: str = Field(..., description="目标版本号")
  218. class PromptCompareResponse(BaseModel):
  219. """版本对比响应"""
  220. status: str = Field(default="not_implemented", description="状态")
  221. name: str = Field(default="", description="提示词名称")
  222. base_version: str = Field(default="", description="基准版本")
  223. target_version: str = Field(default="", description="目标版本")
  224. diffs: List[dict] = Field(default_factory=list, description="差异列表")
  225. class PromptActivateRequest(BaseModel):
  226. """激活版本请求"""
  227. name: str = Field(..., description="提示词名称")
  228. version: str = Field(..., description="版本号")
  229. class PromptActivateResponse(BaseModel):
  230. """激活版本响应"""
  231. success: bool = Field(default=True, description="是否成功")
  232. name: str = Field(default="", description="提示词名称")
  233. version: str = Field(default="", description="版本号")
  234. message: str = Field(default="", description="消息说明")
  235. class PromptVersionsResponse(BaseModel):
  236. """版本列表响应"""
  237. status: str = Field(default="not_implemented", description="状态")
  238. name: str = Field(default="", description="提示词名称")
  239. chain: str = Field(default="", description="所属链路")
  240. current_version: str = Field(default="", description="当前激活版本")
  241. versions: List[dict] = Field(default_factory=list, description="版本列表")
  242. # ============ 调用记录 - 请求/响应模型 ============
  243. class CallRecordItem(BaseModel):
  244. """调用记录列表项"""
  245. id: str = Field(default="", description="记录 ID")
  246. time: str = Field(default="", description="调用时间")
  247. chain: str = Field(default="", description="链路标识")
  248. chain_name: str = Field(default="", description="链路名称")
  249. doc_ref: str = Field(default="", description="文档引用")
  250. duration: str = Field(default="", description="持续时间(格式化)")
  251. duration_ms: int = Field(default=0, description="持续时间(毫秒)")
  252. status: str = Field(default="", description="状态: succ/fail/timeout")
  253. model: str = Field(default="", description="模型名称")
  254. prompt_ver: str = Field(default="", description="提示词版本")
  255. tokens: int = Field(default=0, description="Token 消耗")
  256. result_preview: str = Field(default="", description="结果预览")
  257. class CallRecordListResponse(BaseModel):
  258. """调用记录列表响应"""
  259. status: str = Field(default="not_implemented", description="状态")
  260. total: int = Field(default=0, description="总数")
  261. page: int = Field(default=1, description="当前页码")
  262. page_size: int = Field(default=20, description="每页条数")
  263. total_pages: int = Field(default=0, description="总页数")
  264. items: List[CallRecordItem] = Field(default_factory=list, description="记录列表")
  265. chains: List[str] = Field(default_factory=list, description="可筛选的链路列表")
  266. status_counts: Dict[str, int] = Field(default_factory=dict, description="状态统计")
  267. class StepDetail(BaseModel):
  268. """步骤详情"""
  269. index: int = Field(default=0, description="步骤索引")
  270. name: str = Field(default="", description="步骤名称")
  271. duration_ms: int = Field(default=0, description="持续时间(毫秒)")
  272. status: str = Field(default="", description="状态: succ/fail/timeout")
  273. input: dict = Field(default_factory=dict, description="步骤输入")
  274. output: dict = Field(default_factory=dict, description="步骤输出")
  275. class CallRecordDetailResponse(BaseModel):
  276. """调用记录详情响应"""
  277. id: str = Field(default="", description="记录 ID")
  278. time: str = Field(default="", description="调用时间")
  279. chain: str = Field(default="", description="链路标识")
  280. chain_name: str = Field(default="", description="链路名称")
  281. doc_ref: str = Field(default="", description="文档引用")
  282. status: str = Field(default="", description="状态: succ/fail/timeout")
  283. duration_ms: int = Field(default=0, description="持续时间(毫秒)")
  284. model: str = Field(default="", description="模型名称")
  285. function_name: Optional[str] = Field(default=None, description="功能名称")
  286. prompt_ver: str = Field(default="", description="提示词版本")
  287. prompt_name: str = Field(default="", description="提示词名称")
  288. tokens: int = Field(default=0, description="Token 消耗")
  289. params: dict = Field(default_factory=dict, description="请求参数")
  290. steps: List[StepDetail] = Field(default_factory=list, description="步骤列表")
  291. result: str = Field(default="", description="审查结果")
  292. error_message: Optional[str] = Field(default=None, description="错误信息")
  293. class OverrideParams(BaseModel):
  294. """回放覆盖参数"""
  295. model: Optional[str] = Field(default=None, description="覆盖模型名称")
  296. prompt_version: Optional[str] = Field(default=None, description="覆盖提示词版本")
  297. rag_params: Optional[RagParams] = Field(default=None, description="覆盖 RAG 参数")
  298. class ReplayRequest(BaseModel):
  299. """回放调用请求"""
  300. record_id: str = Field(..., description="要回放的调用记录 ID")
  301. override_params: Optional[OverrideParams] = Field(
  302. default=None,
  303. description="覆盖原始参数,不指定则使用原始参数",
  304. )
  305. class ExportRequest(BaseModel):
  306. """导出调用记录请求"""
  307. record_ids: List[str] = Field(..., description="要导出的记录 ID 列表")
  308. format: str = Field(default="json", description="导出格式")
  309. # ============ 路由定义 ============
  310. debug_router = APIRouter(prefix="/debug", tags=["审查调试工作台"])
  311. # 注册子模块路由
  312. from .prompt_api import register_routes as register_prompt_routes
  313. from .record_api import register_routes as register_record_routes
  314. from .rag_debug_api import register_routes as register_rag_debug_routes
  315. register_prompt_routes(debug_router)
  316. register_record_routes(debug_router)
  317. register_rag_debug_routes(debug_router)
  318. # ============ 审查链路调试端点 ============
  319. @debug_router.post("/api/review/execute")
  320. async def execute_review(request: DebugExecuteRequest):
  321. """
  322. 启动审查调试任务,返回 task_id。
  323. 前端拿到 task_id 后通过 GET /debug/api/review/stream/{task_id} (EventSource)
  324. 接收 SSE 实时进度推送。
  325. 支持全部 7 个审查链路。
  326. 并发限制:最多 5 个调试任务同时执行。
  327. """
  328. # ---- 检查并发上限 ----
  329. if debug_semaphore.locked():
  330. raise HTTPException(
  331. status_code=429,
  332. detail=(
  333. f"并发调试任务数已达上限 ({MAX_CONCURRENT_DEBUG_TASKS}),"
  334. f"请等待其他任务完成后再试"
  335. ),
  336. )
  337. chain_id = request.chain_id
  338. total_steps = CHAIN_STEPS_COUNT.get(chain_id, 3)
  339. event_queue: asyncio.Queue = asyncio.Queue()
  340. task_id = make_trace_id(chain_id)
  341. record_id = make_record_id()
  342. _running_tasks[task_id] = event_queue
  343. # 后台启动执行任务(不 await,让它在后台运行)
  344. asyncio.create_task(_background_execute(
  345. request, event_queue, task_id, record_id, chain_id,
  346. ))
  347. return {
  348. "task_id": task_id,
  349. "chain_id": chain_id,
  350. "total_steps": total_steps,
  351. }
  352. async def _background_execute(
  353. request: DebugExecuteRequest,
  354. event_queue: asyncio.Queue,
  355. task_id: str,
  356. record_id: str,
  357. chain_id: str,
  358. ) -> None:
  359. """后台执行审查调试,通过 event_queue 推送进度给 SSE 端点。
  360. 不消费 event_queue —— 仅启动 executor 并等待其完成,
  361. 将结果持久化到调用记录。SSE 事件由 GET stream 端点独立消费。
  362. """
  363. async with debug_semaphore:
  364. try:
  365. from core.debug.executor import DebugExecutor
  366. await event_queue.put(("started", {
  367. "task_id": task_id,
  368. "chain_id": chain_id,
  369. "total_steps": CHAIN_STEPS_COUNT.get(chain_id, 3),
  370. }))
  371. executor = DebugExecutor()
  372. result = await executor.execute(request, event_queue)
  373. # 保存调用记录
  374. try:
  375. await _save_debug_record(
  376. request=request,
  377. task_id=task_id,
  378. record_id=record_id,
  379. chain_id=chain_id,
  380. completed_data=result.get("completed_data"),
  381. error_message=result.get("error_occurred"),
  382. steps=result.get("steps_collected", []),
  383. )
  384. except Exception as exc:
  385. logger.warning("[execute_review] 保存调用记录失败: %s", exc)
  386. except asyncio.CancelledError:
  387. await event_queue.put(("error", {
  388. "task_id": task_id,
  389. "message": "任务被取消",
  390. }))
  391. except Exception as exc:
  392. logger.exception("[_background_execute] 执行异常")
  393. await event_queue.put(("error", {
  394. "task_id": task_id,
  395. "message": str(exc),
  396. }))
  397. finally:
  398. # 不立即清理 _running_tasks,让 GET SSE 端点消费事件后自行清理。
  399. # 兜底:30s 后如仍无人消费则清理。
  400. async def _delayed_cleanup():
  401. await asyncio.sleep(30)
  402. if task_id in _running_tasks:
  403. logger.warning("[_background_execute] 任务 %s 30s 未被消费,强制清理", task_id)
  404. _running_tasks.pop(task_id, None)
  405. asyncio.create_task(_delayed_cleanup())
  406. # ============ SSE 断线重连端点 ============
  407. @debug_router.get("/api/review/stream/{task_id}")
  408. async def stream_review_progress(task_id: str = Path(..., description="任务 ID")):
  409. """
  410. 重新连接获取审查调试进度(SSE 流式)
  411. 当 SSE 连接断开时,用于重新连接获取仍在执行中的任务进度。
  412. 从当前进度继续推送,不重放已完成的 step 事件。
  413. """
  414. async def event_generator():
  415. queue = _running_tasks.get(task_id)
  416. if queue is None:
  417. yield format_sse_event("error", {
  418. "task_id": task_id,
  419. "message": "任务不存在或已完成",
  420. })
  421. return
  422. # 发送 resumed 标记
  423. yield format_sse_event("started", {
  424. "task_id": task_id,
  425. "resumed": True,
  426. })
  427. # 继续从原队列消费事件
  428. try:
  429. while True:
  430. try:
  431. event_type, data = await asyncio.wait_for(
  432. queue.get(),
  433. timeout=DEBUG_GLOBAL_TIMEOUT,
  434. )
  435. except asyncio.TimeoutError:
  436. yield format_sse_event("error", {
  437. "task_id": task_id,
  438. "message": "重连等待超时",
  439. })
  440. break
  441. if event_type == "__end__":
  442. break
  443. yield format_sse_event(event_type, data)
  444. if event_type in ("completed", "error"):
  445. break
  446. finally:
  447. _running_tasks.pop(task_id, None)
  448. return StreamingResponse(
  449. event_generator(),
  450. media_type="text/event-stream",
  451. headers={
  452. "Cache-Control": "no-cache",
  453. "X-Accel-Buffering": "no",
  454. },
  455. )
  456. # ============================================================
  457. # 内部辅助函数
  458. # ============================================================
  459. async def _run_debug_execution(
  460. request: DebugExecuteRequest,
  461. event_queue: asyncio.Queue,
  462. task_id: str,
  463. record_id: str,
  464. ) -> None:
  465. """
  466. 执行审查调试后台任务。
  467. 根据 isolation_mode 选择执行路径:
  468. - True: 使用 IsolationRunner.run_selected_steps()
  469. - False: 使用 DebugExecutor.execute()
  470. """
  471. try:
  472. if request.isolation_mode:
  473. await _run_isolation_mode(request, event_queue, task_id)
  474. else:
  475. await _run_normal_mode(request, event_queue, task_id, record_id)
  476. except asyncio.CancelledError:
  477. await event_queue.put(("error", {
  478. "task_id": task_id,
  479. "message": "执行被取消",
  480. }))
  481. except Exception as exc:
  482. logger.exception("[_run_debug_execution] 执行异常")
  483. await event_queue.put(("error", {
  484. "task_id": task_id,
  485. "message": str(exc),
  486. }))
  487. async def _run_normal_mode(
  488. request: DebugExecuteRequest,
  489. event_queue: asyncio.Queue,
  490. task_id: str,
  491. record_id: str,
  492. ) -> None:
  493. """正常模式:使用 DebugExecutor 执行"""
  494. from core.debug.executor import DebugExecutor
  495. executor = DebugExecutor()
  496. await executor.execute(request, event_queue)
  497. async def _run_isolation_mode(
  498. request: DebugExecuteRequest,
  499. event_queue: asyncio.Queue,
  500. task_id: str,
  501. ) -> None:
  502. """隔离模式:使用 IsolationRunner 执行选中的步骤"""
  503. from core.debug.isolation_runner import IsolationRunner
  504. from core.debug.step_dispatcher import CHAIN_STEPS
  505. runner = IsolationRunner()
  506. chain_id = request.chain_id
  507. params = {
  508. "content": request.content,
  509. "reference": request.reference or "",
  510. "model": request.model,
  511. "function_name": request.function_name,
  512. "timeout": request.timeout,
  513. }
  514. if request.rag_params:
  515. params["rag_params"] = request.rag_params
  516. manual_inputs = request.manual_inputs or {}
  517. selected_indices = list(request.isolation_steps) if request.isolation_steps else []
  518. step_result_dicts = await runner.run_selected_steps(
  519. chain_id=chain_id,
  520. selected_indices=selected_indices,
  521. manual_inputs=manual_inputs,
  522. **params,
  523. )
  524. # 从 step_def 获取 phase 信息
  525. step_defs = CHAIN_STEPS.get(chain_id, [])
  526. current_phase = None
  527. for sr in step_result_dicts:
  528. si = sr.get("index", 0)
  529. # 获取 phase 信息
  530. phase = None
  531. if si < len(step_defs):
  532. sd = step_defs[si]
  533. if hasattr(sd, "phase"):
  534. phase = sd.phase
  535. elif isinstance(sd, dict):
  536. phase = sd.get("phase")
  537. # phase_label:阶段切换时推送
  538. if phase and phase != current_phase:
  539. current_phase = phase
  540. await event_queue.put(("phase_label", {
  541. "task_id": task_id,
  542. "label": phase,
  543. }))
  544. # step_progress
  545. await event_queue.put(("step_progress", {
  546. "task_id": task_id,
  547. "step_index": si,
  548. "step_name": sr.get("name", ""),
  549. "status": "running",
  550. "duration": None,
  551. }))
  552. # step_result
  553. await event_queue.put(("step_result", {
  554. "task_id": task_id,
  555. "step_index": si,
  556. "step_name": sr.get("name", ""),
  557. "status": sr.get("status", ""),
  558. "duration": sr.get("duration", 0),
  559. "input": sr.get("input", {}),
  560. "output": sr.get("output", {}),
  561. "error": sr.get("error"),
  562. }))
  563. # completed 事件
  564. total_duration = sum(
  565. sr.get("duration", 0) or 0 for sr in step_result_dicts
  566. )
  567. success_count = sum(
  568. 1 for sr in step_result_dicts if sr.get("status") == "success"
  569. )
  570. error_count = sum(
  571. 1 for sr in step_result_dicts if sr.get("status") == "error"
  572. )
  573. skipped_count = sum(
  574. 1 for sr in step_result_dicts if sr.get("status") == "skipped"
  575. )
  576. await event_queue.put(("completed", {
  577. "task_id": task_id,
  578. "chain_id": chain_id,
  579. "total_duration": round(total_duration, 3),
  580. "record_id": "",
  581. "final_result": {
  582. "summary": (
  583. f"{success_count}/{len(step_result_dicts)} 步骤成功, "
  584. f"{error_count} 错误, {skipped_count} 跳过"
  585. ),
  586. "success_count": success_count,
  587. "error_count": error_count,
  588. "skipped_count": skipped_count,
  589. "total_steps": len(step_result_dicts),
  590. },
  591. }))
  592. async def _save_debug_record(
  593. request: DebugExecuteRequest,
  594. task_id: str,
  595. record_id: str,
  596. chain_id: str,
  597. completed_data: Optional[dict],
  598. error_message: Optional[str],
  599. steps: List[dict],
  600. ) -> None:
  601. """通过 RecordManager 保存调用记录"""
  602. from core.debug.record_manager import RecordManager
  603. duration_ms = 0
  604. final_result = ""
  605. if completed_data:
  606. duration_sec = completed_data.get("total_duration", 0) or 0
  607. duration_ms = int(duration_sec * 1000)
  608. final_result = str(completed_data.get("final_result", {}).get("summary", ""))
  609. status = "succ"
  610. if error_message:
  611. if "超时" in str(error_message):
  612. status = "timeout"
  613. else:
  614. status = "fail"
  615. record_data = {
  616. "id": record_id,
  617. "time": datetime.now().isoformat(),
  618. "chain": chain_id,
  619. "chain_name": CHAIN_NAMES.get(chain_id, ""),
  620. "doc_ref": "",
  621. "status": status,
  622. "duration_ms": duration_ms,
  623. "model": request.model or "default",
  624. "function_name": request.function_name or "",
  625. "prompt_ver": request.prompt_version or "",
  626. "prompt_name": "",
  627. "tokens": 0,
  628. "params": {
  629. "review_content": request.content,
  630. "review_references": request.reference or "",
  631. "model_override": request.model,
  632. "function_name": request.function_name,
  633. "timeout": request.timeout,
  634. },
  635. "execution_params": {
  636. "isolation_mode": request.isolation_mode,
  637. "isolation_steps": list(request.isolation_steps),
  638. "rag_params": (
  639. request.rag_params.model_dump()
  640. if request.rag_params else None
  641. ),
  642. },
  643. "steps": steps,
  644. "result": final_result,
  645. "error_message": error_message,
  646. }
  647. rm = RecordManager()
  648. await rm.save_record(record_data)