skill_dispatcher.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # -*- coding: utf-8 -*-
  2. """文档对话技能注册表与分发器。"""
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import Callable, Dict, List, Type
  6. import yaml
  7. from core.document_chat.schemas import DocumentChatSkillInput, DocumentChatSkillOutput
  8. from core.document_chat.skills.base import BaseDocumentChatSkill
  9. from core.document_chat.skills.document_answer import DocumentAnswerSkill
  10. from core.document_chat.skills.document_modify import DocumentModifySkill
  11. @dataclass(frozen=True)
  12. class SkillDefinition:
  13. name: str
  14. description: str
  15. intent: str
  16. function_name: str
  17. handler_class: Type[BaseDocumentChatSkill]
  18. response_type: str
  19. def to_registry_item(self) -> Dict[str, str]:
  20. return {
  21. "name": self.name,
  22. "description": self.description,
  23. "intent": self.intent,
  24. "function_name": self.function_name,
  25. "handler_class": self.handler_class.__name__,
  26. "response_type": self.response_type,
  27. }
  28. class SkillDispatcher:
  29. """基于白名单的技能分发器。"""
  30. _HANDLER_CLASSES: Dict[str, Type[BaseDocumentChatSkill]] = {
  31. "DocumentModifySkill": DocumentModifySkill,
  32. "DocumentAnswerSkill": DocumentAnswerSkill,
  33. }
  34. def __init__(self):
  35. self._definitions: Dict[str, SkillDefinition] = self._load_definitions()
  36. self._instances: Dict[str, BaseDocumentChatSkill] = {}
  37. def registry_for_prompt(self) -> List[Dict[str, str]]:
  38. return [definition.to_registry_item() for definition in self._definitions.values()]
  39. def has_skill(self, skill_name: str) -> bool:
  40. return skill_name in self._definitions
  41. async def run_skill(
  42. self,
  43. skill_name: str,
  44. skill_input: DocumentChatSkillInput,
  45. ) -> DocumentChatSkillOutput:
  46. if skill_name not in self._definitions:
  47. raise ValueError(f"Unsupported document chat skill: {skill_name}")
  48. skill = self._get_instance(skill_name)
  49. return await skill.run(skill_input)
  50. async def run_skill_stream(
  51. self,
  52. skill_name: str,
  53. skill_input: DocumentChatSkillInput,
  54. on_chunk: Callable[[str], None],
  55. ) -> DocumentChatSkillOutput:
  56. if skill_name not in self._definitions:
  57. raise ValueError(f"Unsupported document chat skill: {skill_name}")
  58. skill = self._get_instance(skill_name)
  59. return await skill.run_stream(skill_input, on_chunk)
  60. def _get_instance(self, skill_name: str) -> BaseDocumentChatSkill:
  61. if skill_name not in self._instances:
  62. definition = self._definitions[skill_name]
  63. self._instances[skill_name] = definition.handler_class(
  64. name=definition.name,
  65. function_name=definition.function_name,
  66. )
  67. return self._instances[skill_name]
  68. def _load_definitions(self) -> Dict[str, SkillDefinition]:
  69. skills_root = Path(__file__).resolve().parents[1] / "skills"
  70. definitions: Dict[str, SkillDefinition] = {}
  71. for skill_yaml in sorted(skills_root.glob("*/skill.yaml")):
  72. with open(skill_yaml, "r", encoding="utf-8") as handle:
  73. data = yaml.safe_load(handle) or {}
  74. definition = self._definition_from_yaml(data, skill_yaml)
  75. definitions[definition.name] = definition
  76. return definitions
  77. def _definition_from_yaml(self, data: dict, source: Path) -> SkillDefinition:
  78. required_fields = ["name", "description", "intent", "function_name", "handler_class", "response_type"]
  79. missing = [field for field in required_fields if not data.get(field)]
  80. if missing:
  81. raise ValueError(f"Skill配置缺少字段 {missing}: {source}")
  82. handler_name = str(data["handler_class"])
  83. handler_class = self._HANDLER_CLASSES.get(handler_name)
  84. if handler_class is None:
  85. raise ValueError(f"Skill配置使用了未注册的 handler_class: {handler_name}, source={source}")
  86. return SkillDefinition(
  87. name=str(data["name"]),
  88. description=str(data["description"]),
  89. intent=str(data["intent"]),
  90. function_name=str(data["function_name"]),
  91. handler_class=handler_class,
  92. response_type=str(data["response_type"]),
  93. )