tool_workflow.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: tool_workflow.py
  6. @date:2026/3/6 13:59
  7. @desc:
  8. """
  9. import asyncio
  10. import json
  11. import os
  12. # coding=utf-8
  13. import pickle
  14. import tempfile
  15. import zipfile
  16. from functools import reduce
  17. from typing import Dict, List
  18. import requests
  19. import uuid_utils.compat as uuid
  20. from django.db import transaction
  21. from django.db.models import QuerySet, Q
  22. from django.http import HttpResponse
  23. from django.utils import timezone
  24. from django.utils.translation import gettext_lazy as _, gettext
  25. from rest_framework import serializers, status
  26. from rest_framework.utils.formatting import lazy_format
  27. from application.flow.common import Workflow, WorkflowMode
  28. from application.flow.i_step_node import ToolWorkflowPostHandler
  29. from application.flow.tool_workflow_manage import ToolWorkflowManage
  30. from application.models import ChatRecord
  31. from application.serializers.application import McpServersSerializer, get_mcp_tools
  32. from application.serializers.common import ToolExecute
  33. from common.database_model_manage.database_model_manage import DatabaseModelManage
  34. from common.exception.app_exception import AppApiException
  35. from common.field.common import UploadedFileField
  36. from common.result import result
  37. from common.utils.common import bytes_to_uploaded_file
  38. from common.utils.common import restricted_loads, generate_uuid
  39. from common.utils.logger import maxkb_logger
  40. from common.utils.tool_code import ToolExecutor
  41. from knowledge.models import KnowledgeWorkflow, Knowledge, KnowledgeScope
  42. from knowledge.serializers.knowledge import KnowledgeModelSerializer, KnowledgeSerializer
  43. from maxkb.const import CONFIG
  44. from system_manage.models import AuthTargetType
  45. from system_manage.models.resource_mapping import ResourceMapping
  46. from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
  47. from tools.models import Tool, ToolScope, ToolWorkflow, ToolWorkflowVersion
  48. from tools.serializers.tool import ToolExportModelSerializer, ToolSerializer
  49. from users.models import User
  50. tool_executor = ToolExecutor()
  51. def is_valid_tool_workflow_circular_dependency(workflow, _id, visited=None, stack=None):
  52. """
  53. workflow: 当前要检查的 workflow 对象
  54. visited: 全局已经访问过的 workflow id
  55. stack: 当前递归栈里的 workflow id
  56. """
  57. if visited is None:
  58. visited = set()
  59. if stack is None:
  60. stack = set()
  61. if _id in stack:
  62. return False
  63. if _id in visited:
  64. return True
  65. stack.add(_id)
  66. for node in workflow.get('nodes', []):
  67. child_tool_ids = []
  68. if node.get('type') == 'ai-chat-node':
  69. node_data = node.get('properties', {}).get('node_data', {})
  70. child_tool_ids = node_data.get('tool_ids') or []
  71. if node.get('type') == 'tool-workflow-lib-node':
  72. child_tool_id = node.get('properties', {}).get('node_data', {}).get('tool_lib_id')
  73. child_tool_ids.append(child_tool_id)
  74. for child_tool_id in child_tool_ids:
  75. if child_tool_id:
  76. child_workflow = QuerySet(ToolWorkflow).filter(tool_id=child_tool_id).first()
  77. if child_workflow:
  78. if not is_valid_tool_workflow_circular_dependency(child_workflow.work_flow, str(child_tool_id),
  79. visited,
  80. stack):
  81. return False
  82. stack.remove(_id)
  83. visited.add(_id)
  84. return True
  85. def hand_node(node, update_tool_map):
  86. if node.get('type') == 'tool-lib-node':
  87. tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
  88. node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, tool_lib_id)
  89. if node.get('type') == 'search-knowledge-node':
  90. node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = []
  91. if node.get('type') == 'ai-chat-node':
  92. node_data = node.get('properties', {}).get('node_data', {})
  93. mcp_tool_ids = node_data.get('mcp_tool_ids') or []
  94. node_data['mcp_tool_ids'] = [update_tool_map.get(tool_id,
  95. tool_id) for tool_id in mcp_tool_ids]
  96. tool_ids = node_data.get('tool_ids') or []
  97. node_data['tool_ids'] = [update_tool_map.get(tool_id,
  98. tool_id) for tool_id in tool_ids]
  99. if node.get('type') == 'mcp-node':
  100. mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
  101. node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
  102. mcp_tool_id)
  103. class ToolWorkflowModelSerializer(serializers.ModelSerializer):
  104. class Meta:
  105. model = ToolWorkflow
  106. fields = '__all__'
  107. class ToolWorkflowImportRequest(serializers.Serializer):
  108. file = UploadedFileField(required=True, label=_("file"))
  109. class ToolWorkflowActionListQuerySerializer(serializers.Serializer):
  110. user_name = serializers.CharField(required=False, label=_('Name'), allow_blank=True, allow_null=True)
  111. state = serializers.CharField(required=False, label=_("State"), allow_blank=True, allow_null=True)
  112. class ToolWorkflowInstance:
  113. def __init__(self, knowledge_workflow: dict, version: str, tool_list: List[dict]):
  114. self.knowledge_workflow = knowledge_workflow
  115. self.version = version
  116. self.tool_list = tool_list
  117. def get_tool_list(self):
  118. return self.tool_list or []
  119. class ToolWorkflowSerializer(serializers.Serializer):
  120. class Operate(serializers.Serializer):
  121. user_id = serializers.UUIDField(required=True, label=_('user id'))
  122. workspace_id = serializers.CharField(required=False, label=_('workspace id'))
  123. tool_id = serializers.UUIDField(required=True, label=_('tool id'))
  124. def is_valid(self, *, raise_exception=False):
  125. super().is_valid(raise_exception=True)
  126. workspace_id = self.data.get('workspace_id')
  127. query_set = QuerySet(Tool).filter(id=self.data.get('tool_id'))
  128. if workspace_id:
  129. query_set = query_set.filter(workspace_id=workspace_id)
  130. if not query_set.exists():
  131. raise AppApiException(500, _('Tool id does not exist'))
  132. def debug(self, instance: Dict, user, with_valid=True):
  133. if with_valid:
  134. self.is_valid(raise_exception=True)
  135. tool_workflow = QuerySet(ToolWorkflow).filter(tool_id=self.data.get("tool_id")).first()
  136. workspace_id = tool_workflow.workspace_id
  137. tool_record_id = instance.get('chat_record_id') or str(uuid.uuid7())
  138. took_execute = ToolExecute(self.data.get("tool_id"), tool_record_id,
  139. workspace_id,
  140. None,
  141. None,
  142. True)
  143. record = took_execute.get_record()
  144. work_flow_manage = ToolWorkflowManage(
  145. Workflow.new_instance(tool_workflow.work_flow, WorkflowMode.TOOL),
  146. {
  147. 'chat_record_id': tool_record_id,
  148. 'tool_id': self.data.get("tool_id"),
  149. 'stream': True,
  150. 'workspace_id': workspace_id,
  151. **instance},
  152. ToolWorkflowPostHandler(took_execute, self.data.get("tool_id")),
  153. is_the_task_interrupted=lambda: False,
  154. child_node=instance.get('child_node'),
  155. start_node_id=instance.get('runtime_node_id'),
  156. start_node_data=instance.get('node_data'),
  157. chat_record=self.to_chat_record(record)
  158. )
  159. r = work_flow_manage.run()
  160. return r
  161. @staticmethod
  162. def to_chat_record(record):
  163. if record is None:
  164. return None
  165. return ChatRecord(
  166. answer_text_list=record.meta.get('answer_text_list'),
  167. details=record.meta.get('details'),
  168. answer_text='',
  169. )
  170. def publish(self, with_valid=True):
  171. if with_valid:
  172. self.is_valid()
  173. user_id = self.data.get('user_id')
  174. user = QuerySet(User).filter(id=user_id).first()
  175. tool_workflow = QuerySet(ToolWorkflow).filter(tool_id=self.data.get("tool_id")).first()
  176. workspace_id = tool_workflow.workspace_id
  177. work_flow_version = ToolWorkflowVersion(work_flow=tool_workflow.work_flow,
  178. tool_id=self.data.get("tool_id"),
  179. name=timezone.localtime(timezone.now()).strftime(
  180. '%Y-%m-%d %H:%M:%S'),
  181. publish_user_id=user_id,
  182. publish_user_name=user.username,
  183. workspace_id=workspace_id)
  184. work_flow_version.save()
  185. QuerySet(ToolWorkflow).filter(
  186. tool_id=self.data.get("tool_id")
  187. ).update(is_publish=True, publish_time=timezone.now())
  188. return True
  189. def list_knowledge(self, with_valid=True):
  190. if with_valid:
  191. self.is_valid(raise_exception=True)
  192. workspace_id = self.data.get("workspace_id")
  193. user_id = self.data.get('user_id')
  194. if workspace_id == 'None':
  195. return [{**KnowledgeModelSerializer(k).data, 'scope': 'SHARED'} for k in
  196. QuerySet(Knowledge).filter(workspace_id='None')]
  197. knowledge_workspace_authorization_model = DatabaseModelManage.get_model('knowledge_workspace_authorization')
  198. share_knowledge_list = []
  199. if knowledge_workspace_authorization_model is not None:
  200. white_list_condition = Q(authentication_type='WHITE_LIST') & Q(
  201. workspace_id_list__contains=[workspace_id])
  202. default_condition = ~Q(authentication_type='WHITE_LIST') & ~Q(
  203. workspace_id_list__contains=[workspace_id])
  204. # 组合查询
  205. query = white_list_condition | default_condition
  206. inner = QuerySet(knowledge_workspace_authorization_model).filter(query)
  207. share_knowledge_list = [{**KnowledgeModelSerializer(k).data, 'scope': 'SHARED'} for k in
  208. QuerySet(Knowledge).filter(id__in=inner)]
  209. workspace_knowledge_list = [{**k, 'scope': 'WORKSPACE'} for k in KnowledgeSerializer.Query(
  210. data={
  211. 'workspace_id': workspace_id,
  212. 'scope': KnowledgeScope.WORKSPACE,
  213. 'user_id': user_id
  214. }
  215. ).list() if k.get('resource_type') == 'knowledge']
  216. return [*workspace_knowledge_list, *share_knowledge_list]
  217. @staticmethod
  218. def get_tool_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, tool_id):
  219. """
  220. @param application_knowledge_id_list: 当前应用可修改的知识库列表
  221. @param knowledge_id_list: 用户修改的知识库列表
  222. @param application_id: 应用id
  223. @return:
  224. """
  225. # 当前知识库和应用已关联列表
  226. knowledge_application_mapping_list = QuerySet(ResourceMapping).filter(source_id=tool_id,
  227. source_type='TOOL',
  228. target_type="KNOWLEDGE",
  229. ).exclude(
  230. target_id__in=application_knowledge_id_list)
  231. edit_knowledge_list = [ResourceMapping(source_id=tool_id, target_id=knowledge_id,
  232. source_type='TOOL',
  233. target_type="KNOWLEDGE")
  234. for knowledge_id in knowledge_id_list]
  235. return list(knowledge_application_mapping_list) + edit_knowledge_list
  236. def edit(self, instance: Dict):
  237. self.is_valid(raise_exception=True)
  238. tool = QuerySet(Tool).filter(id=self.data.get("tool_id")).first()
  239. workflow_id = tool.workspace_id
  240. if instance.get("work_flow"):
  241. dependency = is_valid_tool_workflow_circular_dependency(workflow=instance.get('work_flow'),
  242. _id=str(tool.id))
  243. if not dependency:
  244. raise Exception(gettext('There is a circular dependency in the tool workflow'))
  245. QuerySet(ToolWorkflow).update_or_create(tool_id=self.data.get("tool_id"),
  246. create_defaults={'id': uuid.uuid7(),
  247. 'tool_id': self.data.get(
  248. "tool_id"),
  249. "workspace_id": workflow_id,
  250. 'work_flow': instance.get('work_flow',
  251. {}), },
  252. defaults={
  253. 'tool_id': self.data.get("tool_id"),
  254. 'workspace_id': workflow_id,
  255. 'work_flow': instance.get('work_flow')
  256. })
  257. # 当前用户可修改关联的知识库列表
  258. tool_knowledge_id_list = [str(knowledge.get('id')) for knowledge in
  259. self.list_knowledge(with_valid=False)]
  260. knowledge_id_list = []
  261. if 'knowledge_id_list' in instance:
  262. # 当前用户可修改关联的知识库列表
  263. application_knowledge_id_list = [str(knowledge.get('id')) for knowledge in
  264. self.list_knowledge(with_valid=False)]
  265. knowledge_id_list = instance.get('knowledge_id_list')
  266. for knowledge_id in knowledge_id_list:
  267. if not application_knowledge_id_list.__contains__(knowledge_id):
  268. message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'),
  269. dataset_id=knowledge_id)
  270. raise AppApiException(500, str(message))
  271. update_resource_mapping_by_tool(self.data.get("tool_id"),
  272. self.get_tool_knowledge_mapping(
  273. tool_knowledge_id_list,
  274. knowledge_id_list,
  275. self.data.get("tool_id")))
  276. return self.one()
  277. if instance.get("work_flow_template"):
  278. template_instance = instance.get('work_flow_template')
  279. download_url = template_instance.get('downloadUrl')
  280. # 查找匹配的版本名称
  281. res = requests.get(download_url, timeout=5)
  282. tool = QuerySet(Tool).filter(id=self.data.get("tool_id")).first()
  283. ToolSerializer.Import(data={
  284. 'user_id': self.data.get('user_id'),
  285. 'workspace_id': workflow_id,
  286. 'folder_id': tool.folder_id,
  287. 'file': bytes_to_uploaded_file(res.content, 'file.tool')
  288. }).update_template_workflow(str(self.data.get('tool_id')))
  289. try:
  290. requests.get(template_instance.get('downloadCallbackUrl'), timeout=5)
  291. except Exception as e:
  292. maxkb_logger.error(f"callback appstore tool download error: {e}")
  293. return self.one()
  294. def one(self):
  295. self.is_valid(raise_exception=True)
  296. workflow = QuerySet(ToolWorkflow).filter(tool_id=self.data.get('tool_id')).first()
  297. return {**ToolWorkflowModelSerializer(workflow).data}
  298. class ToolWorkflowMcpSerializer(serializers.Serializer):
  299. tool_id = serializers.UUIDField(required=True, label=_('Tool id'))
  300. user_id = serializers.UUIDField(required=True, label=_("User ID"))
  301. workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID"))
  302. def is_valid(self, *, raise_exception=False):
  303. super().is_valid(raise_exception=True)
  304. workspace_id = self.data.get('workspace_id')
  305. query_set = QuerySet(Tool).filter(id=self.data.get('tool_id'))
  306. if workspace_id:
  307. query_set = query_set.filter(workspace_id=workspace_id)
  308. if not query_set.exists():
  309. raise AppApiException(500, _('Tool id does not exist'))
  310. def get_mcp_servers(self, instance, with_valid=True):
  311. if with_valid:
  312. self.is_valid(raise_exception=True)
  313. McpServersSerializer(data=instance).is_valid(raise_exception=True)
  314. servers = json.loads(instance.get('mcp_servers'))
  315. for server, config in servers.items():
  316. if config.get('transport') not in ['sse', 'streamable_http']:
  317. raise AppApiException(500, _('Only support transport=sse or transport=streamable_http'))
  318. tools = []
  319. for server in servers:
  320. tools += [
  321. {
  322. 'server': server,
  323. 'name': tool.name,
  324. 'description': tool.description,
  325. 'args_schema': tool.args_schema,
  326. }
  327. for tool in asyncio.run(get_mcp_tools({server: servers[server]}))]
  328. return tools
  329. class StoreToolWorkflow(serializers.Serializer):
  330. user_id = serializers.UUIDField(required=True, label=_("User ID"))
  331. name = serializers.CharField(required=False, label=_("tool name"), allow_null=True, allow_blank=True)
  332. def get_appstore_templates(self):
  333. self.is_valid(raise_exception=True)
  334. # 下载zip文件
  335. try:
  336. appstore_url = CONFIG.get('APPSTORE_URL', 'https://apps-assets.fit2cloud.com/stable/maxkb.json.zip')
  337. res = requests.get(appstore_url, timeout=5)
  338. res.raise_for_status()
  339. # 创建临时文件保存zip
  340. with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_zip:
  341. temp_zip.write(res.content)
  342. temp_zip_path = temp_zip.name
  343. try:
  344. # 解压zip文件
  345. with zipfile.ZipFile(temp_zip_path, 'r') as zip_ref:
  346. # 获取zip中的第一个文件(假设只有一个json文件)
  347. json_filename = zip_ref.namelist()[0]
  348. json_content = zip_ref.read(json_filename)
  349. # 将json转换为字典
  350. tool_store = json.loads(json_content.decode('utf-8'))
  351. tag_dict = {tag['name']: tag['key'] for tag in tool_store['additionalProperties']['tags']}
  352. filter_apps = []
  353. for tool in tool_store['apps']:
  354. if self.data.get('name', '') != '':
  355. if self.data.get('name').lower() not in tool.get('name', '').lower():
  356. continue
  357. if not tool['downloadUrl'].endswith('.tool') or not [tag_dict[tag] for tag in
  358. tool.get('tags')].__contains__(
  359. 'workflow_template'):
  360. continue
  361. versions = tool.get('versions', [])
  362. tool['label'] = tag_dict[tool.get('tags')[0]] if tool.get('tags') else ''
  363. tool['version'] = next(
  364. (version.get('name') for version in versions if
  365. version.get('downloadUrl') == tool['downloadUrl']),
  366. )
  367. filter_apps.append(tool)
  368. tool_store['apps'] = filter_apps
  369. return tool_store
  370. finally:
  371. # 清理临时文件
  372. os.unlink(temp_zip_path)
  373. except Exception as e:
  374. maxkb_logger.error(f"fetch appstore tools error: {e}")
  375. return {'apps': [], 'additionalProperties': {'tags': []}}
  376. def update_resource_mapping_by_tool(tool_id: str, other_resource_mapping=None):
  377. from application.flow.tools import get_instance_resource, save_workflow_mapping
  378. from system_manage.models.resource_mapping import ResourceType
  379. if other_resource_mapping is None:
  380. other_resource_mapping = []
  381. tool = QuerySet(ToolWorkflow).filter(tool_id=tool_id).first()
  382. instance_mapping = get_instance_resource(tool, ResourceType.TOOL, str(tool_id),
  383. {})
  384. save_workflow_mapping(tool.work_flow, ResourceType.TOOL, str(tool_id),
  385. instance_mapping + other_resource_mapping)
  386. return