loop_workflow_manage.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: workflow_manage.py
  6. @date:2024/1/9 17:40
  7. @desc:
  8. """
  9. from concurrent.futures import ThreadPoolExecutor
  10. from typing import List
  11. from django.db import close_old_connections
  12. from django.utils.translation import get_language
  13. from langchain_core.prompts import PromptTemplate
  14. from application.flow.common import Workflow
  15. from application.flow.i_step_node import WorkFlowPostHandler, INode
  16. from application.flow.step_node import get_node
  17. from application.flow.workflow_manage import WorkflowManage
  18. from common.handle.base_to_response import BaseToResponse
  19. from common.handle.impl.response.system_to_response import SystemToResponse
  20. executor = ThreadPoolExecutor(max_workers=200)
  21. class NodeResultFuture:
  22. def __init__(self, r, e, status=200):
  23. self.r = r
  24. self.e = e
  25. self.status = status
  26. def result(self):
  27. if self.status == 200:
  28. return self.r
  29. else:
  30. raise self.e
  31. def await_result(result, timeout=1):
  32. try:
  33. result.result(timeout)
  34. return False
  35. except Exception as e:
  36. return True
  37. class NodeChunkManage:
  38. def __init__(self, work_flow):
  39. self.node_chunk_list = []
  40. self.current_node_chunk = None
  41. self.work_flow = work_flow
  42. def add_node_chunk(self, node_chunk):
  43. self.node_chunk_list.append(node_chunk)
  44. def contains(self, node_chunk):
  45. return self.node_chunk_list.__contains__(node_chunk)
  46. def pop(self):
  47. if self.current_node_chunk is None:
  48. try:
  49. current_node_chunk = self.node_chunk_list.pop(0)
  50. self.current_node_chunk = current_node_chunk
  51. except IndexError as e:
  52. pass
  53. if self.current_node_chunk is not None:
  54. try:
  55. chunk = self.current_node_chunk.chunk_list.pop(0)
  56. return chunk
  57. except IndexError as e:
  58. if self.current_node_chunk.is_end():
  59. self.current_node_chunk = None
  60. if self.work_flow.answer_is_not_empty():
  61. chunk = self.work_flow.base_to_response.to_stream_chunk_response(
  62. self.work_flow.params['chat_id'],
  63. self.work_flow.params['chat_record_id'],
  64. '\n\n', False, 0, 0)
  65. self.work_flow.append_answer('\n\n')
  66. return chunk
  67. return self.pop()
  68. return None
  69. class LoopWorkflowManage(WorkflowManage):
  70. def __init__(self, flow: Workflow,
  71. params,
  72. work_flow_post_handler: WorkFlowPostHandler,
  73. parentWorkflowManage,
  74. loop_params,
  75. get_loop_context,
  76. base_to_response: BaseToResponse = SystemToResponse(),
  77. start_node_id=None,
  78. start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
  79. self.parentWorkflowManage = parentWorkflowManage
  80. self.loop_params = loop_params
  81. self.get_loop_context = get_loop_context
  82. self.loop_field_list = []
  83. super().__init__(flow, params, work_flow_post_handler, base_to_response, None, None, None,
  84. None,
  85. None, None, start_node_id, start_node_data, chat_record, child_node, is_the_task_interrupted)
  86. def get_node_cls_by_id(self, node_id, up_node_id_list=None,
  87. get_node_params=lambda node: node.properties.get('node_data')):
  88. for node in self.flow.nodes:
  89. if node.id == node_id:
  90. node_instance = get_node(node.type, self.flow.workflow_mode)(node,
  91. self.params, self, up_node_id_list,
  92. get_node_params,
  93. salt=self.get_index())
  94. return node_instance
  95. return None
  96. def stream(self):
  97. close_old_connections()
  98. language = get_language()
  99. self.run_chain_async(self.start_node, None, language)
  100. return self.await_result(is_cleanup=False)
  101. def get_index(self):
  102. return self.loop_params.get('index')
  103. def get_start_node(self):
  104. start_node_list = [node for node in self.flow.nodes if
  105. ['loop-start-node'].__contains__(node.type)]
  106. return start_node_list[0]
  107. def get_reference_field(self, node_id: str, fields: List[str]):
  108. """
  109. @param node_id: 节点id
  110. @param fields: 字段
  111. @return:
  112. """
  113. if node_id == 'global':
  114. return self.parentWorkflowManage.get_reference_field(node_id, fields)
  115. elif node_id == 'chat':
  116. return self.parentWorkflowManage.get_reference_field(node_id, fields)
  117. elif node_id == 'loop':
  118. loop_context = self.get_loop_context()
  119. return INode.get_field(loop_context, fields)
  120. else:
  121. node = self.get_node_by_id(node_id)
  122. if node:
  123. return node.get_reference_field(fields)
  124. return self.parentWorkflowManage.get_reference_field(node_id, fields)
  125. def get_workflow_content(self):
  126. context = {
  127. 'global': self.context,
  128. 'chat': self.chat_context,
  129. 'loop': self.get_loop_context(),
  130. }
  131. for node in self.node_context:
  132. context[node.id] = node.context
  133. return context
  134. def init_fields(self):
  135. super().init_fields()
  136. loop_field_list = []
  137. loop_start_node = self.flow.get_node('loop-start-node')
  138. loop_input_field_list = loop_start_node.properties.get('loop_input_field_list')
  139. node_name = loop_start_node.properties.get('stepName')
  140. node_id = loop_start_node.id
  141. if loop_input_field_list is not None:
  142. for f in loop_input_field_list:
  143. loop_field_list.append(
  144. {'label': f.get('label'), 'value': f.get('field'), 'node_id': node_id, 'node_name': node_name})
  145. self.loop_field_list = loop_field_list
  146. def reset_prompt(self, prompt: str):
  147. prompt = super().reset_prompt(prompt)
  148. for field in self.loop_field_list:
  149. chatLabel = f"loop.{field.get('value')}"
  150. chatValue = f"context.get('loop').get('{field.get('value', '')}','')"
  151. prompt = prompt.replace(chatLabel, chatValue)
  152. prompt = self.parentWorkflowManage.reset_prompt(prompt)
  153. return prompt
  154. def generate_prompt(self, prompt: str):
  155. """
  156. 格式化生成提示词
  157. @param prompt: 提示词信息
  158. @return: 格式化后的提示词
  159. """
  160. context = {**self.get_workflow_content(), **self.parentWorkflowManage.get_workflow_content()}
  161. prompt = self.reset_prompt(prompt)
  162. prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
  163. value = prompt_template.format(context=context)
  164. return value
  165. def get_source_type(self):
  166. return "APPLICATION"
  167. def get_source_id(self):
  168. return self.params.get('application_id')