| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440 |
- # coding=utf-8
- """
- @project: MaxKB
- @Author:虎虎
- @file: tool_workflow.py
- @date:2026/3/6 13:59
- @desc:
- """
- import asyncio
- import json
- import os
- # coding=utf-8
- import pickle
- import tempfile
- import zipfile
- from functools import reduce
- from typing import Dict, List
- import requests
- import uuid_utils.compat as uuid
- from django.db import transaction
- from django.db.models import QuerySet, Q
- from django.http import HttpResponse
- from django.utils import timezone
- from django.utils.translation import gettext_lazy as _, gettext
- from rest_framework import serializers, status
- from rest_framework.utils.formatting import lazy_format
- from application.flow.common import Workflow, WorkflowMode
- from application.flow.i_step_node import ToolWorkflowPostHandler
- from application.flow.tool_workflow_manage import ToolWorkflowManage
- from application.models import ChatRecord
- from application.serializers.application import McpServersSerializer, get_mcp_tools
- from application.serializers.common import ToolExecute
- from common.database_model_manage.database_model_manage import DatabaseModelManage
- from common.exception.app_exception import AppApiException
- from common.field.common import UploadedFileField
- from common.result import result
- from common.utils.common import bytes_to_uploaded_file
- from common.utils.common import restricted_loads, generate_uuid
- from common.utils.logger import maxkb_logger
- from common.utils.tool_code import ToolExecutor
- from knowledge.models import KnowledgeWorkflow, Knowledge, KnowledgeScope
- from knowledge.serializers.knowledge import KnowledgeModelSerializer, KnowledgeSerializer
- from maxkb.const import CONFIG
- from system_manage.models import AuthTargetType
- from system_manage.models.resource_mapping import ResourceMapping
- from system_manage.serializers.user_resource_permission import UserResourcePermissionSerializer
- from tools.models import Tool, ToolScope, ToolWorkflow, ToolWorkflowVersion
- from tools.serializers.tool import ToolExportModelSerializer, ToolSerializer
- from users.models import User
- tool_executor = ToolExecutor()
- def is_valid_tool_workflow_circular_dependency(workflow, _id, visited=None, stack=None):
- """
- workflow: 当前要检查的 workflow 对象
- visited: 全局已经访问过的 workflow id
- stack: 当前递归栈里的 workflow id
- """
- if visited is None:
- visited = set()
- if stack is None:
- stack = set()
- if _id in stack:
- return False
- if _id in visited:
- return True
- stack.add(_id)
- for node in workflow.get('nodes', []):
- child_tool_ids = []
- if node.get('type') == 'ai-chat-node':
- node_data = node.get('properties', {}).get('node_data', {})
- child_tool_ids = node_data.get('tool_ids') or []
- if node.get('type') == 'tool-workflow-lib-node':
- child_tool_id = node.get('properties', {}).get('node_data', {}).get('tool_lib_id')
- child_tool_ids.append(child_tool_id)
- for child_tool_id in child_tool_ids:
- if child_tool_id:
- child_workflow = QuerySet(ToolWorkflow).filter(tool_id=child_tool_id).first()
- if child_workflow:
- if not is_valid_tool_workflow_circular_dependency(child_workflow.work_flow, str(child_tool_id),
- visited,
- stack):
- return False
- stack.remove(_id)
- visited.add(_id)
- return True
- def hand_node(node, update_tool_map):
- if node.get('type') == 'tool-lib-node':
- tool_lib_id = (node.get('properties', {}).get('node_data', {}).get('tool_lib_id') or '')
- node.get('properties', {}).get('node_data', {})['tool_lib_id'] = update_tool_map.get(tool_lib_id, tool_lib_id)
- if node.get('type') == 'search-knowledge-node':
- node.get('properties', {}).get('node_data', {})['knowledge_id_list'] = []
- if node.get('type') == 'ai-chat-node':
- node_data = node.get('properties', {}).get('node_data', {})
- mcp_tool_ids = node_data.get('mcp_tool_ids') or []
- node_data['mcp_tool_ids'] = [update_tool_map.get(tool_id,
- tool_id) for tool_id in mcp_tool_ids]
- tool_ids = node_data.get('tool_ids') or []
- node_data['tool_ids'] = [update_tool_map.get(tool_id,
- tool_id) for tool_id in tool_ids]
- if node.get('type') == 'mcp-node':
- mcp_tool_id = (node.get('properties', {}).get('node_data', {}).get('mcp_tool_id') or '')
- node.get('properties', {}).get('node_data', {})['mcp_tool_id'] = update_tool_map.get(mcp_tool_id,
- mcp_tool_id)
- class ToolWorkflowModelSerializer(serializers.ModelSerializer):
- class Meta:
- model = ToolWorkflow
- fields = '__all__'
- class ToolWorkflowImportRequest(serializers.Serializer):
- file = UploadedFileField(required=True, label=_("file"))
- class ToolWorkflowActionListQuerySerializer(serializers.Serializer):
- user_name = serializers.CharField(required=False, label=_('Name'), allow_blank=True, allow_null=True)
- state = serializers.CharField(required=False, label=_("State"), allow_blank=True, allow_null=True)
- class ToolWorkflowInstance:
- def __init__(self, knowledge_workflow: dict, version: str, tool_list: List[dict]):
- self.knowledge_workflow = knowledge_workflow
- self.version = version
- self.tool_list = tool_list
- def get_tool_list(self):
- return self.tool_list or []
- class ToolWorkflowSerializer(serializers.Serializer):
- class Operate(serializers.Serializer):
- user_id = serializers.UUIDField(required=True, label=_('user id'))
- workspace_id = serializers.CharField(required=False, label=_('workspace id'))
- tool_id = serializers.UUIDField(required=True, label=_('tool id'))
- def is_valid(self, *, raise_exception=False):
- super().is_valid(raise_exception=True)
- workspace_id = self.data.get('workspace_id')
- query_set = QuerySet(Tool).filter(id=self.data.get('tool_id'))
- if workspace_id:
- query_set = query_set.filter(workspace_id=workspace_id)
- if not query_set.exists():
- raise AppApiException(500, _('Tool id does not exist'))
- def debug(self, instance: Dict, user, with_valid=True):
- if with_valid:
- self.is_valid(raise_exception=True)
- tool_workflow = QuerySet(ToolWorkflow).filter(tool_id=self.data.get("tool_id")).first()
- workspace_id = tool_workflow.workspace_id
- tool_record_id = instance.get('chat_record_id') or str(uuid.uuid7())
- took_execute = ToolExecute(self.data.get("tool_id"), tool_record_id,
- workspace_id,
- None,
- None,
- True)
- record = took_execute.get_record()
- work_flow_manage = ToolWorkflowManage(
- Workflow.new_instance(tool_workflow.work_flow, WorkflowMode.TOOL),
- {
- 'chat_record_id': tool_record_id,
- 'tool_id': self.data.get("tool_id"),
- 'stream': True,
- 'workspace_id': workspace_id,
- **instance},
- ToolWorkflowPostHandler(took_execute, self.data.get("tool_id")),
- is_the_task_interrupted=lambda: False,
- child_node=instance.get('child_node'),
- start_node_id=instance.get('runtime_node_id'),
- start_node_data=instance.get('node_data'),
- chat_record=self.to_chat_record(record)
- )
- r = work_flow_manage.run()
- return r
- @staticmethod
- def to_chat_record(record):
- if record is None:
- return None
- return ChatRecord(
- answer_text_list=record.meta.get('answer_text_list'),
- details=record.meta.get('details'),
- answer_text='',
- )
- def publish(self, with_valid=True):
- if with_valid:
- self.is_valid()
- user_id = self.data.get('user_id')
- user = QuerySet(User).filter(id=user_id).first()
- tool_workflow = QuerySet(ToolWorkflow).filter(tool_id=self.data.get("tool_id")).first()
- workspace_id = tool_workflow.workspace_id
- work_flow_version = ToolWorkflowVersion(work_flow=tool_workflow.work_flow,
- tool_id=self.data.get("tool_id"),
- name=timezone.localtime(timezone.now()).strftime(
- '%Y-%m-%d %H:%M:%S'),
- publish_user_id=user_id,
- publish_user_name=user.username,
- workspace_id=workspace_id)
- work_flow_version.save()
- QuerySet(ToolWorkflow).filter(
- tool_id=self.data.get("tool_id")
- ).update(is_publish=True, publish_time=timezone.now())
- return True
- def list_knowledge(self, with_valid=True):
- if with_valid:
- self.is_valid(raise_exception=True)
- workspace_id = self.data.get("workspace_id")
- user_id = self.data.get('user_id')
- if workspace_id == 'None':
- return [{**KnowledgeModelSerializer(k).data, 'scope': 'SHARED'} for k in
- QuerySet(Knowledge).filter(workspace_id='None')]
- knowledge_workspace_authorization_model = DatabaseModelManage.get_model('knowledge_workspace_authorization')
- share_knowledge_list = []
- if knowledge_workspace_authorization_model is not None:
- white_list_condition = Q(authentication_type='WHITE_LIST') & Q(
- workspace_id_list__contains=[workspace_id])
- default_condition = ~Q(authentication_type='WHITE_LIST') & ~Q(
- workspace_id_list__contains=[workspace_id])
- # 组合查询
- query = white_list_condition | default_condition
- inner = QuerySet(knowledge_workspace_authorization_model).filter(query)
- share_knowledge_list = [{**KnowledgeModelSerializer(k).data, 'scope': 'SHARED'} for k in
- QuerySet(Knowledge).filter(id__in=inner)]
- workspace_knowledge_list = [{**k, 'scope': 'WORKSPACE'} for k in KnowledgeSerializer.Query(
- data={
- 'workspace_id': workspace_id,
- 'scope': KnowledgeScope.WORKSPACE,
- 'user_id': user_id
- }
- ).list() if k.get('resource_type') == 'knowledge']
- return [*workspace_knowledge_list, *share_knowledge_list]
- @staticmethod
- def get_tool_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, tool_id):
- """
- @param application_knowledge_id_list: 当前应用可修改的知识库列表
- @param knowledge_id_list: 用户修改的知识库列表
- @param application_id: 应用id
- @return:
- """
- # 当前知识库和应用已关联列表
- knowledge_application_mapping_list = QuerySet(ResourceMapping).filter(source_id=tool_id,
- source_type='TOOL',
- target_type="KNOWLEDGE",
- ).exclude(
- target_id__in=application_knowledge_id_list)
- edit_knowledge_list = [ResourceMapping(source_id=tool_id, target_id=knowledge_id,
- source_type='TOOL',
- target_type="KNOWLEDGE")
- for knowledge_id in knowledge_id_list]
- return list(knowledge_application_mapping_list) + edit_knowledge_list
- def edit(self, instance: Dict):
- self.is_valid(raise_exception=True)
- tool = QuerySet(Tool).filter(id=self.data.get("tool_id")).first()
- workflow_id = tool.workspace_id
- if instance.get("work_flow"):
- dependency = is_valid_tool_workflow_circular_dependency(workflow=instance.get('work_flow'),
- _id=str(tool.id))
- if not dependency:
- raise Exception(gettext('There is a circular dependency in the tool workflow'))
- QuerySet(ToolWorkflow).update_or_create(tool_id=self.data.get("tool_id"),
- create_defaults={'id': uuid.uuid7(),
- 'tool_id': self.data.get(
- "tool_id"),
- "workspace_id": workflow_id,
- 'work_flow': instance.get('work_flow',
- {}), },
- defaults={
- 'tool_id': self.data.get("tool_id"),
- 'workspace_id': workflow_id,
- 'work_flow': instance.get('work_flow')
- })
- # 当前用户可修改关联的知识库列表
- tool_knowledge_id_list = [str(knowledge.get('id')) for knowledge in
- self.list_knowledge(with_valid=False)]
- knowledge_id_list = []
- if 'knowledge_id_list' in instance:
- # 当前用户可修改关联的知识库列表
- application_knowledge_id_list = [str(knowledge.get('id')) for knowledge in
- self.list_knowledge(with_valid=False)]
- knowledge_id_list = instance.get('knowledge_id_list')
- for knowledge_id in knowledge_id_list:
- if not application_knowledge_id_list.__contains__(knowledge_id):
- message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'),
- dataset_id=knowledge_id)
- raise AppApiException(500, str(message))
- update_resource_mapping_by_tool(self.data.get("tool_id"),
- self.get_tool_knowledge_mapping(
- tool_knowledge_id_list,
- knowledge_id_list,
- self.data.get("tool_id")))
- return self.one()
- if instance.get("work_flow_template"):
- template_instance = instance.get('work_flow_template')
- download_url = template_instance.get('downloadUrl')
- # 查找匹配的版本名称
- res = requests.get(download_url, timeout=5)
- tool = QuerySet(Tool).filter(id=self.data.get("tool_id")).first()
- ToolSerializer.Import(data={
- 'user_id': self.data.get('user_id'),
- 'workspace_id': workflow_id,
- 'folder_id': tool.folder_id,
- 'file': bytes_to_uploaded_file(res.content, 'file.tool')
- }).update_template_workflow(str(self.data.get('tool_id')))
- try:
- requests.get(template_instance.get('downloadCallbackUrl'), timeout=5)
- except Exception as e:
- maxkb_logger.error(f"callback appstore tool download error: {e}")
- return self.one()
- def one(self):
- self.is_valid(raise_exception=True)
- workflow = QuerySet(ToolWorkflow).filter(tool_id=self.data.get('tool_id')).first()
- return {**ToolWorkflowModelSerializer(workflow).data}
- class ToolWorkflowMcpSerializer(serializers.Serializer):
- tool_id = serializers.UUIDField(required=True, label=_('Tool id'))
- user_id = serializers.UUIDField(required=True, label=_("User ID"))
- workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID"))
- def is_valid(self, *, raise_exception=False):
- super().is_valid(raise_exception=True)
- workspace_id = self.data.get('workspace_id')
- query_set = QuerySet(Tool).filter(id=self.data.get('tool_id'))
- if workspace_id:
- query_set = query_set.filter(workspace_id=workspace_id)
- if not query_set.exists():
- raise AppApiException(500, _('Tool id does not exist'))
- def get_mcp_servers(self, instance, with_valid=True):
- if with_valid:
- self.is_valid(raise_exception=True)
- McpServersSerializer(data=instance).is_valid(raise_exception=True)
- servers = json.loads(instance.get('mcp_servers'))
- for server, config in servers.items():
- if config.get('transport') not in ['sse', 'streamable_http']:
- raise AppApiException(500, _('Only support transport=sse or transport=streamable_http'))
- tools = []
- for server in servers:
- tools += [
- {
- 'server': server,
- 'name': tool.name,
- 'description': tool.description,
- 'args_schema': tool.args_schema,
- }
- for tool in asyncio.run(get_mcp_tools({server: servers[server]}))]
- return tools
- class StoreToolWorkflow(serializers.Serializer):
- user_id = serializers.UUIDField(required=True, label=_("User ID"))
- name = serializers.CharField(required=False, label=_("tool name"), allow_null=True, allow_blank=True)
- def get_appstore_templates(self):
- self.is_valid(raise_exception=True)
- # 下载zip文件
- try:
- appstore_url = CONFIG.get('APPSTORE_URL', 'https://apps-assets.fit2cloud.com/stable/maxkb.json.zip')
- res = requests.get(appstore_url, timeout=5)
- res.raise_for_status()
- # 创建临时文件保存zip
- with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as temp_zip:
- temp_zip.write(res.content)
- temp_zip_path = temp_zip.name
- try:
- # 解压zip文件
- with zipfile.ZipFile(temp_zip_path, 'r') as zip_ref:
- # 获取zip中的第一个文件(假设只有一个json文件)
- json_filename = zip_ref.namelist()[0]
- json_content = zip_ref.read(json_filename)
- # 将json转换为字典
- tool_store = json.loads(json_content.decode('utf-8'))
- tag_dict = {tag['name']: tag['key'] for tag in tool_store['additionalProperties']['tags']}
- filter_apps = []
- for tool in tool_store['apps']:
- if self.data.get('name', '') != '':
- if self.data.get('name').lower() not in tool.get('name', '').lower():
- continue
- if not tool['downloadUrl'].endswith('.tool') or not [tag_dict[tag] for tag in
- tool.get('tags')].__contains__(
- 'workflow_template'):
- continue
- versions = tool.get('versions', [])
- tool['label'] = tag_dict[tool.get('tags')[0]] if tool.get('tags') else ''
- tool['version'] = next(
- (version.get('name') for version in versions if
- version.get('downloadUrl') == tool['downloadUrl']),
- )
- filter_apps.append(tool)
- tool_store['apps'] = filter_apps
- return tool_store
- finally:
- # 清理临时文件
- os.unlink(temp_zip_path)
- except Exception as e:
- maxkb_logger.error(f"fetch appstore tools error: {e}")
- return {'apps': [], 'additionalProperties': {'tags': []}}
- def update_resource_mapping_by_tool(tool_id: str, other_resource_mapping=None):
- from application.flow.tools import get_instance_resource, save_workflow_mapping
- from system_manage.models.resource_mapping import ResourceType
- if other_resource_mapping is None:
- other_resource_mapping = []
- tool = QuerySet(ToolWorkflow).filter(tool_id=tool_id).first()
- instance_mapping = get_instance_resource(tool, ResourceType.TOOL, str(tool_id),
- {})
- save_workflow_mapping(tool.work_flow, ResourceType.TOOL, str(tool_id),
- instance_mapping + other_resource_mapping)
- return
|