i_tool_node.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: i_function_lib_node.py
  6. @date:2024/8/8 16:21
  7. @desc:
  8. """
  9. import re
  10. from typing import Type
  11. from django.core import validators
  12. from django.utils.translation import gettext_lazy as _
  13. from rest_framework import serializers
  14. from rest_framework.utils.formatting import lazy_format
  15. from application.flow.common import WorkflowMode
  16. from application.flow.i_step_node import INode, NodeResult
  17. from common.exception.app_exception import AppApiException
  18. from common.field.common import ObjectField
  19. class InputField(serializers.Serializer):
  20. name = serializers.CharField(required=True, label=_('Variable Name'))
  21. is_required = serializers.BooleanField(required=True, label=_("Is this field required"))
  22. type = serializers.CharField(required=True, label=_("type"), validators=[
  23. validators.RegexValidator(regex=re.compile("^string|int|dict|array|float|boolean$"),
  24. message=_("The field only supports string|int|dict|array|float"), code=500)
  25. ])
  26. source = serializers.CharField(required=True, label=_("source"), validators=[
  27. validators.RegexValidator(regex=re.compile("^custom|reference$"),
  28. message=_("The field only supports custom|reference"), code=500)
  29. ])
  30. value = ObjectField(required=True, label=_("Variable Value"), model_type_list=[str, list])
  31. def is_valid(self, *, raise_exception=False):
  32. super().is_valid(raise_exception=True)
  33. is_required = self.data.get('is_required')
  34. if is_required and self.data.get('value') is None:
  35. message = lazy_format(_('{field}, this field is required.'), field=self.data.get("name"))
  36. raise AppApiException(500, message)
  37. class FunctionNodeParamsSerializer(serializers.Serializer):
  38. input_field_list = InputField(required=True, many=True)
  39. code = serializers.CharField(required=True, label=_("function"))
  40. is_result = serializers.BooleanField(required=False,
  41. label=_('Whether to return content'))
  42. def is_valid(self, *, raise_exception=False):
  43. super().is_valid(raise_exception=True)
  44. class IToolNode(INode):
  45. type = 'tool-node'
  46. support = [WorkflowMode.APPLICATION, WorkflowMode.APPLICATION_LOOP, WorkflowMode.KNOWLEDGE,
  47. WorkflowMode.KNOWLEDGE_LOOP, WorkflowMode.TOOL, WorkflowMode.TOOL_LOOP]
  48. def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
  49. return FunctionNodeParamsSerializer
  50. def _run(self):
  51. return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
  52. def execute(self, input_field_list, code, **kwargs) -> NodeResult:
  53. pass