| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- # -*- coding: utf-8 -*-
- """Skill registry and dispatcher for document chat."""
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Dict, List, Type
- import yaml
- from core.document_chat.schemas import DocumentChatSkillInput, DocumentChatSkillOutput
- from core.document_chat.skills.base import BaseDocumentChatSkill
- from core.document_chat.skills.document_answer import DocumentAnswerSkill
- from core.document_chat.skills.document_modify import DocumentModifySkill
- @dataclass(frozen=True)
- class SkillDefinition:
- name: str
- description: str
- intent: str
- function_name: str
- handler_class: Type[BaseDocumentChatSkill]
- response_type: str
- def to_registry_item(self) -> Dict[str, str]:
- return {
- "name": self.name,
- "description": self.description,
- "intent": self.intent,
- "function_name": self.function_name,
- "handler_class": self.handler_class.__name__,
- "response_type": self.response_type,
- }
- class SkillDispatcher:
- """Allowlist-backed skill dispatcher."""
- _HANDLER_CLASSES: Dict[str, Type[BaseDocumentChatSkill]] = {
- "DocumentModifySkill": DocumentModifySkill,
- "DocumentAnswerSkill": DocumentAnswerSkill,
- }
- def __init__(self):
- self._definitions: Dict[str, SkillDefinition] = self._load_definitions()
- self._instances: Dict[str, BaseDocumentChatSkill] = {}
- def registry_for_prompt(self) -> List[Dict[str, str]]:
- return [definition.to_registry_item() for definition in self._definitions.values()]
- def has_skill(self, skill_name: str) -> bool:
- return skill_name in self._definitions
- async def run_skill(
- self,
- skill_name: str,
- skill_input: DocumentChatSkillInput,
- ) -> DocumentChatSkillOutput:
- if skill_name not in self._definitions:
- raise ValueError(f"Unsupported document chat skill: {skill_name}")
- skill = self._get_instance(skill_name)
- return await skill.run(skill_input)
- def _get_instance(self, skill_name: str) -> BaseDocumentChatSkill:
- if skill_name not in self._instances:
- definition = self._definitions[skill_name]
- self._instances[skill_name] = definition.handler_class(
- name=definition.name,
- function_name=definition.function_name,
- )
- return self._instances[skill_name]
- def _load_definitions(self) -> Dict[str, SkillDefinition]:
- skills_root = Path(__file__).resolve().parents[1] / "skills"
- definitions: Dict[str, SkillDefinition] = {}
- for skill_yaml in sorted(skills_root.glob("*/skill.yaml")):
- with open(skill_yaml, "r", encoding="utf-8") as handle:
- data = yaml.safe_load(handle) or {}
- definition = self._definition_from_yaml(data, skill_yaml)
- definitions[definition.name] = definition
- return definitions
- def _definition_from_yaml(self, data: dict, source: Path) -> SkillDefinition:
- required_fields = ["name", "description", "intent", "function_name", "handler_class", "response_type"]
- missing = [field for field in required_fields if not data.get(field)]
- if missing:
- raise ValueError(f"Skill配置缺少字段 {missing}: {source}")
- handler_name = str(data["handler_class"])
- handler_class = self._HANDLER_CLASSES.get(handler_name)
- if handler_class is None:
- raise ValueError(f"Skill配置使用了未注册的 handler_class: {handler_name}, source={source}")
- return SkillDefinition(
- name=str(data["name"]),
- description=str(data["description"]),
- intent=str(data["intent"]),
- function_name=str(data["function_name"]),
- handler_class=handler_class,
- response_type=str(data["response_type"]),
- )
|