step_dispatcher.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. """
  2. 步骤定义与调度器
  3. 提供 StepDefinition 数据类、CHAIN_STEPS 全局定义表和 StepDispatcher 调度器。
  4. 步骤定义与 executor.py 的 CHAIN_STEPS 一一对应,并扩展 is_isolatable / requires_previous 字段。
  5. 用法:
  6. steps = StepDispatcher.get_steps("completeness")
  7. isolated = StepDispatcher.get_isolation_steps("completeness", [0, 2])
  8. ctx = StepDispatcher.get_step_context("completeness", 1)
  9. """
  10. from dataclasses import dataclass, field, asdict
  11. from typing import List, Optional, Dict
  12. # ============================================================
  13. # StepDefinition 数据类
  14. # ============================================================
  15. @dataclass
  16. class StepDefinition:
  17. """单步定义"""
  18. index: int # 步骤索引,从 0 开始
  19. name: str # 步骤名称
  20. phase: Optional[str] = None # 阶段名称,如 "RAG 召回阶段" / "AI 审查阶段"
  21. is_isolatable: bool = True # 是否可在隔离模式下独立执行
  22. requires_previous: bool = True # 是否依赖前一步的输出
  23. def to_dict(self) -> dict:
  24. """转为字典(与 executor.py 的 dict 格式兼容)"""
  25. return {
  26. "index": self.index,
  27. "name": self.name,
  28. "phase": self.phase,
  29. }
  30. # ============================================================
  31. # CHAIN_STEPS 定义表
  32. # ============================================================
  33. # 与 executor.py 的 CHAIN_STEPS 结构一一对应,扩展隔离属性
  34. # 依赖关系:直调 LLM 链路各步依次依赖,professional 链路每步依赖上一步
  35. CHAIN_STEPS: Dict[str, List[StepDefinition]] = {
  36. # ---- 直调 LLM 链路(6 个):各 3 步 ----
  37. "completeness": [
  38. StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
  39. StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
  40. StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
  41. ],
  42. "timeliness": [
  43. StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
  44. StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
  45. StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
  46. ],
  47. "reference": [
  48. StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
  49. StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
  50. StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
  51. ],
  52. "sensitive": [
  53. StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
  54. StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
  55. StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
  56. ],
  57. "semantic": [
  58. StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
  59. StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
  60. StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
  61. ],
  62. "grammar": [
  63. StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
  64. StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
  65. StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
  66. ],
  67. # ---- 专业性审查链路(RAG + AI,7 步) ----
  68. "professional": [
  69. StepDefinition(0, "查询提取", phase="RAG 召回阶段", is_isolatable=True, requires_previous=False),
  70. StepDefinition(1, "实体增强检索", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True),
  71. StepDefinition(2, "父文档增强", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True),
  72. StepDefinition(3, "结果提取", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True),
  73. StepDefinition(4, "非参数合规审查", phase="AI 审查阶段", is_isolatable=True, requires_previous=True),
  74. StepDefinition(5, "参数合规审查", phase="AI 审查阶段", is_isolatable=True, requires_previous=True),
  75. StepDefinition(6, "结果汇总", phase="AI 审查阶段", is_isolatable=False, requires_previous=True),
  76. ],
  77. }
  78. VALID_CHAIN_IDS = set(CHAIN_STEPS.keys())
  79. # 步骤依赖映射(与 executor.py 的 _STEP_DEPS 一一对应)
  80. # key=step_index, value=依赖的上一步索引(None 表示无依赖)
  81. _STEP_DEPS: Dict[str, Dict[int, Optional[int]]] = {
  82. chain: {s.index: (s.index - 1 if s.index > 0 else None) for s in steps}
  83. for chain, steps in CHAIN_STEPS.items()
  84. }
  85. # ============================================================
  86. # StepDispatcher 调度器
  87. # ============================================================
  88. class StepDispatcher:
  89. """步骤调度器"""
  90. CHAIN_STEPS = CHAIN_STEPS
  91. @classmethod
  92. def get_steps(cls, chain_id: str) -> List[StepDefinition]:
  93. """
  94. 获取指定链路的全量步骤定义。
  95. Args:
  96. chain_id: 链路标识。
  97. Returns:
  98. List[StepDefinition]: 步骤定义列表。
  99. Raises:
  100. ValueError: 如果 chain_id 不存在。
  101. """
  102. if chain_id not in CHAIN_STEPS:
  103. raise ValueError(
  104. f"未知的 chain_id: '{chain_id}'。"
  105. f" 有效值: {sorted(VALID_CHAIN_IDS)}"
  106. )
  107. return list(CHAIN_STEPS[chain_id])
  108. @classmethod
  109. def get_isolation_steps(
  110. cls,
  111. chain_id: str,
  112. selected_indices: List[int],
  113. ) -> List[StepDefinition]:
  114. """
  115. 获取隔离模式下要执行的步骤定义。
  116. 过滤掉不存在的索引(记录 warning),仅返回选中的 + 自动前向传播的步骤。
  117. Args:
  118. chain_id: 链路标识。
  119. selected_indices: 用户选中的步骤索引列表。
  120. Returns:
  121. List[StepDefinition]: 过滤后的步骤定义列表。
  122. """
  123. all_steps = cls.get_steps(chain_id)
  124. if not selected_indices:
  125. return []
  126. selected = set()
  127. for idx in selected_indices:
  128. # 过滤不存在的索引
  129. if idx < 0 or idx >= len(all_steps):
  130. import logging
  131. logging.getLogger(__name__).warning(
  132. f"[StepDispatcher] isolation_steps 包含超出范围的索引 {idx}"
  133. f"({chain_id} 共 {len(all_steps)} 步),已自动过滤。"
  134. )
  135. continue
  136. selected.add(idx)
  137. # 按索引排序返回
  138. return [all_steps[i] for i in sorted(selected)]
  139. @classmethod
  140. def get_step_context(cls, chain_id: str, step_index: int) -> dict:
  141. """
  142. 获取指定步骤所需的上下文参数说明。
  143. Args:
  144. chain_id: 链路标识。
  145. step_index: 步骤索引。
  146. Returns:
  147. dict: 包含 context_name、required_params 等信息的字典。
  148. Raises:
  149. ValueError: 如果 chain_id 或 step_index 不存在。
  150. """
  151. if chain_id not in CHAIN_STEPS:
  152. raise ValueError(f"未知的 chain_id: '{chain_id}'")
  153. all_steps = CHAIN_STEPS[chain_id]
  154. if step_index < 0 or step_index >= len(all_steps):
  155. raise ValueError(
  156. f"step_index {step_index} 超出范围"
  157. f"({chain_id} 共 {len(all_steps)} 步)"
  158. )
  159. step_def = all_steps[step_index]
  160. # 判断链路类型
  161. is_professional = chain_id == "professional"
  162. if is_professional:
  163. return cls._build_professional_context(step_def, step_index)
  164. else:
  165. return cls._build_direct_llm_context(step_def, step_index)
  166. @classmethod
  167. def _build_direct_llm_context(
  168. cls, step_def: StepDefinition, step_index: int
  169. ) -> dict:
  170. """构建直调 LLM 链路的步骤上下文"""
  171. contexts = {
  172. 0: {
  173. "context_name": "prompt_rendering",
  174. "description": "渲染提示词模板,生成 system_prompt 和 user_prompt",
  175. "required_params": ["review_content", "review_references"],
  176. "optional_params": ["prompt_version"],
  177. "produced_outputs": ["system_prompt", "user_prompt"],
  178. "can_run_in_isolation": step_def.is_isolatable,
  179. },
  180. 1: {
  181. "context_name": "llm_invocation",
  182. "description": "调用 LLM 模型执行审查",
  183. "required_params": ["prompt_template", "trace_id"],
  184. "optional_params": ["model", "function_name", "timeout"],
  185. "produced_outputs": ["raw_response"],
  186. "can_run_in_isolation": step_def.is_isolatable,
  187. },
  188. 2: {
  189. "context_name": "result_parsing",
  190. "description": "解析 LLM 返回的原始响应为结构化审查结果",
  191. "required_params": ["raw_response"],
  192. "optional_params": [],
  193. "produced_outputs": ["parsed_result"],
  194. "can_run_in_isolation": step_def.is_isolatable,
  195. },
  196. }
  197. return contexts.get(step_index, {
  198. "context_name": "unknown",
  199. "description": step_def.name,
  200. "required_params": [],
  201. "optional_params": [],
  202. "produced_outputs": [],
  203. "can_run_in_isolation": step_def.is_isolatable,
  204. })
  205. @classmethod
  206. def _build_professional_context(
  207. cls, step_def: StepDefinition, step_index: int
  208. ) -> dict:
  209. """构建专业性审查的步骤上下文"""
  210. contexts = {
  211. 0: {
  212. "context_name": "query_extraction",
  213. "description": "从审查内容中提取检索查询对",
  214. "required_params": ["review_content"],
  215. "optional_params": [],
  216. "produced_outputs": ["query_pairs"],
  217. "phase": "RAG 召回阶段",
  218. "can_run_in_isolation": step_def.is_isolatable,
  219. },
  220. 1: {
  221. "context_name": "entity_enhanced_retrieval",
  222. "description": "实体增强向量检索,从知识库召回相关规范",
  223. "required_params": ["query_pairs"],
  224. "optional_params": ["top_k", "hybrid_top_k"],
  225. "produced_outputs": ["bfp_results"],
  226. "phase": "RAG 召回阶段",
  227. "can_run_in_isolation": step_def.is_isolatable,
  228. },
  229. 2: {
  230. "context_name": "parent_document_enhancement",
  231. "description": "父文档增强,补充召回片段的上下文",
  232. "required_params": ["bfp_results"],
  233. "optional_params": ["score_threshold", "max_parents_per_pair"],
  234. "produced_outputs": ["enhanced_results"],
  235. "phase": "RAG 召回阶段",
  236. "can_run_in_isolation": step_def.is_isolatable,
  237. },
  238. 3: {
  239. "context_name": "result_extraction",
  240. "description": "从增强结果中提取最终检索对",
  241. "required_params": ["enhanced_results", "query_pairs"],
  242. "optional_params": ["score_threshold"],
  243. "produced_outputs": ["entity_results"],
  244. "phase": "RAG 召回阶段",
  245. "can_run_in_isolation": step_def.is_isolatable,
  246. },
  247. 4: {
  248. "context_name": "non_parameter_compliance",
  249. "description": "非参数合规审查(LLM 审查)",
  250. "required_params": ["entity_results", "review_content"],
  251. "optional_params": ["review_type"],
  252. "produced_outputs": ["non_parameter_results"],
  253. "phase": "AI 审查阶段",
  254. "can_run_in_isolation": step_def.is_isolatable,
  255. },
  256. 5: {
  257. "context_name": "parameter_compliance",
  258. "description": "参数合规审查(LLM 审查)",
  259. "required_params": ["entity_results", "review_content"],
  260. "optional_params": ["review_type"],
  261. "produced_outputs": ["parameter_results"],
  262. "phase": "AI 审查阶段",
  263. "can_run_in_isolation": step_def.is_isolatable,
  264. },
  265. 6: {
  266. "context_name": "result_summary",
  267. "description": "汇总非参数和参数合规审查结果",
  268. "required_params": ["non_parameter_results", "parameter_results"],
  269. "optional_params": [],
  270. "produced_outputs": ["final_summary"],
  271. "phase": "AI 审查阶段",
  272. "can_run_in_isolation": step_def.is_isolatable,
  273. },
  274. }
  275. return contexts.get(step_index, {
  276. "context_name": "unknown",
  277. "description": step_def.name,
  278. "required_params": [],
  279. "optional_params": [],
  280. "produced_outputs": [],
  281. "phase": step_def.phase,
  282. "can_run_in_isolation": step_def.is_isolatable,
  283. })
  284. @classmethod
  285. def validate_chain_id(cls, chain_id: str) -> bool:
  286. """验证 chain_id 是否有效"""
  287. return chain_id in VALID_CHAIN_IDS
  288. @classmethod
  289. def get_step_deps(cls, chain_id: str) -> Dict[int, Optional[int]]:
  290. """获取链路的步骤依赖映射"""
  291. return dict(_STEP_DEPS.get(chain_id, {}))