# -*- coding: utf-8 -*- """文档对话技能注册表与分发器。""" from dataclasses import dataclass from pathlib import Path from typing import Callable, Dict, List, Type import yaml from core.document_chat.schemas import DocumentChatSkillInput, DocumentChatSkillOutput from core.document_chat.skills.base import BaseDocumentChatSkill from core.document_chat.skills.document_answer import DocumentAnswerSkill from core.document_chat.skills.document_modify import DocumentModifySkill @dataclass(frozen=True) class SkillDefinition: name: str description: str intent: str function_name: str handler_class: Type[BaseDocumentChatSkill] response_type: str def to_registry_item(self) -> Dict[str, str]: return { "name": self.name, "description": self.description, "intent": self.intent, "function_name": self.function_name, "handler_class": self.handler_class.__name__, "response_type": self.response_type, } class SkillDispatcher: """基于白名单的技能分发器。""" _HANDLER_CLASSES: Dict[str, Type[BaseDocumentChatSkill]] = { "DocumentModifySkill": DocumentModifySkill, "DocumentAnswerSkill": DocumentAnswerSkill, } def __init__(self): self._definitions: Dict[str, SkillDefinition] = self._load_definitions() self._instances: Dict[str, BaseDocumentChatSkill] = {} def registry_for_prompt(self) -> List[Dict[str, str]]: return [definition.to_registry_item() for definition in self._definitions.values()] def has_skill(self, skill_name: str) -> bool: return skill_name in self._definitions async def run_skill( self, skill_name: str, skill_input: DocumentChatSkillInput, ) -> DocumentChatSkillOutput: if skill_name not in self._definitions: raise ValueError(f"Unsupported document chat skill: {skill_name}") skill = self._get_instance(skill_name) return await skill.run(skill_input) async def run_skill_stream( self, skill_name: str, skill_input: DocumentChatSkillInput, on_chunk: Callable[[str], None], ) -> DocumentChatSkillOutput: if skill_name not in self._definitions: raise ValueError(f"Unsupported document chat skill: {skill_name}") skill = self._get_instance(skill_name) return await skill.run_stream(skill_input, on_chunk) def _get_instance(self, skill_name: str) -> BaseDocumentChatSkill: if skill_name not in self._instances: definition = self._definitions[skill_name] self._instances[skill_name] = definition.handler_class( name=definition.name, function_name=definition.function_name, ) return self._instances[skill_name] def _load_definitions(self) -> Dict[str, SkillDefinition]: skills_root = Path(__file__).resolve().parents[1] / "skills" definitions: Dict[str, SkillDefinition] = {} for skill_yaml in sorted(skills_root.glob("*/skill.yaml")): with open(skill_yaml, "r", encoding="utf-8") as handle: data = yaml.safe_load(handle) or {} definition = self._definition_from_yaml(data, skill_yaml) definitions[definition.name] = definition return definitions def _definition_from_yaml(self, data: dict, source: Path) -> SkillDefinition: required_fields = ["name", "description", "intent", "function_name", "handler_class", "response_type"] missing = [field for field in required_fields if not data.get(field)] if missing: raise ValueError(f"Skill配置缺少字段 {missing}: {source}") handler_name = str(data["handler_class"]) handler_class = self._HANDLER_CLASSES.get(handler_name) if handler_class is None: raise ValueError(f"Skill配置使用了未注册的 handler_class: {handler_name}, source={source}") return SkillDefinition( name=str(data["name"]), description=str(data["description"]), intent=str(data["intent"]), function_name=str(data["function_name"]), handler_class=handler_class, response_type=str(data["response_type"]), )