workflow_manage.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: workflow_manage.py
  6. @date:2024/1/9 17:40
  7. @desc:
  8. """
  9. import concurrent
  10. import json
  11. import threading
  12. from concurrent.futures import ThreadPoolExecutor
  13. from functools import reduce
  14. from typing import List, Dict
  15. from django.db import close_old_connections, connection
  16. from django.utils import translation
  17. from django.utils.translation import get_language
  18. from langchain_core.prompts import PromptTemplate
  19. from rest_framework import status
  20. from application.flow import tools
  21. from application.flow.common import Workflow
  22. from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult, FlowParamsSerializer
  23. from application.flow.step_node import get_node
  24. from common.handle.base_to_response import BaseToResponse
  25. from common.handle.impl.response.system_to_response import SystemToResponse
  26. from common.utils.logger import maxkb_logger
  27. executor = ThreadPoolExecutor(max_workers=200)
  28. class NodeResultFuture:
  29. def __init__(self, r, e, status=200):
  30. self.r = r
  31. self.e = e
  32. self.status = status
  33. def result(self):
  34. if self.status == 200:
  35. return self.r
  36. else:
  37. raise self.e
  38. def await_result(result, timeout=1):
  39. try:
  40. result.result(timeout)
  41. return False
  42. except Exception as e:
  43. return True
  44. class NodeChunkManage:
  45. def __init__(self, work_flow):
  46. self.node_chunk_list = []
  47. self.current_node_chunk = None
  48. self.work_flow = work_flow
  49. def add_node_chunk(self, node_chunk):
  50. self.node_chunk_list.append(node_chunk)
  51. def contains(self, node_chunk):
  52. return self.node_chunk_list.__contains__(node_chunk)
  53. def pop(self):
  54. if self.current_node_chunk is None:
  55. try:
  56. current_node_chunk = self.node_chunk_list.pop(0)
  57. self.current_node_chunk = current_node_chunk
  58. except IndexError as e:
  59. pass
  60. if self.current_node_chunk is not None:
  61. try:
  62. chunk = self.current_node_chunk.chunk_list.pop(0)
  63. return chunk
  64. except IndexError as e:
  65. if self.current_node_chunk.is_end():
  66. self.current_node_chunk = None
  67. if self.work_flow.answer_is_not_empty():
  68. chunk = self.work_flow.base_to_response.to_stream_chunk_response(
  69. self.work_flow.params['chat_id'],
  70. self.work_flow.params['chat_record_id'],
  71. '\n\n', False, 0, 0)
  72. self.work_flow.append_answer('\n\n')
  73. return chunk
  74. return self.pop()
  75. return None
  76. class WorkflowManage:
  77. def __init__(self, flow: Workflow, params, work_flow_post_handler: WorkFlowPostHandler,
  78. base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
  79. document_list=None,
  80. audio_list=None,
  81. video_list=None,
  82. other_list=None,
  83. start_node_id=None,
  84. start_node_data=None, chat_record=None, child_node=None, is_the_task_interrupted=lambda: False):
  85. if form_data is None:
  86. form_data = {}
  87. if image_list is None:
  88. image_list = []
  89. if document_list is None:
  90. document_list = []
  91. if audio_list is None:
  92. audio_list = []
  93. if video_list is None:
  94. video_list = []
  95. if other_list is None:
  96. other_list = []
  97. self.start_node_id = start_node_id
  98. self.start_node = None
  99. self.form_data = form_data
  100. self.image_list = image_list
  101. self.video_list = video_list
  102. self.document_list = document_list
  103. self.audio_list = audio_list
  104. self.other_list = other_list
  105. self.params = params
  106. self.flow = flow
  107. self.context = {}
  108. self.chat_context = {}
  109. self.node_chunk_manage = NodeChunkManage(self)
  110. self.work_flow_post_handler = work_flow_post_handler
  111. self.current_node = None
  112. self.current_result = None
  113. self.answer = ""
  114. self.answer_list = ['']
  115. self.status = 200
  116. self.base_to_response = base_to_response
  117. self.chat_record = chat_record
  118. self.child_node = child_node
  119. self.future_list = []
  120. self.lock = threading.Lock()
  121. self.field_list = []
  122. self.global_field_list = []
  123. self.chat_field_list = []
  124. self.init_fields()
  125. self.is_the_task_interrupted = is_the_task_interrupted
  126. if start_node_id is not None:
  127. self.load_node(chat_record, start_node_id, start_node_data)
  128. else:
  129. self.node_context = []
  130. def init_fields(self):
  131. field_list = []
  132. global_field_list = []
  133. chat_field_list = []
  134. for node in self.flow.nodes:
  135. properties = node.properties
  136. node_name = properties.get('stepName')
  137. node_id = node.id
  138. node_config = properties.get('config')
  139. field_list.append(
  140. {'label': '异常信息', 'value': 'exception_message', 'node_id': node_id, 'node_name': node_name})
  141. if node_config is not None:
  142. fields = node_config.get('fields')
  143. if fields is not None:
  144. for field in fields:
  145. field_list.append({**field, 'node_id': node_id, 'node_name': node_name})
  146. global_fields = node_config.get('globalFields')
  147. if global_fields is not None:
  148. for global_field in global_fields:
  149. global_field_list.append({**global_field, 'node_id': node_id, 'node_name': node_name})
  150. chat_fields = node_config.get('chatFields')
  151. if chat_fields is not None:
  152. for chat_field in chat_fields:
  153. chat_field_list.append({**chat_field, 'node_id': node_id, 'node_name': node_name})
  154. field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True)
  155. global_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True)
  156. chat_field_list.sort(key=lambda f: len(f.get('node_name') + f.get('value')), reverse=True)
  157. self.field_list = field_list
  158. self.global_field_list = global_field_list
  159. self.chat_field_list = chat_field_list
  160. def append_answer(self, content):
  161. self.answer += content
  162. self.answer_list[-1] += content
  163. def answer_is_not_empty(self):
  164. return len(self.answer_list[-1]) > 0
  165. def load_node(self, chat_record, start_node_id, start_node_data):
  166. self.node_context = []
  167. self.answer = chat_record.answer_text
  168. self.answer_list = chat_record.answer_text_list
  169. self.answer_list.append('')
  170. for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')):
  171. node_id = node_details.get('node_id')
  172. if node_details.get('runtime_node_id') == start_node_id:
  173. def get_node_params(n):
  174. is_result = False
  175. if ['application-node', 'loop-node', 'tool-workflow-lib-node'].__contains__(n.type):
  176. is_result = True
  177. return {**n.properties.get('node_data'), 'form_data': start_node_data, 'node_data': start_node_data,
  178. 'child_node': self.child_node, 'is_result': is_result}
  179. self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'),
  180. get_node_params=get_node_params)
  181. self.start_node.valid_args(
  182. {**self.start_node.node_params, 'form_data': start_node_data}, self.start_node.workflow_params)
  183. if self.start_node.type == 'loop-node':
  184. loop_node_data = node_details.get('loop_node_data', {})
  185. for k, v in node_details.get('loop_context_data').items():
  186. if v is not None:
  187. self.start_node.context[k] = v
  188. self.start_node.context['loop_node_data'] = loop_node_data
  189. self.start_node.context['current_index'] = node_details.get('current_index')
  190. self.start_node.context['current_item'] = node_details.get('current_item')
  191. self.start_node.context['loop_answer_data'] = node_details.get('loop_answer_data', {})
  192. if self.start_node.type == 'application-node':
  193. application_node_dict = node_details.get('application_node_dict', {})
  194. self.start_node.context['application_node_dict'] = application_node_dict
  195. self.node_context.append(self.start_node)
  196. continue
  197. node_id = node_details.get('node_id')
  198. node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
  199. node.valid_args(node.node_params, node.workflow_params)
  200. node.save_context(node_details, self)
  201. node.node_chunk.end()
  202. self.node_context.append(node)
  203. def run(self):
  204. close_old_connections()
  205. language = get_language()
  206. if self.params.get('stream'):
  207. return self.run_stream(self.start_node, None, language)
  208. return self.run_block(language)
  209. def run_block(self, language='zh'):
  210. """
  211. 非流式响应
  212. @return: 结果
  213. """
  214. try:
  215. self.run_chain_async(None, None, language)
  216. while self.is_run():
  217. pass
  218. details = self.get_runtime_details()
  219. message_tokens = sum([row.get('message_tokens') for row in details.values() if
  220. 'message_tokens' in row and row.get('message_tokens') is not None])
  221. answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
  222. 'answer_tokens' in row and row.get('answer_tokens') is not None])
  223. answer_text_list = self.get_answer_text_list()
  224. answer_text = '\n\n'.join(
  225. '\n\n'.join([a.get('content') for a in answer]) for answer in
  226. answer_text_list)
  227. answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
  228. self.work_flow_post_handler.handler(self)
  229. res = self.base_to_response.to_block_response(self.params['chat_id'],
  230. self.params['chat_record_id'], answer_text, True
  231. , message_tokens, answer_tokens,
  232. _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
  233. other_params={'answer_list': answer_list})
  234. finally:
  235. self._cleanup()
  236. return res
  237. def _cleanup(self):
  238. """清理所有对象引用"""
  239. # 清理列表
  240. self.future_list.clear()
  241. self.field_list.clear()
  242. self.global_field_list.clear()
  243. self.chat_field_list.clear()
  244. self.image_list.clear()
  245. self.video_list.clear()
  246. self.document_list.clear()
  247. self.audio_list.clear()
  248. self.other_list.clear()
  249. if hasattr(self, 'node_context'):
  250. self.node_context.clear()
  251. # 清理字典
  252. self.context.clear()
  253. self.chat_context.clear()
  254. self.form_data.clear()
  255. # 清理对象引用
  256. self.node_chunk_manage = None
  257. self.work_flow_post_handler = None
  258. self.flow = None
  259. self.start_node = None
  260. self.current_node = None
  261. self.current_result = None
  262. self.chat_record = None
  263. self.base_to_response = None
  264. self.params = None
  265. self.lock = None
  266. def run_stream(self, current_node, node_result_future, language='zh'):
  267. """
  268. 流式响应
  269. @return:
  270. """
  271. self.run_chain_async(current_node, node_result_future, language)
  272. return tools.to_stream_response_simple(self.await_result())
  273. def get_body(self):
  274. return self.params
  275. def is_run(self, timeout=0.5):
  276. future_list_len = len(self.future_list)
  277. try:
  278. r = concurrent.futures.wait(self.future_list, timeout)
  279. if len(r.not_done) > 0:
  280. return True
  281. else:
  282. if future_list_len == len(self.future_list):
  283. return False
  284. else:
  285. return True
  286. except Exception as e:
  287. return True
  288. def await_result(self, is_cleanup=True):
  289. try:
  290. while self.is_run():
  291. while True:
  292. chunk = self.node_chunk_manage.pop()
  293. if chunk is not None:
  294. yield chunk
  295. else:
  296. break
  297. while True:
  298. chunk = self.node_chunk_manage.pop()
  299. if chunk is None:
  300. break
  301. yield chunk
  302. finally:
  303. while self.is_run():
  304. pass
  305. details = self.get_runtime_details()
  306. message_tokens = sum([row.get('message_tokens') for row in details.values() if
  307. 'message_tokens' in row and row.get('message_tokens') is not None])
  308. answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
  309. 'answer_tokens' in row and row.get('answer_tokens') is not None])
  310. self.work_flow_post_handler.handler(self)
  311. yield self.base_to_response.to_stream_chunk_response(self.params.get('chat_id'),
  312. self.params.get('chat_record_id'),
  313. '',
  314. [],
  315. '', True, message_tokens, answer_tokens, {})
  316. if is_cleanup:
  317. self._cleanup()
  318. def run_chain_async(self, current_node, node_result_future, language='zh'):
  319. future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
  320. self.future_list.append(future)
  321. def run_chain_manage(self, current_node, node_result_future, language='zh'):
  322. translation.activate(language)
  323. if current_node is None:
  324. start_node = self.get_start_node()
  325. current_node = get_node(start_node.type, self.flow.workflow_mode)(start_node, self.params, self)
  326. self.node_chunk_manage.add_node_chunk(current_node.node_chunk)
  327. # 添加节点
  328. self.append_node(current_node)
  329. result = self.run_chain(current_node, node_result_future)
  330. if result is None:
  331. return
  332. node_list = self.get_next_node_list(current_node, result)
  333. if len(node_list) == 1:
  334. self.run_chain_manage(node_list[0], None, language)
  335. elif len(node_list) > 1:
  336. sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y)
  337. # 获取到可执行的子节点
  338. result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None, language)} for
  339. node in
  340. sorted_node_run_list]
  341. for r in result_list:
  342. self.future_list.append(r.get('future'))
  343. def run_chain(self, current_node, node_result_future=None):
  344. if node_result_future is None:
  345. node_result_future = self.run_node_future(current_node)
  346. try:
  347. is_stream = self.params.get('stream', True)
  348. result = self.hand_event_node_result(current_node,
  349. node_result_future) if is_stream else self.hand_node_result(
  350. current_node, node_result_future)
  351. return result
  352. except Exception as e:
  353. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  354. return None
  355. def hand_node_result(self, current_node, node_result_future):
  356. try:
  357. current_result = node_result_future.result()
  358. result = current_result.write_context(current_node, self)
  359. if result is not None:
  360. # 阻塞获取结果
  361. list(result)
  362. return current_result
  363. except Exception as e:
  364. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  365. self.status = 500
  366. current_node.get_write_error_context(e)
  367. self.answer += str(e)
  368. finally:
  369. current_node.node_chunk.end()
  370. def append_node(self, current_node):
  371. for index in range(len(self.node_context)):
  372. n = self.node_context[index]
  373. if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id:
  374. self.node_context[index] = current_node
  375. return
  376. self.node_context.append(current_node)
  377. def hand_event_node_result(self, current_node, node_result_future):
  378. runtime_node_id = current_node.runtime_node_id
  379. real_node_id = current_node.runtime_node_id
  380. child_node = {}
  381. view_type = current_node.view_type
  382. try:
  383. current_result = node_result_future.result()
  384. result = current_result.write_context(current_node, self)
  385. if result is not None:
  386. if self.is_result(current_node, current_result):
  387. for r in result:
  388. reasoning_content = ''
  389. content = r
  390. child_node = {}
  391. node_is_end = False
  392. view_type = current_node.view_type
  393. node_type = current_node.type
  394. if isinstance(r, dict):
  395. content = r.get('content')
  396. child_node = {'runtime_node_id': r.get('runtime_node_id'),
  397. 'chat_record_id': r.get('chat_record_id')
  398. , 'child_node': r.get('child_node')}
  399. if r.__contains__('real_node_id'):
  400. real_node_id = r.get('real_node_id')
  401. if r.__contains__('node_is_end'):
  402. node_is_end = r.get('node_is_end')
  403. if r.__contains__('node_type'):
  404. node_type = r.get("node_type")
  405. view_type = r.get('view_type')
  406. reasoning_content = r.get('reasoning_content')
  407. chunk = self.base_to_response.to_stream_chunk_response(self.params.get('chat_id'),
  408. self.params.get('chat_record_id'),
  409. current_node.id,
  410. current_node.up_node_id_list,
  411. content, False, 0, 0,
  412. {'node_type': node_type,
  413. 'runtime_node_id': runtime_node_id,
  414. 'view_type': view_type,
  415. 'child_node': child_node,
  416. 'node_is_end': node_is_end,
  417. 'real_node_id': real_node_id,
  418. 'reasoning_content': reasoning_content,
  419. 'node_status': "SUCCESS"})
  420. current_node.node_chunk.add_chunk(chunk)
  421. chunk = (self.base_to_response
  422. .to_stream_chunk_response(self.params.get('chat_id'),
  423. self.params.get('chat_record_id'),
  424. current_node.id,
  425. current_node.up_node_id_list,
  426. '', False, 0, 0, {'node_is_end': True,
  427. 'runtime_node_id': runtime_node_id,
  428. 'node_type': current_node.type,
  429. 'view_type': view_type,
  430. 'child_node': child_node,
  431. 'real_node_id': real_node_id,
  432. 'reasoning_content': '',
  433. 'node_status': "SUCCESS"}))
  434. current_node.node_chunk.add_chunk(chunk)
  435. else:
  436. list(result)
  437. if current_node.status == 500:
  438. enableException = current_node.node.properties.get('enableException')
  439. if not enableException:
  440. return None
  441. current_node.context['exception_message'] = current_node.err_message
  442. current_node.context['branch_id'] = 'exception'
  443. r = NodeResult({'branch_id': 'exception', 'exception': current_node.err_message}, {},
  444. _is_interrupt=lambda node, step_variable, global_variable: False)
  445. r.write_context(current_node, self)
  446. return r
  447. if self.is_the_task_interrupted():
  448. current_node.status = 201
  449. return None
  450. return current_result
  451. except Exception as e:
  452. # 添加节点
  453. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  454. enableException = current_node.node.properties.get('enableException')
  455. current_node.get_write_error_context(e)
  456. self.status = 500
  457. if self.is_the_task_interrupted():
  458. current_node.status = 201
  459. return None
  460. if not enableException:
  461. chunk = self.base_to_response.to_stream_chunk_response(self.params.get('chat_id'),
  462. self.params.get('chat_id'),
  463. current_node.id,
  464. current_node.up_node_id_list,
  465. 'Exception:' + str(e), False, 0, 0,
  466. {'node_is_end': True,
  467. 'runtime_node_id': current_node.runtime_node_id,
  468. 'node_type': current_node.type,
  469. 'view_type': current_node.view_type,
  470. 'child_node': {},
  471. 'real_node_id': real_node_id,
  472. 'node_status': 'ERROR'})
  473. current_node.node_chunk.add_chunk(chunk)
  474. return None
  475. else:
  476. current_node.context['exception_message'] = current_node.err_message
  477. current_node.context['branch_id'] = 'exception'
  478. return NodeResult({'branch_id': 'exception', 'exception': current_node.err_message}, {},
  479. _is_interrupt=lambda node, step_variable, global_variable: False)
  480. finally:
  481. current_node.node_chunk.end()
  482. # 归还链接到连接池
  483. connection.close()
  484. def run_node_async(self, node):
  485. future = executor.submit(self.run_node, node)
  486. return future
  487. def run_node_future(self, node):
  488. try:
  489. node.valid_args(node.node_params, node.workflow_params)
  490. result = self.run_node(node)
  491. return NodeResultFuture(result, None, 200)
  492. except Exception as e:
  493. return NodeResultFuture(None, e, 500)
  494. def run_node(self, node):
  495. result = node.run()
  496. return result
  497. def is_result(self, current_node, current_node_result):
  498. return current_node.node_params.get('is_result', not self._has_next_node(
  499. current_node, current_node_result)) if current_node.node_params is not None else False
  500. def get_chat_info(self):
  501. return self.work_flow_post_handler.chat_info
  502. def get_chunk_content(self, chunk, is_end=False):
  503. return 'data: ' + json.dumps(
  504. {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
  505. 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
  506. def _has_next_node(self, current_node, node_result: NodeResult | None):
  507. """
  508. 是否有下一个可运行的节点
  509. """
  510. next_edge_node_list = self.flow.get_next_edge_nodes(current_node.id) or []
  511. for next_edge_node in next_edge_node_list:
  512. if node_result is not None and node_result.is_assertion_result():
  513. edge = next_edge_node.edge
  514. if (edge.sourceNodeId == current_node.id and
  515. f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
  516. return True
  517. return len(next_edge_node_list) > 0
  518. def has_next_node(self, node_result: NodeResult | None):
  519. """
  520. 是否有下一个可运行的节点
  521. """
  522. return self._has_next_node(self.get_start_node() if self.current_node is None else self.current_node,
  523. node_result)
  524. def get_runtime_details(self, get_details=lambda n, index: n.get_details(index)):
  525. details_result = {}
  526. for index in range(len(self.node_context)):
  527. node = self.node_context[index]
  528. if self.chat_record is not None and self.chat_record.details is not None and self.start_node:
  529. details = self.chat_record.details.get(node.runtime_node_id)
  530. if details is not None and self.start_node.runtime_node_id != node.runtime_node_id:
  531. details_result[node.runtime_node_id] = details
  532. continue
  533. details = get_details(node, index)
  534. details['node_id'] = node.id
  535. details['up_node_id_list'] = node.up_node_id_list
  536. details['runtime_node_id'] = node.runtime_node_id
  537. details_result[node.runtime_node_id] = details
  538. return details_result
  539. def get_record_answer_list(self):
  540. answer_text_list = self.get_answer_text_list()
  541. return reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
  542. def get_answer_text_list(self):
  543. result = []
  544. answer_list = reduce(lambda x, y: [*x, *y],
  545. [n.get_answer_list() for n in self.node_context if n.get_answer_list() is not None],
  546. [])
  547. up_node = None
  548. for index in range(len(answer_list)):
  549. current_answer = answer_list[index]
  550. if len(current_answer.content) > 0:
  551. if up_node is None or current_answer.view_type == 'single_view' or (
  552. current_answer.view_type == 'many_view' and up_node.view_type == 'single_view'):
  553. result.append([current_answer])
  554. else:
  555. if len(result) > 0:
  556. exec_index = len(result) - 1
  557. if isinstance(result[exec_index], list):
  558. result[exec_index].append(current_answer)
  559. else:
  560. result.insert(0, [current_answer])
  561. up_node = current_answer
  562. if len(result) == 0:
  563. # 如果没有响应 就响应一个空数据
  564. return [[]]
  565. return [[item.to_dict() for item in r] for r in result]
  566. @staticmethod
  567. def dependent_node(edge, node):
  568. up_node_id = edge.sourceNodeId
  569. if not node.node_chunk.is_end():
  570. return False
  571. if node.id == up_node_id:
  572. if node.context.get('branch_id', None):
  573. if edge.sourceAnchorId == f"{node.id}_{node.context.get('branch_id', None)}_right":
  574. return True
  575. else:
  576. return False
  577. if node.type == 'form-node':
  578. if node.context.get('form_data', None) is not None:
  579. return True
  580. return False
  581. return True
  582. def dependent_node_been_executed(self, node_id):
  583. """
  584. 判断依赖节点是否都已执行
  585. @param node_id: 需要判断的节点id
  586. @return:
  587. """
  588. up_edge_list = [edge for edge in self.flow.edges if edge.targetNodeId == node_id]
  589. return all(
  590. [any([self.dependent_node(edge, node) for node in self.node_context if node.id == edge.sourceNodeId]) for
  591. edge in
  592. up_edge_list])
  593. def get_next_node_list(self, current_node, current_node_result):
  594. """
  595. 获取下一个可执行节点列表
  596. @param current_node: 当前可执行节点
  597. @param current_node_result: 当前可执行节点结果
  598. @return: 可执行节点列表
  599. """
  600. # 判断是否中断执行
  601. if current_node_result.is_interrupt_exec(current_node):
  602. return []
  603. node_list = []
  604. next_edge_node_list = self.flow.get_next_edge_nodes(current_node.id) or []
  605. if current_node_result is not None and current_node_result.is_assertion_result():
  606. for edge_node in next_edge_node_list:
  607. edge = edge_node.edge
  608. next_node = edge_node.node
  609. if (
  610. f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
  611. if next_node.properties.get('condition', "AND") == 'AND':
  612. if self.dependent_node_been_executed(edge.targetNodeId):
  613. up_nodes = self.flow.get_up_nodes(edge.targetNodeId)
  614. up_node_id_list = [*current_node.up_node_id_list, current_node.node.id]
  615. if up_nodes and len(up_nodes) > 1:
  616. up_nodes.sort(key=lambda node: node.id)
  617. first = up_nodes[0]
  618. up_node_id_list = [n_c for n_c in self.node_context if n_c.node.id == first.id][
  619. 0].up_node_id_list
  620. up_node_id_list = [*up_node_id_list, first.id]
  621. node_list.append(
  622. self.get_node_cls_by_id(edge.targetNodeId,
  623. up_node_id_list))
  624. else:
  625. node_list.append(
  626. self.get_node_cls_by_id(edge.targetNodeId,
  627. [*current_node.up_node_id_list, current_node.node.id]))
  628. else:
  629. for edge_node in next_edge_node_list:
  630. edge = edge_node.edge
  631. if edge.sourceNodeId + '_right' == edge.sourceAnchorId:
  632. next_node = edge_node.node
  633. if next_node.properties.get('condition', "AND") == 'AND':
  634. if self.dependent_node_been_executed(edge.targetNodeId):
  635. up_nodes = self.flow.get_up_nodes(edge.targetNodeId)
  636. up_node_id_list = [*current_node.up_node_id_list, current_node.node.id]
  637. if up_nodes and len(up_nodes) > 1:
  638. up_nodes.sort(key=lambda node: node.id)
  639. first = up_nodes[0]
  640. up_node_id_list = [n_c for n_c in self.node_context if n_c.node.id == first.id][
  641. 0].up_node_id_list
  642. up_node_id_list = [*up_node_id_list, first.id]
  643. node_list.append(
  644. self.get_node_cls_by_id(edge.targetNodeId,
  645. up_node_id_list))
  646. else:
  647. node_list.append(
  648. self.get_node_cls_by_id(edge.targetNodeId,
  649. [*current_node.up_node_id_list, current_node.node.id]))
  650. return node_list
  651. def get_reference_field(self, node_id: str, fields: List[str]):
  652. """
  653. @param node_id: 节点id
  654. @param fields: 字段
  655. @return:
  656. """
  657. if node_id == 'global':
  658. return INode.get_field(self.context, fields)
  659. elif node_id == 'chat':
  660. return INode.get_field(self.chat_context, fields)
  661. else:
  662. node = self.get_node_by_id(node_id)
  663. if node:
  664. return node.get_reference_field(fields)
  665. return None
  666. def get_workflow_content(self):
  667. context = {
  668. 'global': self.context,
  669. 'chat': self.chat_context
  670. }
  671. for node in self.node_context:
  672. context[node.id] = node.context
  673. return context
  674. def reset_prompt(self, prompt: str):
  675. placeholder = "{}"
  676. for field in self.field_list:
  677. globeLabel = f"{field.get('node_name')}.{field.get('value')}"
  678. globeValue = f"context.get('{field.get('node_id')}',{placeholder}).get('{field.get('value', '')}','')"
  679. prompt = prompt.replace(globeLabel, globeValue)
  680. for field in self.global_field_list:
  681. globeLabel = f"全局变量.{field.get('value')}"
  682. globeLabelNew = f"global.{field.get('value')}"
  683. globeValue = f"context.get('global').get('{field.get('value', '')}','')"
  684. prompt = prompt.replace(globeLabel, globeValue).replace(globeLabelNew, globeValue)
  685. for field in self.chat_field_list:
  686. chatLabel = f"chat.{field.get('value')}"
  687. chatValue = f"context.get('chat').get('{field.get('value', '')}','')"
  688. prompt = prompt.replace(chatLabel, chatValue)
  689. return prompt
  690. def generate_prompt(self, prompt: str):
  691. """
  692. 格式化生成提示词
  693. @param prompt: 提示词信息
  694. @return: 格式化后的提示词
  695. """
  696. context = self.get_workflow_content()
  697. prompt = self.reset_prompt(prompt)
  698. prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
  699. value = prompt_template.format(context=context)
  700. return value
  701. def get_start_node(self):
  702. """
  703. 获取启动节点
  704. @return:
  705. """
  706. start_node_list = [node for node in self.flow.nodes if node.type == 'start-node']
  707. return start_node_list[0]
  708. def get_base_node(self):
  709. """
  710. 获取基础节点
  711. @return:
  712. """
  713. base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
  714. return base_node_list[0]
  715. def get_node_cls_by_id(self, node_id, up_node_id_list=None,
  716. get_node_params=lambda node: node.properties.get('node_data')):
  717. for node in self.flow.nodes:
  718. if node.id == node_id:
  719. node_instance = get_node(node.type, self.flow.workflow_mode)(node,
  720. self.params, self, up_node_id_list,
  721. get_node_params)
  722. return node_instance
  723. return None
  724. def get_node_by_id(self, node_id):
  725. for node in self.node_context:
  726. if node.id == node_id:
  727. return node
  728. return None
  729. def get_node_reference(self, reference_address: Dict):
  730. node = self.get_node_by_id(reference_address.get('node_id'))
  731. return node.context[reference_address.get('node_field')]
  732. def get_params_serializer_class(self):
  733. return FlowParamsSerializer
  734. def get_source_type(self):
  735. return "APPLICATION"
  736. def get_source_id(self):
  737. return self.params.get('application_id')