| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- """
- 步骤定义与调度器
- 提供 StepDefinition 数据类、CHAIN_STEPS 全局定义表和 StepDispatcher 调度器。
- 步骤定义与 executor.py 的 CHAIN_STEPS 一一对应,并扩展 is_isolatable / requires_previous 字段。
- 用法:
- steps = StepDispatcher.get_steps("completeness")
- isolated = StepDispatcher.get_isolation_steps("completeness", [0, 2])
- ctx = StepDispatcher.get_step_context("completeness", 1)
- """
- from dataclasses import dataclass, field, asdict
- from typing import List, Optional, Dict
- # ============================================================
- # StepDefinition 数据类
- # ============================================================
- @dataclass
- class StepDefinition:
- """单步定义"""
- index: int # 步骤索引,从 0 开始
- name: str # 步骤名称
- phase: Optional[str] = None # 阶段名称,如 "RAG 召回阶段" / "AI 审查阶段"
- is_isolatable: bool = True # 是否可在隔离模式下独立执行
- requires_previous: bool = True # 是否依赖前一步的输出
- def to_dict(self) -> dict:
- """转为字典(与 executor.py 的 dict 格式兼容)"""
- return {
- "index": self.index,
- "name": self.name,
- "phase": self.phase,
- }
- # ============================================================
- # CHAIN_STEPS 定义表
- # ============================================================
- # 与 executor.py 的 CHAIN_STEPS 结构一一对应,扩展隔离属性
- # 依赖关系:直调 LLM 链路各步依次依赖,professional 链路每步依赖上一步
- CHAIN_STEPS: Dict[str, List[StepDefinition]] = {
- # ---- 直调 LLM 链路(6 个):各 3 步 ----
- "completeness": [
- StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
- ],
- "timeliness": [
- StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
- ],
- "reference": [
- StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
- ],
- "sensitive": [
- StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
- ],
- "semantic": [
- StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
- ],
- "grammar": [
- StepDefinition(0, "Prompt 渲染", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "LLM 调用", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "结果解析", is_isolatable=True, requires_previous=True),
- ],
- # ---- 专业性审查链路(RAG + AI,7 步) ----
- "professional": [
- StepDefinition(0, "查询提取", phase="RAG 召回阶段", is_isolatable=True, requires_previous=False),
- StepDefinition(1, "实体增强检索", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True),
- StepDefinition(2, "父文档增强", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True),
- StepDefinition(3, "结果提取", phase="RAG 召回阶段", is_isolatable=True, requires_previous=True),
- StepDefinition(4, "非参数合规审查", phase="AI 审查阶段", is_isolatable=True, requires_previous=True),
- StepDefinition(5, "参数合规审查", phase="AI 审查阶段", is_isolatable=True, requires_previous=True),
- StepDefinition(6, "结果汇总", phase="AI 审查阶段", is_isolatable=False, requires_previous=True),
- ],
- }
- VALID_CHAIN_IDS = set(CHAIN_STEPS.keys())
- # 步骤依赖映射(与 executor.py 的 _STEP_DEPS 一一对应)
- # key=step_index, value=依赖的上一步索引(None 表示无依赖)
- _STEP_DEPS: Dict[str, Dict[int, Optional[int]]] = {
- chain: {s.index: (s.index - 1 if s.index > 0 else None) for s in steps}
- for chain, steps in CHAIN_STEPS.items()
- }
- # ============================================================
- # StepDispatcher 调度器
- # ============================================================
- class StepDispatcher:
- """步骤调度器"""
- CHAIN_STEPS = CHAIN_STEPS
- @classmethod
- def get_steps(cls, chain_id: str) -> List[StepDefinition]:
- """
- 获取指定链路的全量步骤定义。
- Args:
- chain_id: 链路标识。
- Returns:
- List[StepDefinition]: 步骤定义列表。
- Raises:
- ValueError: 如果 chain_id 不存在。
- """
- if chain_id not in CHAIN_STEPS:
- raise ValueError(
- f"未知的 chain_id: '{chain_id}'。"
- f" 有效值: {sorted(VALID_CHAIN_IDS)}"
- )
- return list(CHAIN_STEPS[chain_id])
- @classmethod
- def get_isolation_steps(
- cls,
- chain_id: str,
- selected_indices: List[int],
- ) -> List[StepDefinition]:
- """
- 获取隔离模式下要执行的步骤定义。
- 过滤掉不存在的索引(记录 warning),仅返回选中的 + 自动前向传播的步骤。
- Args:
- chain_id: 链路标识。
- selected_indices: 用户选中的步骤索引列表。
- Returns:
- List[StepDefinition]: 过滤后的步骤定义列表。
- """
- all_steps = cls.get_steps(chain_id)
- if not selected_indices:
- return []
- selected = set()
- for idx in selected_indices:
- # 过滤不存在的索引
- if idx < 0 or idx >= len(all_steps):
- import logging
- logging.getLogger(__name__).warning(
- f"[StepDispatcher] isolation_steps 包含超出范围的索引 {idx}"
- f"({chain_id} 共 {len(all_steps)} 步),已自动过滤。"
- )
- continue
- selected.add(idx)
- # 按索引排序返回
- return [all_steps[i] for i in sorted(selected)]
- @classmethod
- def get_step_context(cls, chain_id: str, step_index: int) -> dict:
- """
- 获取指定步骤所需的上下文参数说明。
- Args:
- chain_id: 链路标识。
- step_index: 步骤索引。
- Returns:
- dict: 包含 context_name、required_params 等信息的字典。
- Raises:
- ValueError: 如果 chain_id 或 step_index 不存在。
- """
- if chain_id not in CHAIN_STEPS:
- raise ValueError(f"未知的 chain_id: '{chain_id}'")
- all_steps = CHAIN_STEPS[chain_id]
- if step_index < 0 or step_index >= len(all_steps):
- raise ValueError(
- f"step_index {step_index} 超出范围"
- f"({chain_id} 共 {len(all_steps)} 步)"
- )
- step_def = all_steps[step_index]
- # 判断链路类型
- is_professional = chain_id == "professional"
- if is_professional:
- return cls._build_professional_context(step_def, step_index)
- else:
- return cls._build_direct_llm_context(step_def, step_index)
- @classmethod
- def _build_direct_llm_context(
- cls, step_def: StepDefinition, step_index: int
- ) -> dict:
- """构建直调 LLM 链路的步骤上下文"""
- contexts = {
- 0: {
- "context_name": "prompt_rendering",
- "description": "渲染提示词模板,生成 system_prompt 和 user_prompt",
- "required_params": ["review_content", "review_references"],
- "optional_params": ["prompt_version"],
- "produced_outputs": ["system_prompt", "user_prompt"],
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 1: {
- "context_name": "llm_invocation",
- "description": "调用 LLM 模型执行审查",
- "required_params": ["prompt_template", "trace_id"],
- "optional_params": ["model", "function_name", "timeout"],
- "produced_outputs": ["raw_response"],
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 2: {
- "context_name": "result_parsing",
- "description": "解析 LLM 返回的原始响应为结构化审查结果",
- "required_params": ["raw_response"],
- "optional_params": [],
- "produced_outputs": ["parsed_result"],
- "can_run_in_isolation": step_def.is_isolatable,
- },
- }
- return contexts.get(step_index, {
- "context_name": "unknown",
- "description": step_def.name,
- "required_params": [],
- "optional_params": [],
- "produced_outputs": [],
- "can_run_in_isolation": step_def.is_isolatable,
- })
- @classmethod
- def _build_professional_context(
- cls, step_def: StepDefinition, step_index: int
- ) -> dict:
- """构建专业性审查的步骤上下文"""
- contexts = {
- 0: {
- "context_name": "query_extraction",
- "description": "从审查内容中提取检索查询对",
- "required_params": ["review_content"],
- "optional_params": [],
- "produced_outputs": ["query_pairs"],
- "phase": "RAG 召回阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 1: {
- "context_name": "entity_enhanced_retrieval",
- "description": "实体增强向量检索,从知识库召回相关规范",
- "required_params": ["query_pairs"],
- "optional_params": ["top_k", "hybrid_top_k"],
- "produced_outputs": ["bfp_results"],
- "phase": "RAG 召回阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 2: {
- "context_name": "parent_document_enhancement",
- "description": "父文档增强,补充召回片段的上下文",
- "required_params": ["bfp_results"],
- "optional_params": ["score_threshold", "max_parents_per_pair"],
- "produced_outputs": ["enhanced_results"],
- "phase": "RAG 召回阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 3: {
- "context_name": "result_extraction",
- "description": "从增强结果中提取最终检索对",
- "required_params": ["enhanced_results", "query_pairs"],
- "optional_params": ["score_threshold"],
- "produced_outputs": ["entity_results"],
- "phase": "RAG 召回阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 4: {
- "context_name": "non_parameter_compliance",
- "description": "非参数合规审查(LLM 审查)",
- "required_params": ["entity_results", "review_content"],
- "optional_params": ["review_type"],
- "produced_outputs": ["non_parameter_results"],
- "phase": "AI 审查阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 5: {
- "context_name": "parameter_compliance",
- "description": "参数合规审查(LLM 审查)",
- "required_params": ["entity_results", "review_content"],
- "optional_params": ["review_type"],
- "produced_outputs": ["parameter_results"],
- "phase": "AI 审查阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- 6: {
- "context_name": "result_summary",
- "description": "汇总非参数和参数合规审查结果",
- "required_params": ["non_parameter_results", "parameter_results"],
- "optional_params": [],
- "produced_outputs": ["final_summary"],
- "phase": "AI 审查阶段",
- "can_run_in_isolation": step_def.is_isolatable,
- },
- }
- return contexts.get(step_index, {
- "context_name": "unknown",
- "description": step_def.name,
- "required_params": [],
- "optional_params": [],
- "produced_outputs": [],
- "phase": step_def.phase,
- "can_run_in_isolation": step_def.is_isolatable,
- })
- @classmethod
- def validate_chain_id(cls, chain_id: str) -> bool:
- """验证 chain_id 是否有效"""
- return chain_id in VALID_CHAIN_IDS
- @classmethod
- def get_step_deps(cls, chain_id: str) -> Dict[int, Optional[int]]:
- """获取链路的步骤依赖映射"""
- return dict(_STEP_DEPS.get(chain_id, {}))
|