knowledge_workflow_manage.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: Knowledge_workflow_manage.py
  6. @date:2025/11/13 19:02
  7. @desc:
  8. """
  9. import time
  10. import traceback
  11. from concurrent.futures import ThreadPoolExecutor
  12. from django.db.models import QuerySet
  13. from django.utils.translation import get_language
  14. from application.flow.common import Workflow
  15. from application.flow.i_step_node import WorkFlowPostHandler, KnowledgeFlowParamsSerializer, NodeResult
  16. from application.flow.workflow_manage import WorkflowManage
  17. from common.handle.base_to_response import BaseToResponse
  18. from common.handle.impl.response.system_to_response import SystemToResponse
  19. from knowledge.models.knowledge_action import KnowledgeAction, State
  20. executor = ThreadPoolExecutor(max_workers=200)
  21. class KnowledgeWorkflowManage(WorkflowManage):
  22. def __init__(self, flow: Workflow,
  23. params,
  24. work_flow_post_handler: WorkFlowPostHandler,
  25. base_to_response: BaseToResponse = SystemToResponse(),
  26. start_node_id=None,
  27. start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
  28. super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
  29. None,
  30. None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
  31. def get_params_serializer_class(self):
  32. return KnowledgeFlowParamsSerializer
  33. def get_start_node(self):
  34. start_node_list = [node for node in self.flow.nodes if
  35. self.params.get('data_source', {}).get('node_id') == node.id]
  36. return start_node_list[0]
  37. def run(self):
  38. self.context['start_time'] = time.time()
  39. executor.submit(self._run)
  40. def _run(self):
  41. QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
  42. state=State.STARTED)
  43. language = get_language()
  44. self.run_chain_async(self.start_node, None, language)
  45. while self.is_run():
  46. pass
  47. self.work_flow_post_handler.handler(self)
  48. @staticmethod
  49. def get_node_details(current_node, node, index):
  50. if current_node == node:
  51. return {
  52. 'name': node.node.properties.get('stepName'),
  53. "index": index,
  54. 'run_time': 0,
  55. 'type': node.type,
  56. 'status': 202,
  57. 'err_message': ""
  58. }
  59. return node.get_details(index)
  60. def run_chain(self, current_node, node_result_future=None):
  61. QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
  62. details=self.get_runtime_details(lambda node, index: self.get_node_details(current_node, node, index)))
  63. if node_result_future is None:
  64. node_result_future = self.run_node_future(current_node)
  65. try:
  66. result = self.hand_node_result(current_node, node_result_future)
  67. return result
  68. except Exception as e:
  69. traceback.print_exc()
  70. return None
  71. def hand_node_result(self, current_node, node_result_future):
  72. try:
  73. current_result = node_result_future.result()
  74. result = current_result.write_context(current_node, self)
  75. if result is not None:
  76. # 阻塞获取结果
  77. list(result)
  78. if current_node.status == 500:
  79. enableException = current_node.node.properties.get('enableException')
  80. if not enableException:
  81. return None
  82. current_node.context['exception_message'] = current_node.err_message
  83. current_node.context['branch_id'] = 'exception'
  84. r = NodeResult({'branch_id': 'exception', 'exception': current_node.err_message}, {},
  85. _is_interrupt=lambda node, step_variable, global_variable: False)
  86. r.write_context(current_node, self)
  87. return r
  88. if self.is_the_task_interrupted():
  89. current_node.status = 201
  90. return None
  91. return current_result
  92. except Exception as e:
  93. traceback.print_exc()
  94. self.status = 500
  95. current_node.get_write_error_context(e)
  96. self.answer += str(e)
  97. if self.is_the_task_interrupted():
  98. current_node.status = 201
  99. return None
  100. enableException = current_node.node.properties.get('enableException')
  101. if enableException:
  102. current_node.context['exception_message'] = current_node.err_message
  103. current_node.context['branch_id'] = 'exception'
  104. return NodeResult({'branch_id': 'exception', 'exception': current_node.err_message}, {},
  105. _is_interrupt=lambda node, step_variable, global_variable: False)
  106. QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(state=State.FAILURE)
  107. finally:
  108. current_node.node_chunk.end()
  109. QuerySet(KnowledgeAction).filter(id=self.params.get('knowledge_action_id')).update(
  110. details=self.get_runtime_details())
  111. def get_source_type(self):
  112. return "KNOWLEDGE"
  113. def get_source_id(self):
  114. return self.params.get('knowledge_id')