""" 步骤定义与调度器 提供 StepDefinition 数据类、CHAIN_STEPS 全局定义表和 StepDispatcher 调度器。 步骤定义与 executor.py 的 CHAIN_STEPS 一一对应,并扩展 is_isolatable / requires_previous 字段。 用法: steps = StepDispatcher.get_steps("completeness") isolated = StepDispatcher.get_isolation_steps("completeness", [0, 2]) ctx = StepDispatcher.get_step_context("completeness", 1) """ from dataclasses import dataclass, field, asdict from typing import List, Optional, Dict # ============================================================ # StepDefinition 数据类 # ============================================================ @dataclass class StepDefinition: """单步定义""" index: int # 步骤索引,从 0 开始 name: str # 步骤名称 phase: Optional[str] = None # 阶段名称,如 "RAG 召回阶段" / "AI 审查阶段" is_isolatable: bool = True # 是否可在隔离模式下独立执行 requires_previous: bool = True # 是否依赖前一步的输出 def to_dict(self) -> dict: """转为字典(与 executor.py 的 dict 格式兼容)""" return { "index": self.index, "name": self.name, "phase": self.phase, } # ============================================================ # CHAIN_STEPS 定义表 # ============================================================ # 与 executor.py 的 CHAIN_STEPS 结构一一对应,扩展隔离属性 # 依赖关系:直调 LLM 链路各步依次依赖,professional 链路每步依赖上一步 CHAIN_STEPS: Dict[str, List[StepDefinition]] = { # ---- 直调 LLM 链路(6 个):各 3 步 ---- "completeness": [ StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False), StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True), StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True), ], "timeliness": [ StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False), StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True), StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True), ], "reference": [ StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False), StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True), StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True), ], "sensitive": [ StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False), StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True), StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True), ], "semantic": [ StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False), StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True), StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True), ], "grammar": [ StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False), StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True), StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True), ], # ---- 专业性审查链路(RAG + AI,7 步) ---- "professional": [ StepDefinition(0, "查询提取", phase="RAG 召回阶段", is_isolatable=True, requires_previous=False), StepDefinition(1, "实体增强检索", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True), StepDefinition(2, "父文档增强", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True), StepDefinition(3, "结果提取", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True), StepDefinition(4, "非参数合规审查", phase="AI 审查阶段", is_isolatable=True, requires_previous=True), StepDefinition(5, "参数合规审查", phase="AI 审查阶段", is_isolatable=True, requires_previous=True), StepDefinition(6, "结果汇总", phase="AI 审查阶段", is_isolatable=False, requires_previous=True), ], } VALID_CHAIN_IDS = set(CHAIN_STEPS.keys()) # 步骤依赖映射(与 executor.py 的 _STEP_DEPS 一一对应) # key=step_index, value=依赖的上一步索引(None 表示无依赖) _STEP_DEPS: Dict[str, Dict[int, Optional[int]]] = { chain: {s.index: (s.index - 1 if s.index > 0 else None) for s in steps} for chain, steps in CHAIN_STEPS.items() } # ============================================================ # StepDispatcher 调度器 # ============================================================ class StepDispatcher: """步骤调度器""" CHAIN_STEPS = CHAIN_STEPS @classmethod def get_steps(cls, chain_id: str) -> List[StepDefinition]: """ 获取指定链路的全量步骤定义。 Args: chain_id: 链路标识。 Returns: List[StepDefinition]: 步骤定义列表。 Raises: ValueError: 如果 chain_id 不存在。 """ if chain_id not in CHAIN_STEPS: raise ValueError( f"未知的 chain_id: '{chain_id}'。" f" 有效值: {sorted(VALID_CHAIN_IDS)}" ) return list(CHAIN_STEPS[chain_id]) @classmethod def get_isolation_steps( cls, chain_id: str, selected_indices: List[int], ) -> List[StepDefinition]: """ 获取隔离模式下要执行的步骤定义。 过滤掉不存在的索引(记录 warning),仅返回选中的 + 自动前向传播的步骤。 Args: chain_id: 链路标识。 selected_indices: 用户选中的步骤索引列表。 Returns: List[StepDefinition]: 过滤后的步骤定义列表。 """ all_steps = cls.get_steps(chain_id) if not selected_indices: return [] selected = set() for idx in selected_indices: # 过滤不存在的索引 if idx < 0 or idx >= len(all_steps): import logging logging.getLogger(__name__).warning( f"[StepDispatcher] isolation_steps 包含超出范围的索引 {idx}" f"({chain_id} 共 {len(all_steps)} 步),已自动过滤。" ) continue selected.add(idx) # 按索引排序返回 return [all_steps[i] for i in sorted(selected)] @classmethod def get_step_context(cls, chain_id: str, step_index: int) -> dict: """ 获取指定步骤所需的上下文参数说明。 Args: chain_id: 链路标识。 step_index: 步骤索引。 Returns: dict: 包含 context_name、required_params 等信息的字典。 Raises: ValueError: 如果 chain_id 或 step_index 不存在。 """ if chain_id not in CHAIN_STEPS: raise ValueError(f"未知的 chain_id: '{chain_id}'") all_steps = CHAIN_STEPS[chain_id] if step_index < 0 or step_index >= len(all_steps): raise ValueError( f"step_index {step_index} 超出范围" f"({chain_id} 共 {len(all_steps)} 步)" ) step_def = all_steps[step_index] # 判断链路类型 is_professional = chain_id == "professional" if is_professional: return cls._build_professional_context(step_def, step_index) else: return cls._build_direct_llm_context(step_def, step_index) @classmethod def _build_direct_llm_context( cls, step_def: StepDefinition, step_index: int ) -> dict: """构建直调 LLM 链路的步骤上下文""" contexts = { 0: { "context_name": "prompt_rendering", "description": "渲染提示词模板,生成 system_prompt 和 user_prompt", "required_params": ["review_content", "review_references"], "optional_params": ["prompt_version"], "produced_outputs": ["system_prompt", "user_prompt"], "can_run_in_isolation": step_def.is_isolatable, }, 1: { "context_name": "llm_invocation", "description": "调用 LLM 模型执行审查", "required_params": ["prompt_template", "trace_id"], "optional_params": ["model", "function_name", "timeout"], "produced_outputs": ["raw_response"], "can_run_in_isolation": step_def.is_isolatable, }, 2: { "context_name": "result_parsing", "description": "解析 LLM 返回的原始响应为结构化审查结果", "required_params": ["raw_response"], "optional_params": [], "produced_outputs": ["parsed_result"], "can_run_in_isolation": step_def.is_isolatable, }, } return contexts.get(step_index, { "context_name": "unknown", "description": step_def.name, "required_params": [], "optional_params": [], "produced_outputs": [], "can_run_in_isolation": step_def.is_isolatable, }) @classmethod def _build_professional_context( cls, step_def: StepDefinition, step_index: int ) -> dict: """构建专业性审查的步骤上下文""" contexts = { 0: { "context_name": "query_extraction", "description": "从审查内容中提取检索查询对", "required_params": ["review_content"], "optional_params": [], "produced_outputs": ["query_pairs"], "phase": "RAG 召回阶段", "can_run_in_isolation": step_def.is_isolatable, }, 1: { "context_name": "entity_enhanced_retrieval", "description": "实体增强向量检索,从知识库召回相关规范", "required_params": ["query_pairs"], "optional_params": ["top_k", "hybrid_top_k"], "produced_outputs": ["bfp_results"], "phase": "RAG 召回阶段", "can_run_in_isolation": step_def.is_isolatable, }, 2: { "context_name": "parent_document_enhancement", "description": "父文档增强,补充召回片段的上下文", "required_params": ["bfp_results"], "optional_params": ["score_threshold", "max_parents_per_pair"], "produced_outputs": ["enhanced_results"], "phase": "RAG 召回阶段", "can_run_in_isolation": step_def.is_isolatable, }, 3: { "context_name": "result_extraction", "description": "从增强结果中提取最终检索对", "required_params": ["enhanced_results", "query_pairs"], "optional_params": ["score_threshold"], "produced_outputs": ["entity_results"], "phase": "RAG 召回阶段", "can_run_in_isolation": step_def.is_isolatable, }, 4: { "context_name": "non_parameter_compliance", "description": "非参数合规审查(LLM 审查)", "required_params": ["entity_results", "review_content"], "optional_params": ["review_type"], "produced_outputs": ["non_parameter_results"], "phase": "AI 审查阶段", "can_run_in_isolation": step_def.is_isolatable, }, 5: { "context_name": "parameter_compliance", "description": "参数合规审查(LLM 审查)", "required_params": ["entity_results", "review_content"], "optional_params": ["review_type"], "produced_outputs": ["parameter_results"], "phase": "AI 审查阶段", "can_run_in_isolation": step_def.is_isolatable, }, 6: { "context_name": "result_summary", "description": "汇总非参数和参数合规审查结果", "required_params": ["non_parameter_results", "parameter_results"], "optional_params": [], "produced_outputs": ["final_summary"], "phase": "AI 审查阶段", "can_run_in_isolation": step_def.is_isolatable, }, } return contexts.get(step_index, { "context_name": "unknown", "description": step_def.name, "required_params": [], "optional_params": [], "produced_outputs": [], "phase": step_def.phase, "can_run_in_isolation": step_def.is_isolatable, }) @classmethod def validate_chain_id(cls, chain_id: str) -> bool: """验证 chain_id 是否有效""" return chain_id in VALID_CHAIN_IDS @classmethod def get_step_deps(cls, chain_id: str) -> Dict[int, Optional[int]]: """获取链路的步骤依赖映射""" return dict(_STEP_DEPS.get(chain_id, {}))