application_task.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: application_task.py
  6. @date:2026/1/14 19:14
  7. @desc:
  8. """
  9. import json
  10. import time
  11. import traceback
  12. import uuid_utils.compat as uuid
  13. from django.db.models import QuerySet
  14. from application.models import ChatUserType, Chat, ChatRecord, ChatSourceChoices, Application
  15. from chat.serializers.chat import ChatSerializers
  16. from common.utils.logger import maxkb_logger
  17. from knowledge.models.knowledge_action import State
  18. from trigger.handler.base_task import BaseTriggerTask
  19. from trigger.models import TaskRecord, TriggerTask
  20. def get_reference(fields, obj):
  21. for field in fields:
  22. value = obj.get(field)
  23. if value is None:
  24. return None
  25. else:
  26. obj = value
  27. return obj
  28. def conversion_custom_value(value, _type):
  29. if ['array', 'dict', 'float', 'int', 'boolean', 'any'].__contains__(_type):
  30. try:
  31. return json.loads(value)
  32. except Exception as e:
  33. pass
  34. return value
  35. def valid_value_type(value, _type):
  36. if _type == 'array':
  37. return isinstance(value, list)
  38. if _type == 'dict':
  39. return isinstance(value, dict)
  40. if _type == 'float':
  41. return isinstance(value, float)
  42. if _type == 'int':
  43. return isinstance(value, int)
  44. if _type == 'boolean':
  45. return isinstance(value, bool)
  46. return isinstance(value, str)
  47. def get_field_value(value, kwargs, _type, required, default_value, field):
  48. source = value.get('source')
  49. if source == 'custom':
  50. _value = value.get('value')
  51. if _value:
  52. _value = conversion_custom_value(_value, _type)
  53. else:
  54. if default_value:
  55. return default_value
  56. if required:
  57. raise Exception(f'{field} is required')
  58. else:
  59. return None
  60. else:
  61. _value = get_reference(value.get('value'), kwargs)
  62. valid = valid_value_type(_value, _type)
  63. if not valid:
  64. raise Exception(f'{field} type error')
  65. return _value
  66. def get_application_execute_parameters(parameter_setting, application_parameters_setting, kwargs):
  67. many_field = ['api_input_field_list', 'user_input_field_list']
  68. parameters = {'form_data': {}}
  69. for key, value in application_parameters_setting.items():
  70. setting = parameter_setting.get(key)
  71. if setting:
  72. if many_field.__contains__(key):
  73. for ck, cv in value.items():
  74. _setting = setting.get(ck)
  75. if _setting:
  76. _value = get_field_value(_setting, kwargs, cv.get('type'), cv.get('required'),
  77. cv.get('default_value'), ck)
  78. parameters['form_data'][ck] = _value
  79. else:
  80. if cv.get('default_value'):
  81. parameters['form_data'][ck] = cv.get('default_value')
  82. else:
  83. if cv.get('required'):
  84. raise Exception(f'{ck} is required')
  85. else:
  86. value = get_field_value(setting, kwargs, value.get('type'), value.get('required'),
  87. value.get('default_value'), key)
  88. parameters['message' if key == 'question' else key] = value
  89. else:
  90. if value.get('default_value'):
  91. parameters['message' if key == 'question' else key] = value.get('default_value')
  92. else:
  93. if value.get('required'):
  94. raise Exception(f'{"message" if key == "question" else key} is required')
  95. return parameters
  96. def get_loop_workflow_node(node_list):
  97. result = []
  98. for item in node_list:
  99. if item.get('type') == 'loop-node':
  100. for loop_item in item.get('loop_node_data') or []:
  101. for inner_item in loop_item.values():
  102. result.append(inner_item)
  103. return result
  104. def get_workflow_state(details):
  105. node_list = details.values()
  106. all_node = [*node_list, *get_loop_workflow_node(node_list)]
  107. err = any([True for value in all_node if value.get('status') == 500 and not value.get('enableException')])
  108. if err:
  109. return State.FAILURE
  110. return State.SUCCESS
  111. def get_user_field_component_input_type(input_type):
  112. if input_type == "MultiRow":
  113. return 'array'
  114. if input_type == "SwitchInput":
  115. return 'boolean'
  116. return 'string'
  117. def get_application_parameters_setting(application):
  118. application_parameter_setting = {'question': {
  119. 'required': True,
  120. 'type': 'string'
  121. }}
  122. if application.type == 'SIMPLE':
  123. return application_parameter_setting
  124. else:
  125. base_node_list = [n for n in application.work_flow.get('nodes') if n.get('type') == "base-node"]
  126. if len(base_node_list) == 0:
  127. raise Exception('Incorrect application workflow information')
  128. base_node = base_node_list[0]
  129. api_input_field_list = base_node.get('properties').get('api_input_field_list') or []
  130. api_input_field_list = {user_field.get('variable'): {
  131. 'required': user_field.get('is_required'),
  132. 'default_value': user_field.get('default_value'),
  133. 'type': 'string'
  134. } for user_field in api_input_field_list}
  135. user_input_field_list = base_node.get('properties').get('user_input_field_list') or []
  136. user_input_field_list = {user_field.get('field'): {
  137. 'required': user_field.get('required'),
  138. 'default_value': user_field.get('default_value'),
  139. 'type': get_user_field_component_input_type(user_field.get('input_type'))
  140. } for user_field in user_input_field_list}
  141. application_parameter_setting['api_input_field_list'] = api_input_field_list
  142. application_parameter_setting['user_input_field_list'] = user_input_field_list
  143. node_data = base_node.get('properties').get('node_data') or {}
  144. file_upload_enable = node_data.get('file_upload_enable')
  145. if file_upload_enable:
  146. file_upload_setting = node_data.get('file_upload_setting') or {}
  147. for field in ['audio', 'document', 'image', 'other', 'video']:
  148. v = file_upload_setting.get(field)
  149. if v:
  150. application_parameter_setting[field + '_list'] = {'required': False, 'default_value': [],
  151. 'type': 'array'}
  152. return application_parameter_setting
  153. class ApplicationTask(BaseTriggerTask):
  154. def support(self, trigger_task, **kwargs):
  155. return trigger_task.get('source_type') == 'APPLICATION'
  156. def execute(self, trigger_task, **kwargs):
  157. parameter_setting = trigger_task.get('parameter')
  158. task_record_id = uuid.uuid7()
  159. start_time = time.time()
  160. try:
  161. application = QuerySet(Application).filter(id=trigger_task.get('source_id')).only('type',
  162. 'work_flow').first()
  163. if application is None:
  164. QuerySet(TriggerTask).filter(id=trigger_task.get('id')).delete()
  165. return
  166. application_id = trigger_task.get('source_id')
  167. chat_id = uuid.uuid7()
  168. chat_user_id = str(uuid.uuid7())
  169. chat_record_id = str(uuid.uuid7())
  170. TaskRecord(id=task_record_id, trigger_id=trigger_task.get('trigger'),
  171. trigger_task_id=trigger_task.get('id'),
  172. source_type="APPLICATION",
  173. source_id=application_id,
  174. task_record_id=chat_record_id,
  175. meta={'chat_id': chat_id},
  176. state=State.STARTED).save()
  177. application_parameters_setting = get_application_parameters_setting(application)
  178. parameters = get_application_execute_parameters(parameter_setting, application_parameters_setting, kwargs)
  179. parameters['re_chat'] = False
  180. parameters['stream'] = True
  181. parameters['chat_record_id'] = chat_record_id
  182. message = parameters.get('message')
  183. ip_address = '-'
  184. if kwargs.get('body') is not None:
  185. ip_address = kwargs.get('body').get('ip_address')
  186. Chat.objects.get_or_create(id=chat_id, defaults={
  187. 'application_id': application_id,
  188. 'abstract': message,
  189. 'chat_user_id': chat_user_id,
  190. 'chat_user_type': ChatUserType.ANONYMOUS_USER.value,
  191. 'asker': {'username': "游客"},
  192. 'ip_address': ip_address,
  193. 'source': {
  194. 'type': ChatSourceChoices.TRIGGER.value
  195. },
  196. })
  197. list(ChatSerializers(data={
  198. "chat_id": chat_id,
  199. "chat_user_id": chat_user_id,
  200. 'chat_user_type': ChatUserType.ANONYMOUS_USER.value,
  201. 'application_id': application_id,
  202. 'ip_address': ip_address,
  203. 'source': {
  204. 'type': ChatSourceChoices.TRIGGER.value
  205. },
  206. 'debug': False
  207. }).chat(instance=parameters))
  208. chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
  209. if chat_record:
  210. state = get_workflow_state(chat_record.details)
  211. QuerySet(TaskRecord).filter(id=task_record_id).update(state=state, run_time=chat_record.run_time,
  212. meta={'parameter_setting': parameter_setting,
  213. 'input': parameters, 'output': None})
  214. else:
  215. QuerySet(TaskRecord).filter(id=task_record_id).update(state=State.FAILURE,
  216. run_time=time.time() - start_time,
  217. meta={'parameter_setting': parameter_setting,
  218. 'input': parameters, 'output': None,
  219. 'err_message': 'Error: An unknown error occurred during the execution of the conversation'})
  220. except Exception as e:
  221. maxkb_logger.error(f"Application execution error: {traceback.format_exc()}")
  222. QuerySet(TaskRecord).filter(id=task_record_id).update(
  223. state=State.FAILURE,
  224. run_time=time.time() - start_time,
  225. meta={'input': {'parameter_setting': parameter_setting, **kwargs}, 'output': None,
  226. 'err_message': 'Error: ' + str(e)}
  227. )