base_tool_task.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: application_task.py
  6. @date:2026/1/14 19:14
  7. @desc:
  8. """
  9. import json
  10. import time
  11. import traceback
  12. import uuid_utils.compat as uuid
  13. from django.db.models import QuerySet
  14. from django.utils.translation import gettext as _
  15. from common.utils.logger import maxkb_logger
  16. from common.utils.rsa_util import rsa_long_decrypt
  17. from common.utils.tool_code import ToolExecutor
  18. from knowledge.models.knowledge_action import State
  19. from tools.models import ToolRecord, ToolTaskTypeChoices, ToolType
  20. from trigger.handler.impl.task.tool_task.common import BaseToolTriggerTask
  21. from trigger.models import TaskRecord
  22. executor = ToolExecutor()
  23. def get_reference(fields, obj):
  24. for field in fields:
  25. value = obj.get(field)
  26. if value is None:
  27. return None
  28. else:
  29. obj = value
  30. return obj
  31. def get_field_value(value, kwargs):
  32. source = value.get('source')
  33. if source == 'custom':
  34. return value.get('value')
  35. else:
  36. return get_reference(value.get('value'), kwargs)
  37. def _convert_value(_type, value):
  38. if value is None:
  39. return None
  40. if _type == 'int':
  41. return int(value)
  42. if _type == 'boolean':
  43. value = 0 if ['0', '[]'].__contains__(value) else value
  44. return bool(value)
  45. if _type == 'float':
  46. return float(value)
  47. if _type == 'dict':
  48. v = json.loads(value)
  49. if isinstance(v, dict):
  50. return v
  51. raise Exception(_('type error'))
  52. if _type == 'array':
  53. v = json.loads(value)
  54. if isinstance(v, list):
  55. return v
  56. raise Exception(_('type error'))
  57. return value
  58. def get_tool_execute_parameters(input_field_list, parameter_setting, kwargs):
  59. type_map = {f.get("name"): f.get("type") for f in (input_field_list or []) if f.get("name")}
  60. parameters = {}
  61. for key, value in parameter_setting.items():
  62. raw = get_field_value(value, kwargs)
  63. parameters[key] = _convert_value(type_map.get(key), raw)
  64. return parameters
  65. def get_loop_workflow_node(node_list):
  66. result = []
  67. for item in node_list:
  68. if item.get('type') == 'loop-node':
  69. for loop_item in item.get('loop_node_data') or []:
  70. for inner_item in loop_item.values():
  71. result.append(inner_item)
  72. return result
  73. def get_workflow_state(details):
  74. node_list = details.values()
  75. all_node = [*node_list, *get_loop_workflow_node(node_list)]
  76. err = any([True for value in all_node if value.get('status') == 500 and not value.get('enableException')])
  77. if err:
  78. return State.FAILURE
  79. return State.SUCCESS
  80. def _get_result_detail(result):
  81. if isinstance(result, dict):
  82. result_dict = {k: (str(v)[:500] if len(str(v)) > 500 else v) for k, v in result.items()}
  83. elif isinstance(result, list):
  84. result_dict = [str(item)[:500] if len(str(item)) > 500 else item for item in result]
  85. elif isinstance(result, str):
  86. result_dict = result[:500] if len(result) > 500 else result
  87. else:
  88. result_dict = result
  89. return result_dict
  90. class ToolTask(BaseToolTriggerTask):
  91. def support(self, tool, trigger_task, **kwargs):
  92. return tool.tool_type == ToolType.CUSTOM
  93. def execute(self, tool, trigger_task, **kwargs):
  94. parameter_setting = trigger_task.get('parameter')
  95. tool_id = trigger_task.get('source_id')
  96. task_record_id = uuid.uuid7()
  97. start_time = time.time()
  98. try:
  99. TaskRecord(
  100. id=task_record_id,
  101. trigger_id=trigger_task.get('trigger'),
  102. trigger_task_id=trigger_task.get('id'),
  103. source_type="TOOL",
  104. source_id=tool_id,
  105. task_record_id=task_record_id,
  106. meta={'input': parameter_setting, 'output': {}},
  107. state=State.STARTED
  108. ).save()
  109. ToolRecord(
  110. id=task_record_id,
  111. workspace_id=tool.workspace_id,
  112. tool_id=tool.id,
  113. source_type=ToolTaskTypeChoices.TRIGGER,
  114. source_id=trigger_task.get('trigger'),
  115. meta={'input': parameter_setting, 'output': {}},
  116. state=State.STARTED
  117. ).save()
  118. parameters = get_tool_execute_parameters(tool.input_field_list, parameter_setting, kwargs)
  119. init_params_default_value = {i["field"]: i.get('default_value') for i in tool.init_field_list}
  120. if tool.init_params is not None:
  121. all_params = init_params_default_value | json.loads(rsa_long_decrypt(tool.init_params)) | parameters
  122. else:
  123. all_params = init_params_default_value | parameters
  124. result = executor.exec_code(tool.code, all_params)
  125. result_dict = _get_result_detail(result)
  126. maxkb_logger.debug(f"Tool execution result: {result}")
  127. QuerySet(TaskRecord).filter(id=task_record_id).update(
  128. state=State.SUCCESS,
  129. run_time=time.time() - start_time,
  130. meta={'input': parameter_setting, 'output': result_dict}
  131. )
  132. QuerySet(ToolRecord).filter(id=task_record_id).update(
  133. state=State.SUCCESS,
  134. run_time=time.time() - start_time,
  135. meta={'input': parameters, 'output': result_dict}
  136. )
  137. except Exception as e:
  138. maxkb_logger.error(f"Tool execution error: {traceback.format_exc()}")
  139. QuerySet(TaskRecord).filter(id=task_record_id).update(
  140. state=State.FAILURE,
  141. run_time=time.time() - start_time,
  142. meta={'input': parameter_setting, 'output': 'Error: ' + str(e), 'err_message': 'Error: ' + str(e)}
  143. )
  144. QuerySet(ToolRecord).filter(id=task_record_id).update(
  145. state=State.FAILURE,
  146. run_time=time.time() - start_time,
  147. meta={'input': parameter_setting, 'output': 'Error: ' + str(e), 'err_message': 'Error: ' + str(e)}
  148. )