skill_dispatcher.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # -*- coding: utf-8 -*-
  2. """Skill registry and dispatcher for document chat."""
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Callable, Dict, List, Type
  6. import yaml
  7. from core.document_chat.schemas import DocumentChatSkillInput, DocumentChatSkillOutput
  8. from core.document_chat.skills.base import BaseDocumentChatSkill
  9. from core.document_chat.skills.document_answer import DocumentAnswerSkill
  10. from core.document_chat.skills.document_modify import DocumentModifySkill
  11. @dataclass(frozen=True)
  12. class SkillDefinition:
  13. name: str
  14. description: str
  15. intent: str
  16. function_name: str
  17. handler_class: Type[BaseDocumentChatSkill]
  18. response_type: str
  19. def to_registry_item(self) -> Dict[str, str]:
  20. return {
  21. "name": self.name,
  22. "description": self.description,
  23. "intent": self.intent,
  24. "function_name": self.function_name,
  25. "handler_class": self.handler_class.__name__,
  26. "response_type": self.response_type,
  27. }
  28. class SkillDispatcher:
  29. """Allowlist-backed skill dispatcher."""
  30. _HANDLER_CLASSES: Dict[str, Type[BaseDocumentChatSkill]] = {
  31. "DocumentModifySkill": DocumentModifySkill,
  32. "DocumentAnswerSkill": DocumentAnswerSkill,
  33. }
  34. def __init__(self):
  35. self._definitions: Dict[str, SkillDefinition] = self._load_definitions()
  36. self._instances: Dict[str, BaseDocumentChatSkill] = {}
  37. def registry_for_prompt(self) -> List[Dict[str, str]]:
  38. return [definition.to_registry_item() for definition in self._definitions.values()]
  39. def has_skill(self, skill_name: str) -> bool:
  40. return skill_name in self._definitions
  41. async def run_skill(
  42. self,
  43. skill_name: str,
  44. skill_input: DocumentChatSkillInput,
  45. ) -> DocumentChatSkillOutput:
  46. if skill_name not in self._definitions:
  47. raise ValueError(f"Unsupported document chat skill: {skill_name}")
  48. skill = self._get_instance(skill_name)
  49. return await skill.run(skill_input)
  50. async def run_skill_stream(
  51. self,
  52. skill_name: str,
  53. skill_input: DocumentChatSkillInput,
  54. on_chunk: Callable[[str], None],
  55. ) -> DocumentChatSkillOutput:
  56. if skill_name not in self._definitions:
  57. raise ValueError(f"Unsupported document chat skill: {skill_name}")
  58. skill = self._get_instance(skill_name)
  59. return await skill.run_stream(skill_input, on_chunk)
  60. def _get_instance(self, skill_name: str) -> BaseDocumentChatSkill:
  61. if skill_name not in self._instances:
  62. definition = self._definitions[skill_name]
  63. self._instances[skill_name] = definition.handler_class(
  64. name=definition.name,
  65. function_name=definition.function_name,
  66. )
  67. return self._instances[skill_name]
  68. def _load_definitions(self) -> Dict[str, SkillDefinition]:
  69. skills_root = Path(__file__).resolve().parents[1] / "skills"
  70. definitions: Dict[str, SkillDefinition] = {}
  71. for skill_yaml in sorted(skills_root.glob("*/skill.yaml")):
  72. with open(skill_yaml, "r", encoding="utf-8") as handle:
  73. data = yaml.safe_load(handle) or {}
  74. definition = self._definition_from_yaml(data, skill_yaml)
  75. definitions[definition.name] = definition
  76. return definitions
  77. def _definition_from_yaml(self, data: dict, source: Path) -> SkillDefinition:
  78. required_fields = ["name", "description", "intent", "function_name", "handler_class", "response_type"]
  79. missing = [field for field in required_fields if not data.get(field)]
  80. if missing:
  81. raise ValueError(f"Skill配置缺少字段 {missing}: {source}")
  82. handler_name = str(data["handler_class"])
  83. handler_class = self._HANDLER_CLASSES.get(handler_name)
  84. if handler_class is None:
  85. raise ValueError(f"Skill配置使用了未注册的 handler_class: {handler_name}, source={source}")
  86. return SkillDefinition(
  87. name=str(data["name"]),
  88. description=str(data["description"]),
  89. intent=str(data["intent"]),
  90. function_name=str(data["function_name"]),
  91. handler_class=handler_class,
  92. response_type=str(data["response_type"]),
  93. )