common.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎
  5. @file: common.py
  6. @date:2024/12/11 17:57
  7. @desc:
  8. """
  9. from enum import Enum
  10. from typing import List, Dict
  11. from django.db.models import QuerySet
  12. from django.utils.translation import gettext as _
  13. from rest_framework.exceptions import ErrorDetail, ValidationError
  14. from common.exception.app_exception import AppApiException
  15. from common.utils.common import group_by
  16. from models_provider.models import Model
  17. from models_provider.tools import get_model_credential
  18. from tools.models.tool import Tool
  19. end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
  20. 'image-understand-node', 'speech-to-text-node', 'text-to-speech-node', 'image-generate-node',
  21. 'variable-assign-node']
  22. class Answer:
  23. def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_node, real_node_id,
  24. reasoning_content):
  25. self.view_type = view_type
  26. self.content = content
  27. self.reasoning_content = reasoning_content
  28. self.runtime_node_id = runtime_node_id
  29. self.chat_record_id = chat_record_id
  30. self.child_node = child_node
  31. self.real_node_id = real_node_id
  32. def to_dict(self):
  33. return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
  34. 'chat_record_id': self.chat_record_id,
  35. 'child_node': self.child_node,
  36. 'reasoning_content': self.reasoning_content,
  37. 'real_node_id': self.real_node_id}
  38. class NodeChunk:
  39. def __init__(self):
  40. self.status = 0
  41. self.chunk_list = []
  42. def add_chunk(self, chunk):
  43. self.chunk_list.append(chunk)
  44. def end(self, chunk=None):
  45. if chunk is not None:
  46. self.add_chunk(chunk)
  47. self.status = 200
  48. def is_end(self):
  49. return self.status == 200
  50. class Edge:
  51. def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
  52. self.id = _id
  53. self.type = _type
  54. self.sourceNodeId = sourceNodeId
  55. self.targetNodeId = targetNodeId
  56. for keyword in keywords:
  57. self.__setattr__(keyword, keywords.get(keyword))
  58. class Node:
  59. def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs):
  60. self.id = _id
  61. self.type = _type
  62. self.x = x
  63. self.y = y
  64. self.properties = properties
  65. for keyword in kwargs:
  66. self.__setattr__(keyword, kwargs.get(keyword))
  67. class EdgeNode:
  68. edge: Edge
  69. node: Node
  70. def __init__(self, edge, node):
  71. self.edge = edge
  72. self.node = node
  73. class WorkflowMode(Enum):
  74. APPLICATION = "application"
  75. APPLICATION_LOOP = "application-loop"
  76. KNOWLEDGE = "knowledge"
  77. KNOWLEDGE_LOOP = "knowledge-loop"
  78. TOOL = "tool"
  79. TOOL_LOOP = "tool-loop"
  80. class Workflow:
  81. """
  82. 节点列表
  83. """
  84. nodes: List[Node]
  85. """
  86. 线列表
  87. """
  88. edges: List[Edge]
  89. """
  90. 节点id:node
  91. """
  92. node_map: Dict[str, Node]
  93. """
  94. 节点id:当前节点id上面的所有节点
  95. """
  96. up_node_map: Dict[str, List[EdgeNode]]
  97. """
  98. 节点id:当前节点id下面的所有节点
  99. """
  100. next_node_map: Dict[str, List[EdgeNode]]
  101. workflow_mode: WorkflowMode
  102. def __init__(self, nodes: List[Node], edges: List[Edge],
  103. workflow_mode: WorkflowMode = WorkflowMode.APPLICATION.value):
  104. self.nodes = nodes
  105. self.edges = edges
  106. self.node_map = {node.id: node for node in nodes}
  107. self.up_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.sourceNodeId)) for
  108. edge in edges] for
  109. key, edges in
  110. group_by(edges, key=lambda edge: edge.targetNodeId).items()}
  111. self.next_node_map = {key: [EdgeNode(edge, self.node_map.get(edge.targetNodeId)) for edge in edges] for
  112. key, edges in
  113. group_by(edges, key=lambda edge: edge.sourceNodeId).items()}
  114. self.workflow_mode = workflow_mode
  115. def get_node(self, node_id):
  116. """
  117. 根据node_id 获取节点信息
  118. @param node_id: node_id
  119. @return: 节点信息
  120. """
  121. return self.node_map.get(node_id)
  122. def get_up_edge_nodes(self, node_id) -> List[EdgeNode]:
  123. """
  124. 根据节点id 获取当前连接前置节点和连线
  125. @param node_id: 节点id
  126. @return: 节点连线列表
  127. """
  128. return self.up_node_map.get(node_id)
  129. def get_next_edge_nodes(self, node_id) -> List[EdgeNode]:
  130. """
  131. 根据节点id 获取当前连接目标节点和连线
  132. @param node_id: 节点id
  133. @return: 节点连线列表
  134. """
  135. return self.next_node_map.get(node_id)
  136. def get_up_nodes(self, node_id) -> List[Node]:
  137. """
  138. 根据节点id 获取当前连接前置节点
  139. @param node_id: 节点id
  140. @return: 节点列表
  141. """
  142. return [en.node for en in (self.up_node_map.get(node_id) or [])]
  143. def get_next_nodes(self, node_id) -> List[Node]:
  144. """
  145. 根据节点id 获取当前连接目标节点
  146. @param node_id: 节点id
  147. @return: 节点列表
  148. """
  149. return [en.node for en in self.next_node_map.get(node_id, [])]
  150. @staticmethod
  151. def new_instance(flow_obj: Dict, workflow_mode: WorkflowMode = WorkflowMode.APPLICATION):
  152. nodes = flow_obj.get('nodes')
  153. edges = flow_obj.get('edges')
  154. nodes = [Node(node.get('id'), node.get('type'), **node)
  155. for node in nodes]
  156. edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges]
  157. return Workflow(nodes, edges, workflow_mode)
  158. def get_start_node(self):
  159. return self.get_node('start-node')
  160. def get_search_node(self):
  161. return [node for node in self.nodes if node.type == 'search-dataset-node']
  162. def is_valid(self):
  163. """
  164. 校验工作流数据
  165. """
  166. self.is_valid_model_params()
  167. self.is_valid_start_node()
  168. self.is_valid_base_node()
  169. self.is_valid_work_flow()
  170. def is_valid_node_params(self, node: Node):
  171. from application.flow.step_node import get_node
  172. get_node(node.type, self.workflow_mode)(node, None, None)
  173. def is_valid_node(self, node: Node):
  174. self.is_valid_node_params(node)
  175. if node.type == 'condition-node':
  176. branch_list = node.properties.get('node_data').get('branch')
  177. for branch in branch_list:
  178. source_anchor_id = f"{node.id}_{branch.get('id')}_right"
  179. edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id]
  180. if len(edge_list) == 0:
  181. raise AppApiException(500,
  182. _('The branch {branch} of the {node} node needs to be connected').format(
  183. node=node.properties.get("stepName"), branch=branch.get("type")))
  184. else:
  185. edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
  186. if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
  187. raise AppApiException(500, _("{node} Nodes cannot be considered as end nodes").format(
  188. node=node.properties.get("stepName")))
  189. def is_valid_work_flow(self, up_node=None):
  190. if up_node is None:
  191. up_node = self.get_start_node()
  192. self.is_valid_node(up_node)
  193. next_nodes = self.get_next_nodes(up_node)
  194. for next_node in next_nodes:
  195. self.is_valid_work_flow(next_node)
  196. def is_valid_start_node(self):
  197. start_node_list = [node for node in self.nodes if node.id == 'start-node']
  198. if len(start_node_list) == 0:
  199. raise AppApiException(500, _('The starting node is required'))
  200. if len(start_node_list) > 1:
  201. raise AppApiException(500, _('There can only be one starting node'))
  202. def is_valid_model_params(self):
  203. node_list = [node for node in self.nodes if (
  204. node.type == 'ai-chat-node' or node.type == 'question-node' or node.type == 'parameter-extraction-node')]
  205. for node in node_list:
  206. if (node.properties.get('node_data', {}).get('model_id_type') or 'custom') == 'reference':
  207. continue
  208. model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first()
  209. if model is None:
  210. raise ValidationError(ErrorDetail(
  211. _('The node {node} model does not exist').format(node=node.properties.get("stepName"))))
  212. credential = get_model_credential(model.provider, model.model_type, model.model_name)
  213. model_params_setting = node.properties.get('node_data', {}).get('model_params_setting')
  214. model_params_setting_form = credential.get_model_params_setting_form(
  215. model.model_name)
  216. if model_params_setting is None:
  217. model_params_setting = model_params_setting_form.get_default_form_data()
  218. node.properties.get('node_data', {})['model_params_setting'] = model_params_setting
  219. if node.properties.get('status', 200) != 200:
  220. raise ValidationError(
  221. ErrorDetail(_("Node {node} is unavailable").format(node=node.properties.get("stepName"))))
  222. node_list = [node for node in self.nodes if (node.type == 'function-lib-node')]
  223. for node in node_list:
  224. function_lib_id = node.properties.get('node_data', {}).get('function_lib_id')
  225. if function_lib_id is None:
  226. raise ValidationError(ErrorDetail(
  227. _('The library ID of node {node} cannot be empty').format(node=node.properties.get("stepName"))))
  228. f_lib = QuerySet(Tool).filter(id=function_lib_id).first()
  229. if f_lib is None:
  230. raise ValidationError(ErrorDetail(_("The function library for node {node} is not available").format(
  231. node=node.properties.get("stepName"))))
  232. def is_valid_base_node(self):
  233. base_node_list = [node for node in self.nodes if node.id == 'base-node']
  234. if len(base_node_list) == 0:
  235. raise AppApiException(500, _('Basic information node is required'))
  236. if len(base_node_list) > 1:
  237. raise AppApiException(500, _('There can only be one basic information node'))