task_source_trigger.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:niu
  5. @file: task_source_trigger.py
  6. @date:2026/1/22 16:18
  7. @desc:
  8. """
  9. from typing import Dict
  10. from django.db import transaction
  11. from django.db.models import QuerySet
  12. from django.utils.translation import gettext_lazy as _
  13. from rest_framework import serializers
  14. from application.models import Application
  15. from common.exception.app_exception import AppApiException
  16. from tools.models import Tool
  17. from trigger.models import TriggerTypeChoices, Trigger, TriggerTaskTypeChoices, TriggerTask
  18. from trigger.serializers.trigger import TriggerModelSerializer, TriggerSerializer, ApplicationTriggerTaskSerializer, \
  19. ToolTriggerTaskSerializer, TriggerTaskModelSerializer
  20. class TaskSourceTriggerTaskEditRequest(serializers.Serializer):
  21. meta = serializers.DictField(default=dict, required=False)
  22. parameter = serializers.DictField(default=dict, required=False)
  23. class TaskSourceTriggerEditRequest(serializers.Serializer):
  24. name = serializers.CharField(required=False, label=_('trigger name'))
  25. desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_('trigger description'))
  26. trigger_type = serializers.ChoiceField(required=False, choices=TriggerTypeChoices)
  27. trigger_setting = serializers.DictField(required=False, label=_("trigger setting"))
  28. meta = serializers.DictField(default=dict, required=False)
  29. trigger_task = TaskSourceTriggerTaskEditRequest(many=True, required=False)
  30. class TaskSourceTriggerSerializer(serializers.Serializer):
  31. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  32. user_id = serializers.UUIDField(required=True, label=_("User ID"))
  33. def insert(self, instance, with_valid=True):
  34. if with_valid:
  35. self.is_valid(raise_exception=True)
  36. if not len(instance.get("trigger_task")) == 1:
  37. raise AppApiException(500, _('Trigger task number must be one'))
  38. source_id = instance.get('source_id')
  39. source_type = instance.get('source_type')
  40. source_trigger_task = instance.get('trigger_task')[0]
  41. if not (instance.get('source_id') == source_id and source_trigger_task.get('source_type') == source_type):
  42. raise AppApiException(500, _('Incorrect trigger task'))
  43. return TriggerSerializer(data={
  44. 'workspace_id': self.data.get('workspace_id'),
  45. 'user_id': self.data.get('user_id')
  46. }).insert(instance, with_valid=True)
  47. class TaskSourceTriggerOperateSerializer(serializers.Serializer):
  48. trigger_id = serializers.UUIDField(required=True, label=_('trigger id'))
  49. workspace_id = serializers.CharField(required=False, label=_('workspace id'))
  50. source_type = serializers.CharField(required=True, label=_('source type'))
  51. source_id = serializers.CharField(required=True, label=_('source id'))
  52. def is_valid(self, *, raise_exception=False):
  53. super().is_valid(raise_exception=True)
  54. workspace_id = self.data.get('workspace_id')
  55. query_set = QuerySet(Trigger).filter(id=self.data.get('trigger_id'))
  56. if workspace_id:
  57. query_set = query_set.filter(workspace_id=workspace_id)
  58. if not query_set.exists():
  59. raise AppApiException(500, _('Trigger id does not exist'))
  60. def one(self, with_valid=True):
  61. if with_valid:
  62. self.is_valid()
  63. trigger_id = self.data.get('trigger_id')
  64. workspace_id = self.data.get('workspace_id')
  65. source_id = self.data.get('source_id')
  66. source_type = self.data.get('source_type')
  67. trigger = QuerySet(Trigger).filter(workspace_id=workspace_id, id=trigger_id).first()
  68. trigger_task = TriggerTaskModelSerializer(TriggerTask.objects.filter(
  69. trigger_id=trigger_id, source_id=source_id, source_type=source_type).first()).data
  70. if source_type == TriggerTaskTypeChoices.APPLICATION:
  71. application_task = ApplicationTriggerTaskSerializer(
  72. Application.objects.filter(workspace_id=workspace_id, id=source_id).first()).data
  73. return {
  74. **TriggerModelSerializer(trigger).data,
  75. 'trigger_task': trigger_task,
  76. 'application_task': application_task,
  77. }
  78. if source_type == TriggerTaskTypeChoices.TOOL:
  79. tool_task = ToolTriggerTaskSerializer(
  80. Tool.objects.filter(workspace_id=workspace_id, id=source_id).first()).data
  81. return {
  82. **TriggerModelSerializer(trigger).data,
  83. 'trigger_task': trigger_task,
  84. 'tool_task': tool_task,
  85. }
  86. @transaction.atomic
  87. def edit(self, instance: Dict, with_valid=True):
  88. from trigger.handler.simple_tools import deploy, undeploy
  89. if with_valid:
  90. self.is_valid(raise_exception=True)
  91. serializer = TaskSourceTriggerEditRequest(data=instance)
  92. serializer.is_valid(raise_exception=True)
  93. valid_data = serializer.validated_data
  94. trigger_id = self.data.get('trigger_id')
  95. workspace_id = self.data.get('workspace_id')
  96. source_id = self.data.get('source_id')
  97. source_type = self.data.get('source_type')
  98. trigger = Trigger.objects.filter(workspace_id=workspace_id, id=trigger_id).first()
  99. if not trigger:
  100. raise serializers.ValidationError(_('Trigger not found'))
  101. task_source_trigger_edit_field_list = ['name', 'desc', 'trigger_type', 'trigger_setting', 'meta']
  102. trigger_deploy_edit_field_list = ['trigger_type', 'trigger_setting']
  103. need_redeploy = any(field in instance for field in trigger_deploy_edit_field_list)
  104. for field in task_source_trigger_edit_field_list:
  105. if field in valid_data:
  106. setattr(trigger, field, valid_data.get(field))
  107. trigger.save()
  108. trigger_task = valid_data.get('trigger_task')
  109. if trigger_task is not None:
  110. # 检查是否为空列表
  111. if not trigger_task:
  112. raise serializers.ValidationError(_('Trigger must have at least one task'))
  113. TriggerTask.objects.filter(
  114. source_id=source_id,
  115. source_type=source_type,
  116. trigger_id=trigger_id
  117. ).update(parameter=trigger_task[0].get("parameter"), meta=trigger_task[0].get("meta"))
  118. else:
  119. # 用户没提交 trigger_task 字段,确保数据库中有 task
  120. if not TriggerTask.objects.filter(trigger_id=trigger_id).exists():
  121. raise serializers.ValidationError(_('Trigger must have at least one task'))
  122. if need_redeploy:
  123. if trigger.is_active and trigger.trigger_type == 'SCHEDULED':
  124. deploy(TriggerModelSerializer(trigger).data, **{})
  125. else:
  126. undeploy(TriggerModelSerializer(trigger).data, **{})
  127. return self.one()
  128. # 删除的是当前trigger_id+source_id+source_type对应的task
  129. @transaction.atomic
  130. def delete(self):
  131. from trigger.handler.simple_tools import undeploy
  132. self.is_valid(raise_exception=True)
  133. trigger_id = self.data.get('trigger_id')
  134. workspace_id = self.data.get('workspace_id')
  135. source_id = self.data.get('source_id')
  136. source_type = self.data.get('source_type')
  137. trigger = Trigger.objects.filter(workspace_id=workspace_id, id=trigger_id).first()
  138. if not trigger:
  139. raise AppApiException(404, _('Trigger not found'))
  140. delete_count = TriggerTask.objects.filter(trigger_id=trigger_id, source_id=source_id,
  141. source_type=source_type).delete()[0]
  142. if delete_count == 0:
  143. raise AppApiException(404, _('Task not found'))
  144. has_other_tasks = TriggerTask.objects.filter(trigger_id=trigger_id).exists()
  145. undeploy(TriggerModelSerializer(trigger).data, **{})
  146. if not has_other_tasks:
  147. trigger.delete()
  148. return True
  149. class TaskSourceTriggerListSerializer(serializers.Serializer):
  150. workspace_id = serializers.CharField(required=True, label=_('workspace id'))
  151. source_type = serializers.CharField(required=True, label=_('source type'))
  152. source_id = serializers.CharField(required=True, label=_('source id'))
  153. def list(self, with_valid=True):
  154. if with_valid:
  155. self.is_valid(raise_exception=True)
  156. triggers = Trigger.objects.filter(workspace_id=self.data.get("workspace_id"),
  157. triggertask__source_id=self.data.get("source_id"),
  158. triggertask__source_type=self.data.get("source_type"),
  159. is_active=True
  160. ).distinct()
  161. return [TriggerModelSerializer(trigger).data for trigger in triggers]