i_step_node.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: i_step_node.py
  6. @date:2024/6/3 14:57
  7. @desc:
  8. """
  9. import time
  10. import uuid
  11. from abc import abstractmethod
  12. from hashlib import sha1
  13. from typing import Type, Dict, List
  14. from django.core import cache
  15. from django.db.models import QuerySet
  16. from rest_framework import serializers
  17. from rest_framework.exceptions import ValidationError, ErrorDetail
  18. from application.flow.common import Answer, NodeChunk
  19. from application.models import ApplicationChatUserStats
  20. from application.models import ChatRecord, ChatUserType
  21. from common.field.common import InstanceField
  22. from knowledge.models.knowledge_action import KnowledgeAction, State
  23. from tools.models import ToolRecord
  24. chat_cache = cache
  25. def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
  26. if step_variable is not None:
  27. for key in step_variable:
  28. node.context[key] = step_variable[key]
  29. if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
  30. answer = step_variable['answer']
  31. yield answer
  32. node.answer_text = answer
  33. if global_variable is not None:
  34. for key in global_variable:
  35. workflow.context[key] = global_variable[key]
  36. node.context['run_time'] = time.time() - node.context['start_time']
  37. def is_interrupt(node, step_variable: Dict, global_variable: Dict):
  38. return node.type == 'form-node' and not node.context.get('is_submit', False)
  39. class WorkFlowPostHandler:
  40. def __init__(self, chat_info):
  41. self.chat_info = chat_info
  42. def handler(self, workflow):
  43. workflow_body = workflow.get_body()
  44. question = workflow_body.get('question')
  45. chat_record_id = workflow_body.get('chat_record_id')
  46. chat_id = workflow_body.get('chat_id')
  47. details = workflow.get_runtime_details()
  48. message_tokens = sum([row.get('message_tokens') for row in details.values() if
  49. 'message_tokens' in row and row.get('message_tokens') is not None])
  50. answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
  51. 'answer_tokens' in row and row.get('answer_tokens') is not None])
  52. answer_text_list = workflow.get_answer_text_list()
  53. answer_text = '\n\n'.join(
  54. '\n\n'.join([a.get('content') for a in answer]) for answer in
  55. answer_text_list)
  56. if workflow.chat_record is not None:
  57. chat_record = workflow.chat_record
  58. chat_record.problem_text = question
  59. chat_record.answer_text = answer_text
  60. chat_record.details = details
  61. chat_record.message_tokens = message_tokens
  62. chat_record.answer_tokens = answer_tokens
  63. chat_record.answer_text_list = answer_text_list
  64. chat_record.run_time = time.time() - workflow.context['start_time']
  65. else:
  66. chat_record = ChatRecord(id=chat_record_id,
  67. chat_id=chat_id,
  68. problem_text=question,
  69. answer_text=answer_text,
  70. details=details,
  71. message_tokens=message_tokens,
  72. answer_tokens=answer_tokens,
  73. answer_text_list=answer_text_list,
  74. run_time=time.time() - workflow.context.get('start_time') if workflow.context.get(
  75. 'start_time') is not None else 0,
  76. index=0,
  77. ip_address=self.chat_info.ip_address,
  78. source=self.chat_info.source)
  79. self.chat_info.append_chat_record(chat_record)
  80. self.chat_info.set_cache()
  81. if not self.chat_info.debug and [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
  82. workflow_body.get('chat_user_type')):
  83. application_public_access_client = (QuerySet(ApplicationChatUserStats)
  84. .filter(chat_user_id=workflow_body.get('chat_user_id'),
  85. chat_user_type=workflow_body.get('chat_user_type'),
  86. application_id=self.chat_info.application_id).first())
  87. if application_public_access_client is not None:
  88. application_public_access_client.access_num = application_public_access_client.access_num + 1
  89. application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
  90. application_public_access_client.save()
  91. self.chat_info = None
  92. class KnowledgeWorkflowPostHandler(WorkFlowPostHandler):
  93. def __init__(self, chat_info, knowledge_action_id):
  94. super().__init__(chat_info)
  95. self.knowledge_action_id = knowledge_action_id
  96. def handler(self, workflow):
  97. state = get_workflow_state(workflow)
  98. QuerySet(KnowledgeAction).filter(id=self.knowledge_action_id).update(
  99. state=state,
  100. run_time=time.time() - workflow.context.get('start_time') if workflow.context.get(
  101. 'start_time') is not None else 0)
  102. def get_tool_workflow_state(workflow):
  103. if workflow.is_the_task_interrupted():
  104. return State.REVOKED
  105. details = workflow.get_runtime_details()
  106. node_list = details.values()
  107. all_node = [*node_list, *get_loop_workflow_node(node_list)]
  108. err = any([True for value in all_node if value.get('status') == 500 and not value.get('enableException')])
  109. if err:
  110. return State.FAILURE
  111. return State.SUCCESS
  112. class ToolWorkflowCallPostHandler(WorkFlowPostHandler):
  113. def __init__(self, chat_info, tool_id):
  114. super().__init__(chat_info)
  115. self.tool_id = tool_id
  116. def handler(self, workflow):
  117. self.chat_info = None
  118. self.tool_id = None
  119. class ToolWorkflowPostHandler(WorkFlowPostHandler):
  120. def __init__(self, chat_info, tool_id):
  121. super().__init__(chat_info)
  122. self.tool_id = tool_id
  123. def handler(self, workflow):
  124. state = get_tool_workflow_state(workflow)
  125. record = ToolRecord(id=self.chat_info.tool_record_id, tool_id=self.tool_id,
  126. workspace_id=self.chat_info.workspace_id,
  127. source_type=self.chat_info.source_type,
  128. source_id=self.chat_info.source_id,
  129. state=state,
  130. run_time=time.time() - workflow.context.get('start_time') if workflow.context.get(
  131. 'start_time') is not None else 0,
  132. meta={
  133. 'input_field_list': workflow.get_input_field_list(),
  134. 'output_field_list': workflow.get_output_field_list(),
  135. 'input': workflow.get_input(),
  136. 'output': workflow.out_context,
  137. 'details': workflow.get_runtime_details(),
  138. 'answer_text_list': workflow.get_answer_text_list()
  139. })
  140. self.chat_info.set_record(record)
  141. self.chat_info = None
  142. self.tool_id = None
  143. def get_loop_workflow_node(node_list):
  144. result = []
  145. for item in node_list:
  146. if item.get('type') == 'loop-node':
  147. for loop_item in item.get('loop_node_data') or []:
  148. for inner_item in loop_item.values():
  149. result.append(inner_item)
  150. return result
  151. def get_workflow_state(workflow):
  152. if workflow.is_the_task_interrupted():
  153. return State.REVOKED
  154. details = workflow.get_runtime_details()
  155. node_list = details.values()
  156. all_node = [*node_list, *get_loop_workflow_node(node_list)]
  157. err = any([True for value in all_node if value.get('status') == 500 and not value.get('enableException')])
  158. if err:
  159. return State.FAILURE
  160. write_is_exist = any([True for value in all_node if value.get('type') == 'knowledge-write-node'])
  161. if not write_is_exist:
  162. return State.FAILURE
  163. return State.SUCCESS
  164. class NodeResult:
  165. def __init__(self, node_variable: Dict, workflow_variable: Dict,
  166. _write_context=write_context, _is_interrupt=is_interrupt):
  167. self._write_context = _write_context
  168. self.node_variable = node_variable
  169. self.workflow_variable = workflow_variable
  170. self._is_interrupt = _is_interrupt
  171. def write_context(self, node, workflow):
  172. return self._write_context(self.node_variable, self.workflow_variable, node, workflow)
  173. def is_assertion_result(self):
  174. return 'branch_id' in self.node_variable
  175. def is_interrupt_exec(self, current_node):
  176. """
  177. 是否中断执行
  178. @param current_node:
  179. @return:
  180. """
  181. return self._is_interrupt(current_node, self.node_variable, self.workflow_variable)
  182. class ReferenceAddressSerializer(serializers.Serializer):
  183. node_id = serializers.CharField(required=True, label="节点id")
  184. fields = serializers.ListField(
  185. child=serializers.CharField(required=True, label="节点字段"), required=True,
  186. label="节点字段数组")
  187. class FlowParamsSerializer(serializers.Serializer):
  188. # 历史对答
  189. history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
  190. label="历史对答")
  191. question = serializers.CharField(required=True, label="用户问题")
  192. chat_id = serializers.CharField(required=True, label="对话id")
  193. chat_record_id = serializers.CharField(required=True, label="对话记录id")
  194. stream = serializers.BooleanField(required=True, label="流式输出")
  195. chat_user_id = serializers.CharField(required=False, label="对话用户id")
  196. chat_user_type = serializers.CharField(required=False, label="对话用户类型")
  197. workspace_id = serializers.CharField(required=True, label="工作空间id")
  198. application_id = serializers.CharField(required=True, label="应用id")
  199. re_chat = serializers.BooleanField(required=True, label="换个答案")
  200. debug = serializers.BooleanField(required=True, label="是否debug")
  201. class KnowledgeFlowParamsSerializer(serializers.Serializer):
  202. knowledge_id = serializers.UUIDField(required=True, label="知识库id")
  203. workspace_id = serializers.CharField(required=True, label="工作空间id")
  204. knowledge_action_id = serializers.UUIDField(required=True, label="知识库任务执行器id")
  205. data_source = serializers.DictField(required=True, label="数据源")
  206. knowledge_base = serializers.DictField(required=False, label="知识库设置")
  207. class ToolFlowParamsSerializer(serializers.Serializer):
  208. tool_id = serializers.UUIDField(required=True, label="工具id")
  209. workspace_id = serializers.CharField(required=True, label="工作空间id")
  210. class INode:
  211. view_type = 'many_view'
  212. @abstractmethod
  213. def save_context(self, details, workflow_manage):
  214. pass
  215. def get_answer_list(self) -> List[Answer] | None:
  216. if self.answer_text is None:
  217. return None
  218. reasoning_content_enable = self.context.get('model_setting', {}).get('reasoning_content_enable', False)
  219. return [
  220. Answer(self.answer_text, self.view_type, self.runtime_node_id, self.workflow_params.get('chat_record_id'),
  221. {},
  222. self.runtime_node_id, self.context.get('reasoning_content', '') if reasoning_content_enable else '')]
  223. def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
  224. get_node_params=lambda node: node.properties.get('node_data'), salt=None):
  225. # 当前步骤上下文,用于存储当前步骤信息
  226. self.status = 200
  227. self.err_message = ''
  228. self.node = node
  229. self.node_params = get_node_params(node)
  230. self.workflow_params = workflow_params
  231. self.workflow_manage = workflow_manage
  232. self.node_params_serializer = None
  233. self.flow_params_serializer = None
  234. self.context = {}
  235. self.answer_text = None
  236. self.id = node.id
  237. if up_node_id_list is None:
  238. up_node_id_list = []
  239. self.up_node_id_list = up_node_id_list
  240. self.node_chunk = NodeChunk()
  241. self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
  242. "".join([*sorted(up_node_id_list),
  243. node.id]))),
  244. "utf-8")).hexdigest() + (
  245. "__" + str(salt) if salt is not None else '')
  246. self.extra = {}
  247. def valid_args(self, node_params, flow_params):
  248. flow_params_serializer_class = self.get_flow_params_serializer_class()
  249. node_params_serializer_class = self.get_node_params_serializer_class()
  250. if flow_params_serializer_class is not None and flow_params is not None:
  251. self.flow_params_serializer = flow_params_serializer_class(data=flow_params)
  252. self.flow_params_serializer.is_valid(raise_exception=True)
  253. if node_params_serializer_class is not None:
  254. self.node_params_serializer = node_params_serializer_class(data=node_params)
  255. self.node_params_serializer.is_valid(raise_exception=True)
  256. if self.node.properties.get('status', 200) != 200:
  257. raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用'))
  258. def get_reference_field(self, fields: List[str]):
  259. return self.get_field(self.context, fields)
  260. @staticmethod
  261. def get_field(obj, fields: List[str]):
  262. for field in fields:
  263. value = obj.get(field)
  264. if value is None:
  265. return None
  266. else:
  267. obj = value
  268. return obj
  269. @abstractmethod
  270. def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
  271. pass
  272. def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
  273. return self.workflow_manage.get_params_serializer_class()
  274. def get_write_error_context(self, e):
  275. self.status = 500
  276. self.answer_text = str(e)
  277. self.err_message = str(e)
  278. current_time = time.time()
  279. self.context['run_time'] = current_time - (self.context.get('start_time') or current_time)
  280. def write_error_context(answer, status=200):
  281. pass
  282. return write_error_context
  283. def run(self) -> NodeResult:
  284. """
  285. :return: 执行结果
  286. """
  287. start_time = time.time()
  288. self.context['start_time'] = start_time
  289. result = self._run()
  290. self.context['run_time'] = time.time() - start_time
  291. return result
  292. def _run(self):
  293. result = self.execute()
  294. return result
  295. def execute(self, **kwargs) -> NodeResult:
  296. pass
  297. def get_details(self, index: int, **kwargs):
  298. """
  299. 运行详情
  300. :return: 步骤详情
  301. """
  302. return {}