tool_workflow_manage.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: tool_workflow_manage.py
  6. @date:2026/3/12 15:17
  7. @desc:
  8. """
  9. import time
  10. from concurrent.futures import ThreadPoolExecutor
  11. from django.db import close_old_connections
  12. from django.utils.translation import get_language
  13. from application.flow.common import Workflow
  14. from application.flow.i_step_node import WorkFlowPostHandler, ToolFlowParamsSerializer
  15. from application.flow.workflow_manage import WorkflowManage
  16. from common.handle.base_to_response import BaseToResponse
  17. from common.handle.impl.response.system_to_response import SystemToResponse
  18. executor = ThreadPoolExecutor(max_workers=200)
  19. class ToolWorkflowManage(WorkflowManage):
  20. def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostHandler,
  21. base_to_response: BaseToResponse = SystemToResponse(), form_data=None,
  22. start_node_id=None,
  23. start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
  24. super().__init__(flow, params, work_flow_post_handler, base_to_response, form_data, None, None, None,
  25. None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
  26. self.out_context = {}
  27. def get_params_serializer_class(self):
  28. return ToolFlowParamsSerializer
  29. def run(self):
  30. self.context['start_time'] = time.time()
  31. close_old_connections()
  32. language = get_language()
  33. if self.params.get('stream'):
  34. return self.run_stream(self.start_node, None, language)
  35. return self.run_block(language)
  36. def stream(self):
  37. close_old_connections()
  38. language = get_language()
  39. self.run_chain_async(self.start_node, None, language)
  40. return self.await_result(is_cleanup=False)
  41. def get_start_node(self):
  42. return self.flow.get_node('tool-start-node')
  43. def get_base_node(self):
  44. """
  45. 获取基础节点
  46. @return:
  47. """
  48. return self.flow.get_node('tool-base-node')
  49. def get_input_field_list(self):
  50. """
  51. 获取输入字段列表
  52. @return: 输入字段配置
  53. """
  54. base_node = self.get_base_node()
  55. return base_node.properties.get("user_input_field_list") or []
  56. def get_output_field_list(self):
  57. """
  58. 获取输出字段列表配置
  59. @return: 输出字段列表配置
  60. """
  61. base_node = self.get_base_node()
  62. return base_node.properties.get("user_output_field_list") or []
  63. def get_input(self):
  64. """
  65. 获取用户输入
  66. @return: 用户输入
  67. """
  68. input_field_list = self.get_input_field_list()
  69. return {f.get('field'): self.params.get(f.get('field')) for f in input_field_list}
  70. def get_source_type(self):
  71. return "TOOL"
  72. def get_source_id(self):
  73. return self.params.get('tool_id')