| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091 |
- # coding=utf-8
- """
- @project: maxkb
- @Author:虎
- @file: utils.py
- @date:2024/6/6 15:15
- @desc:
- """
- from langchain_core.tools import StructuredTool
- from application.flow.common import Workflow, WorkflowMode
- from application.serializers.common import ToolExecute
- from tools.models import ToolRecord, Tool, ToolScope, ToolWorkflowVersion, ToolType
- from maxkb.const import CONFIG
- from knowledge.models.knowledge_action import State
- from knowledge.models import File
- from common.utils.logger import maxkb_logger
- from common.result import result
- from application.flow.i_step_node import WorkFlowPostHandler, ToolWorkflowPostHandler
- from application.flow.backend.sandbox_shell import SandboxShellBackend
- import asyncio
- import io
- import json
- import os
- import queue
- import re
- import shutil
- import threading
- import zipfile
- from functools import reduce
- from typing import Iterator
- from pydantic import Field, create_model
- import uuid_utils.compat as uuid
- from asgiref.sync import sync_to_async
- from deepagents import create_deep_agent
- from django.db.models import QuerySet, OuterRef, Subquery
- from django.http import StreamingHttpResponse
- from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk, SystemMessage
- from langchain_mcp_adapters.client import MultiServerMCPClient
- from langgraph.checkpoint.memory import MemorySaver
- # ---------------------------------------------------------------------------
- # Fix: qwen's OpenAI-compatible streaming sends id='' (empty string) for
- # intermediate tool_call_chunks while only the first chunk carries the real
- # id ('call_xxx...'). langchain-core's merge_lists treats '' != 'call_xxx' as
- # an ID conflict and _appends_ instead of merging → the accumulated AIMessage
- # ends up with two separate tool_calls (one with empty args, one with empty
- # id) instead of one correct entry. This causes the Qwen API to reject the
- # next request with "function.arguments must be in JSON format".
- #
- # Patch: normalise id='' → None for items that have an 'index' key
- # (i.e. tool_call_chunk dicts). merge_lists treats None as "no id" and will
- # merge with any existing entry, keeping the real id from the first chunk.
- # ---------------------------------------------------------------------------
- import langchain_core.messages.ai as _lc_ai_module
- from langchain_core.utils._merge import merge_lists as _original_merge_lists
- def _merge_lists_normalize_empty_tool_chunk_ids(left, *others):
- """Wrapper around merge_lists that normalises empty-string IDs to None in
- tool_call_chunk items (those with an 'index' key) so that qwen streaming
- chunks with id='' are merged correctly by index."""
- def _norm(lst):
- if lst is None:
- return lst
- result = []
- for item in lst:
- if isinstance(item, dict) and 'index' in item and item.get('id') == '':
- item = {**item, 'id': None}
- result.append(item)
- return result
- return _original_merge_lists(
- _norm(left),
- *[_norm(o) for o in others],
- )
- # Replace the module-level reference used by add_ai_message_chunks in ai.py
- _lc_ai_module.merge_lists = _merge_lists_normalize_empty_tool_chunk_ids
- class Reasoning:
- def __init__(self, reasoning_content_start, reasoning_content_end):
- self.content = ""
- self.reasoning_content = ""
- self.all_content = ""
- self.reasoning_content_start_tag = reasoning_content_start
- self.reasoning_content_end_tag = reasoning_content_end
- self.reasoning_content_start_tag_len = len(
- reasoning_content_start) if reasoning_content_start is not None else 0
- self.reasoning_content_end_tag_len = len(
- reasoning_content_end) if reasoning_content_end is not None else 0
- self.reasoning_content_end_tag_prefix = reasoning_content_end[
- 0] if self.reasoning_content_end_tag_len > 0 else ''
- self.reasoning_content_is_start = False
- self.reasoning_content_is_end = False
- self.reasoning_content_chunk = ""
- def get_end_reasoning_content(self):
- if not self.reasoning_content_is_start and not self.reasoning_content_is_end:
- r = {'content': self.all_content, 'reasoning_content': ''}
- self.reasoning_content_chunk = ""
- return r
- if self.reasoning_content_is_start and not self.reasoning_content_is_end:
- r = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
- self.reasoning_content_chunk = ""
- return r
- return {'content': '', 'reasoning_content': ''}
- def _normalize_content(self, content):
- """将不同类型的内容统一转换为字符串"""
- if isinstance(content, str):
- return content
- elif isinstance(content, list):
- # 处理包含多种内容类型的列表
- normalized_parts = []
- for item in content:
- if isinstance(item, dict):
- if item.get('type') == 'text':
- normalized_parts.append(item.get('text', ''))
- return ''.join(normalized_parts)
- else:
- return str(content)
- def get_reasoning_content(self, chunk):
- # 如果没有开始思考过程标签那么就全是结果
- if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
- self.content += chunk.content
- return {'content': chunk.content, 'reasoning_content': ''}
- # 如果没有结束思考过程标签那么就全部是思考过程
- if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
- return {'content': '', 'reasoning_content': chunk.content}
- chunk.content = self._normalize_content(chunk.content)
- self.all_content += chunk.content
- if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
- if self.all_content.startswith(self.reasoning_content_start_tag):
- self.reasoning_content_is_start = True
- self.reasoning_content_chunk = self.all_content[self.reasoning_content_start_tag_len:]
- else:
- if not self.reasoning_content_is_end:
- self.reasoning_content_is_end = True
- self.content += self.all_content
- return {'content': self.all_content,
- 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
- '') if chunk.additional_kwargs else ''
- }
- else:
- if self.reasoning_content_is_start:
- self.reasoning_content_chunk += chunk.content
- reasoning_content_end_tag_prefix_index = self.reasoning_content_chunk.find(
- self.reasoning_content_end_tag_prefix)
- if self.reasoning_content_is_end:
- self.content += chunk.content
- return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
- '') if chunk.additional_kwargs else ''
- }
- # 是否包含结束
- if reasoning_content_end_tag_prefix_index > -1:
- if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
- reasoning_content_end_tag_index = self.reasoning_content_chunk.find(
- self.reasoning_content_end_tag)
- if reasoning_content_end_tag_index > -1:
- reasoning_content_chunk = self.reasoning_content_chunk[
- 0:reasoning_content_end_tag_index]
- content_chunk = self.reasoning_content_chunk[
- reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:]
- self.reasoning_content += reasoning_content_chunk
- self.content += content_chunk
- self.reasoning_content_chunk = ""
- self.reasoning_content_is_end = True
- return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk}
- else:
- reasoning_content_chunk = self.reasoning_content_chunk[
- 0:reasoning_content_end_tag_prefix_index + 1]
- self.reasoning_content_chunk = self.reasoning_content_chunk.replace(
- reasoning_content_chunk, '')
- self.reasoning_content += reasoning_content_chunk
- return {'content': '', 'reasoning_content': reasoning_content_chunk}
- else:
- return {'content': '', 'reasoning_content': ''}
- else:
- if self.reasoning_content_is_end:
- self.content += chunk.content
- return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
- '') if chunk.additional_kwargs else ''
- }
- else:
- # aaa
- result = {'content': '',
- 'reasoning_content': self.reasoning_content_chunk}
- self.reasoning_content += self.reasoning_content_chunk
- self.reasoning_content_chunk = ""
- return result
- def event_content(chat_id, chat_record_id, response, workflow,
- write_context,
- post_handler: WorkFlowPostHandler):
- """
- 用于处理流式输出
- @param chat_id: 会话id
- @param chat_record_id: 对话记录id
- @param response: 响应数据
- @param workflow: 工作流管理器
- @param write_context 写入节点上下文
- @param post_handler: 后置处理器
- """
- answer = ''
- try:
- for chunk in response:
- answer += chunk.content
- yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
- 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n"
- write_context(answer, 200)
- post_handler.handler(chat_id, chat_record_id, answer, workflow)
- yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
- 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n"
- except Exception as e:
- answer = str(e)
- write_context(answer, 500)
- post_handler.handler(chat_id, chat_record_id, answer, workflow)
- yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
- 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n"
- def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
- post_handler):
- """
- 将结果转换为服务流输出
- @param chat_id: 会话id
- @param chat_record_id: 对话记录id
- @param response: 响应数据
- @param workflow: 工作流管理器
- @param write_context 写入节点上下文
- @param post_handler: 后置处理器
- @return: 响应
- """
- r = StreamingHttpResponse(
- streaming_content=event_content(
- chat_id, chat_record_id, response, workflow, write_context, post_handler),
- content_type='text/event-stream;charset=utf-8',
- charset='utf-8')
- r['Cache-Control'] = 'no-cache'
- return r
- def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context,
- post_handler: WorkFlowPostHandler):
- """
- 将结果转换为服务输出
- @param chat_id: 会话id
- @param chat_record_id: 对话记录id
- @param response: 响应数据
- @param workflow: 工作流管理器
- @param write_context 写入节点上下文
- @param post_handler: 后置处理器
- @return: 响应
- """
- answer = response.content
- write_context(answer)
- post_handler.handler(chat_id, chat_record_id, answer, workflow)
- return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
- 'content': answer, 'is_end': True})
- def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
- post_handler: WorkFlowPostHandler):
- answer = response.content
- post_handler.handler(chat_id, chat_record_id, answer, workflow)
- return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
- 'content': answer, 'is_end': True})
- def to_stream_response_simple(stream_event):
- r = StreamingHttpResponse(
- streaming_content=stream_event,
- content_type='text/event-stream;charset=utf-8',
- charset='utf-8')
- r['Cache-Control'] = 'no-cache'
- return r
- def generate_tool_message_complete(icon, name, input_content, output_content):
- """生成包含输入和输出的工具消息模版"""
- # 确保输入内容是字符串,如果不是则尝试转换为 JSON 字符串
- if not isinstance(input_content, str):
- input_content = json.dumps(input_content, ensure_ascii=False)
- # 格式化输出
- if not isinstance(output_content, str):
- output_content = json.dumps(output_content, ensure_ascii=False)
- content = {
- "icon": icon,
- "title": name,
- "type": "simple-tool-calls",
- "content": {
- "input": input_content,
- "output": output_content
- }
- }
- return f'<tool_calls_render>{json.dumps(content)}</tool_calls_render>'
- # 全局单例事件循环
- _global_loop = None
- _loop_thread = None
- _loop_lock = threading.Lock()
- def get_global_loop():
- """获取全局共享的事件循环"""
- global _global_loop, _loop_thread
- with _loop_lock:
- if _global_loop is None:
- _global_loop = asyncio.new_event_loop()
- def run_forever():
- asyncio.set_event_loop(_global_loop)
- _global_loop.run_forever()
- _loop_thread = threading.Thread(
- target=run_forever, daemon=True, name="GlobalAsyncLoop")
- _loop_thread.start()
- return _global_loop
- def _extract_tool_id(raw_id):
- """从 raw_id 中提取最后一个符合 call_... 模式的 id,若无匹配则返回原值或 None"""
- if not raw_id:
- return None
- if not isinstance(raw_id, str):
- raw_id = str(raw_id)
- s = raw_id
- prefix = 'call_'
- positions = [m.start() for m in re.finditer(re.escape(prefix), s)]
- if not positions:
- return raw_id
- # 取最后一个前缀位置,截到下一个前缀或结尾
- start = positions[-1]
- end = len(s)
- for pos in positions:
- if pos > start:
- end = pos
- break
- tool_id = s[start:end]
- return tool_id or raw_id
- async def _initialize_skills(mcp_servers, temp_dir):
- skills_dir = os.path.join(temp_dir, 'skills')
- mcp_config = json.loads(mcp_servers)
- if "skills" in mcp_config:
- skill_file_items = mcp_config.pop('skills')
- for skill_file in skill_file_items:
- # 使用 sync_to_async 包装 ORM 查询
- file = await sync_to_async(lambda: QuerySet(File).filter(id=skill_file['file_id']).first())()
- if not file:
- continue
- # get_bytes 可能也涉及 IO,也用 sync_to_async 包装
- file_bytes = await sync_to_async(file.get_bytes)()
- params = skill_file.get('params', {})
- with zipfile.ZipFile(io.BytesIO(file_bytes), 'r') as zip_ref:
- members = [
- m for m in zip_ref.namelist()
- if not m.startswith('__MACOSX/') and '__MACOSX' not in m
- ]
- for member in members:
- if ".." in member or member.startswith("/"):
- raise ValueError(f"非法路径: {member}")
- zip_ref.extractall(skills_dir, members=members)
- # 获取技能解压后的顶级目录名
- top_level_dirs = set()
- for member in members:
- parts = member.split('/')
- if parts[0]:
- top_level_dirs.add(parts[0])
- # 将 params 写入每个顶级目录下的 .env 文件
- if params:
- env_lines = []
- for key, value in params.items():
- # 对含空格或特殊字符的值加引号
- env_lines.append(f'{key}={value}')
- env_content = '\n'.join(env_lines) + '\n'
- for top_dir in top_level_dirs:
- env_path = os.path.join(skills_dir, top_dir, '.env')
- with open(env_path, 'w', encoding='utf-8') as f:
- f.write(env_content)
- os.system("chmod -R g+rx " + temp_dir) # 确保技能目录可访问
- client = MultiServerMCPClient(mcp_config)
- return client
- async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
- source_id=None, source_type=None, temp_dir=None, chat_id=None, extra_tools=None):
- try:
- checkpointer = MemorySaver()
- client = await _initialize_skills(mcp_servers, temp_dir)
- tools = await client.get_tools()
- for tool in tools:
- tool.handle_tool_error = True
- if extra_tools:
- for tool in extra_tools:
- tools.append(tool)
- # ---------------------------------------------------------------------------
- # Fix: vLLM (and Qwen chat templates) reject conversations that contain more
- # than one SystemMessage, or a SystemMessage that is not the very first
- # message. create_deep_agent always prepends its own BASE_AGENT_PROMPT as a
- # SystemMessage before calling the model (factory.py line ~1319). If
- # message_list already contains a SystemMessage (built in base_chat_node.py
- # via generate_message_list), the API receives two system messages and raises
- # "System message must be at the beginning."
- #
- # Solution: strip the user-supplied SystemMessage out of message_list and
- # pass its text as the system_prompt argument of create_deep_agent.
- # deepagents will then merge it with BASE_AGENT_PROMPT into a single
- # combined system message, so the model only ever sees one.
- # ---------------------------------------------------------------------------
- user_system_prompt = None
- filtered_message_list = []
- for msg in message_list:
- if isinstance(msg, SystemMessage):
- # Normalise content to plain string regardless of whether the
- # message was built with a str or a list of content blocks.
- if isinstance(msg.content, str):
- user_system_prompt = msg.content
- elif isinstance(msg.content, list):
- user_system_prompt = ''.join(
- item.get('text', '') if isinstance(item, dict) else str(item)
- for item in msg.content
- )
- else:
- user_system_prompt = str(msg.content)
- else:
- filtered_message_list.append(msg)
- agent = create_deep_agent(
- model=chat_model,
- backend=SandboxShellBackend(root_dir=temp_dir, virtual_mode=True),
- skills=['/skills'],
- tools=tools,
- system_prompt=user_system_prompt,
- interrupt_on={
- "write_file": False,
- "read_file": False,
- "edit_file": False
- },
- checkpointer=checkpointer,
- )
- recursion_limit = int(CONFIG.get(
- "LANGCHAIN_GRAPH_RECURSION_LIMIT", '100'))
- response = agent.astream(
- {"messages": filtered_message_list},
- config={"recursion_limit": recursion_limit,
- "configurable": {"thread_id": chat_id}},
- stream_mode='messages'
- )
- tool_calls_info = {} # tool_id -> {'name': ..., 'input': ...}
- # key(index/id) -> {'id': ..., 'name': ..., 'arguments': ...}
- _tool_fragments = {}
- def _merge_arguments(entry, part_args):
- if not isinstance(part_args, str):
- try:
- part_args = json.dumps(part_args, ensure_ascii=False)
- except Exception:
- part_args = str(part_args) if part_args else ''
- if not part_args:
- return
- # Some providers first emit placeholder args like "{}" and then
- # stream the real JSON fragments via later chunks. Prefer fragments.
- if entry['arguments'] in ('{}', '[]') and part_args.startswith('{'):
- entry['arguments'] = part_args
- return
- if entry['arguments']:
- try:
- existing_obj = json.loads(entry['arguments'])
- new_obj = json.loads(part_args)
- if isinstance(existing_obj, dict) and isinstance(new_obj, dict):
- merged = {**existing_obj, **new_obj}
- entry['arguments'] = json.dumps(
- merged, ensure_ascii=False)
- else:
- entry['arguments'] += part_args
- except (json.JSONDecodeError, ValueError):
- entry['arguments'] += part_args
- else:
- entry['arguments'] = part_args
- def _get_fragment_key(idx, raw_id):
- if idx is not None:
- return f'idx:{idx}'
- if raw_id and str(raw_id).strip():
- return f"id:{_extract_tool_id(str(raw_id).strip())}"
- return None
- def _upsert_fragment(key, raw_id, func_name, part_args):
- if key is None:
- return
- entry = _tool_fragments.setdefault(
- key, {'id': '', 'name': '', 'arguments': ''})
- if raw_id and str(raw_id).strip():
- new_id = str(raw_id).strip()
- if entry.get('completed') and entry.get('id') and entry['id'] != new_id:
- maxkb_logger.debug(
- f"Resetting completed fragment {key}: old ID {entry['id']} -> new ID {new_id}")
- entry.clear()
- entry.update({'id': '', 'name': '', 'arguments': ''})
- entry['id'] = new_id
- if func_name:
- entry['name'] = func_name
- _merge_arguments(entry, part_args)
- async for chunk in response:
- # print(chunk)
- if isinstance(chunk[0], AIMessageChunk):
- # ----------------------------------------------------------------
- # 1. 从 tool_call_chunks 中聚合工具调用片段
- # (qwen/OpenAI streaming 通过 tool_call_chunks 传递,
- # additional_kwargs['tool_calls'] 在流式时通常为空)
- # ----------------------------------------------------------------
- for tc_chunk in (chunk[0].tool_call_chunks or []):
- raw_id = tc_chunk.get('id')
- key = _get_fragment_key(tc_chunk.get('index'), raw_id)
- _upsert_fragment(
- key,
- raw_id,
- tc_chunk.get('name'),
- tc_chunk.get('args', '')
- )
- # ----------------------------------------------------------------
- # 1.1 兼容部分模型将工具调用放在 chunk.tool_calls,且 tool_call_chunks
- # 的 index 为空(例如 ollama/qwen)
- # ----------------------------------------------------------------
- has_tool_call_chunks = bool(chunk[0].tool_call_chunks)
- for tool_call in (chunk[0].tool_calls or []):
- raw_id = tool_call.get('id')
- part_args = tool_call.get('args', '')
- # qwen-plus often emits {} here as a placeholder while
- # the real args are split in tool_call_chunks/invalid_tool_calls.
- if has_tool_call_chunks and (
- part_args == '' or part_args == {} or part_args == []
- ):
- part_args = ''
- key = _get_fragment_key(tool_call.get('index'), raw_id)
- _upsert_fragment(
- key,
- raw_id,
- tool_call.get('name'),
- part_args
- )
- # ----------------------------------------------------------------
- # 1.2 兼容 invalid_tool_calls 分片(部分模型会把中间 JSON 片段放这里)
- # ----------------------------------------------------------------
- for invalid_tool_call in (chunk[0].invalid_tool_calls or []):
- raw_id = invalid_tool_call.get('id')
- key = _get_fragment_key(
- invalid_tool_call.get('index'), raw_id)
- _upsert_fragment(
- key,
- raw_id,
- invalid_tool_call.get('name'),
- invalid_tool_call.get('args', '')
- )
- # ----------------------------------------------------------------
- # 2. 兼容 additional_kwargs['tool_calls'] 方式(旧格式/非流式情况)
- # ----------------------------------------------------------------
- legacy_tool_calls = chunk[0].additional_kwargs.get(
- 'tool_calls', [])
- for tool_call in legacy_tool_calls:
- raw_id = tool_call.get('id')
- func = tool_call.get('function', {})
- if isinstance(func, dict):
- func_name = func.get('name')
- part_args = func.get('arguments', '')
- else:
- func_name = tool_call.get('name')
- part_args = tool_call.get('arguments', '')
- key = _get_fragment_key(tool_call.get('index'), raw_id)
- _upsert_fragment(key, raw_id, func_name, part_args)
- # ----------------------------------------------------------------
- # 3. 检测工具调用结束,更新 tool_calls_info
- # ----------------------------------------------------------------
- is_finish_chunk = (
- chunk[0].response_metadata.get(
- 'finish_reason') == 'tool_calls'
- or chunk[0].chunk_position == 'last'
- )
- if is_finish_chunk:
- # 在 finish chunk 时,将所有未完成的 fragment 标记完成并更新 tool_calls_info
- maxkb_logger.debug(
- f"Processing finish chunk. Tool fragments: {_tool_fragments}")
- for idx, entry in _tool_fragments.items():
- if entry.get('completed'):
- maxkb_logger.debug(
- f"Skipping fragment {idx}: already completed")
- continue
- if not entry.get('id'):
- maxkb_logger.debug(
- f"Skipping fragment {idx}: missing id. Fragment: {entry}")
- continue
- if not entry.get('arguments'):
- maxkb_logger.debug(
- f"Skipping fragment {idx}: missing arguments. Fragment: {entry}")
- continue
- if not entry.get('completed') and entry.get('id') and entry.get('arguments'):
- try:
- parsed_args = json.loads(entry['arguments'])
- filtered_args = {
- k: v for k, v in parsed_args.items()
- if k not in tool_init_params
- } if tool_init_params else parsed_args
- normalized_id = _extract_tool_id(entry['id'])
- info = {
- 'name': entry['name'],
- 'input': json.dumps(filtered_args, ensure_ascii=False)
- }
- tool_calls_info[entry['id']] = info
- if normalized_id and normalized_id != entry['id']:
- tool_calls_info[normalized_id] = info
- entry['completed'] = True
- maxkb_logger.debug(
- f"Added tool call {entry['id']} to tool_calls_info")
- except (json.JSONDecodeError, ValueError) as e:
- # JSON parsing failed, but still add to tool_calls_info with raw arguments
- # to prevent "Tool ID not found" errors when ToolMessage arrives
- maxkb_logger.warning(
- f"Failed to parse tool arguments at finish for tool {entry.get('id', 'unknown')}: "
- f"{entry['arguments']}, error: {e}. Using raw arguments.")
- normalized_id = _extract_tool_id(entry['id'])
- info = {
- 'name': entry['name'],
- # Use raw arguments
- 'input': entry['arguments']
- }
- tool_calls_info[entry['id']] = info
- if normalized_id and normalized_id != entry['id']:
- tool_calls_info[normalized_id] = info
- entry['completed'] = True
- # ----------------------------------------------------------------
- # 4. 修复 tool_call_chunks 中的空 id(回填已知 id)
- # ----------------------------------------------------------------
- if chunk[0].tool_call_chunks:
- for tc_chunk in chunk[0].tool_call_chunks:
- key = _get_fragment_key(
- tc_chunk.get('index'), tc_chunk.get('id'))
- if key is not None:
- frag = _tool_fragments.get(key)
- if frag and frag.get('id') and not tc_chunk.get('id'):
- tc_chunk['id'] = frag['id']
- # ----------------------------------------------------------------
- # 5. 修复 additional_kwargs['tool_calls'](兼容旧格式)
- # 仅在 finish chunk 时写入完整参数,避免污染中间 chunk 的
- # additional_kwargs(中间 chunk 会被 ainvoke 累积,如果写入
- # 不完整 JSON 会导致下一轮 API 调用出现 arguments 非 JSON 格式错误)
- # ----------------------------------------------------------------
- if legacy_tool_calls and is_finish_chunk:
- fixed_tool_calls = []
- for tool_call in legacy_tool_calls:
- key = _get_fragment_key(
- tool_call.get('index'), tool_call.get('id'))
- frag = _tool_fragments.get(
- key) if key is not None else None
- tc = dict(tool_call)
- if frag and frag.get('id') and not tc.get('id'):
- tc['id'] = frag['id']
- if frag and isinstance(tc.get('function'), dict):
- tc['function'] = dict(tc['function'])
- if frag.get('completed'):
- tc['function']['arguments'] = frag['arguments']
- fixed_tool_calls.append(tc)
- chunk[0].additional_kwargs['tool_calls'] = fixed_tool_calls
- yield chunk[0]
- if mcp_output_enable and isinstance(chunk[0], ToolMessage):
- tool_id = chunk[0].tool_call_id
- normalized_tool_id = _extract_tool_id(tool_id)
- tool_info = tool_calls_info.get(tool_id) or tool_calls_info.get(
- normalized_tool_id)
- if tool_info:
- try:
- if isinstance(chunk[0].content, str):
- tool_result = json.loads(chunk[0].content)
- elif isinstance(chunk[0].content, dict):
- tool_result = chunk[0].content
- elif isinstance(chunk[0].content, list):
- tool_result = chunk[0].content[0] if len(
- chunk[0].content) > 0 else {}
- else:
- tool_result = {}
- text = tool_result.get('text') if 'text' in tool_result else None
- text_result = json.loads(text) if text else tool_result
- if text:
- tool_lib_id = text_result.pop('tool_id') if 'tool_id' in text_result else None
- else:
- tool_lib_id = tool_result.pop('tool_id') if 'tool_id' in tool_result else None
- if tool_lib_id:
- await save_tool_record(tool_lib_id, tool_info, tool_result, source_id, source_type)
- tool_result = json.dumps(text_result, ensure_ascii=False)
- except Exception as e:
- tool_result = chunk[0].content
- content = generate_tool_message_complete(
- tool_info.get('icon', ''),
- tool_info['name'],
- tool_info['input'],
- tool_result
- )
- chunk[0].content = content
- else:
- maxkb_logger.warning(
- f"Tool ID {tool_id} not found in tool_calls_info. "
- f"Normalized Tool ID: {normalized_tool_id}. "
- f"Available IDs: {list(tool_calls_info.keys())}. "
- f"Tool fragments at this point: {_tool_fragments}"
- )
- yield chunk[0]
- except ExceptionGroup as eg:
- def get_real_error(exc):
- if isinstance(exc, ExceptionGroup):
- return get_real_error(exc.exceptions[0])
- return exc
- real_error = get_real_error(eg)
- error_msg = f"{type(real_error).__name__}: {str(real_error)}"
- raise RuntimeError(error_msg) from None
- except Exception as e:
- error_msg = f"{type(e).__name__}: {str(e)}"
- raise RuntimeError(error_msg) from None
- async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_type):
- tool = await sync_to_async(lambda: QuerySet(Tool).filter(id=tool_id).first())()
- tool_info['icon'] = tool.icon
- tool_record = ToolRecord(
- id=uuid.uuid7(),
- workspace_id=tool.workspace_id,
- tool_id=tool_id,
- source_type=source_type,
- source_id=source_id,
- meta={'input': tool_info['input'], 'output': tool_result},
- state=State.SUCCESS
- )
- await sync_to_async(tool_record.save)()
- def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
- source_id=None, source_type=None, chat_id=None, extra_tools=None):
- """使用全局事件循环,不创建新实例"""
- result_queue = queue.Queue()
- loop = get_global_loop() # 使用共享循环
- # 创建临时文件夹
- if chat_id:
- temp_dir = os.path.join('/tmp', chat_id[:8])
- else:
- temp_dir = os.path.join('/tmp', uuid.uuid7().hex[:8])
- skills_dir = os.path.join(temp_dir, 'skills')
- os.makedirs(skills_dir, exist_ok=True)
- # print(f"Initializing skills in temporary directory: {skills_dir}")
- async def _run():
- try:
- async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params,
- source_id, source_type, temp_dir, chat_id, extra_tools)
- async for chunk in async_gen:
- result_queue.put(('data', chunk))
- except Exception as e:
- maxkb_logger.error(f'Exception: {e}', exc_info=True)
- result_queue.put(('error', e))
- finally:
- result_queue.put(('done', None))
- # 在全局循环中调度任务
- asyncio.run_coroutine_threadsafe(_run(), loop)
- while True:
- msg_type, data = result_queue.get()
- if msg_type == 'done':
- # 清理临时文件夹
- shutil.rmtree(temp_dir, ignore_errors=True)
- break
- if msg_type == 'error':
- # 清理临时文件夹
- shutil.rmtree(temp_dir, ignore_errors=True)
- raise data
- yield data
- async def anext_async(agen):
- return await agen.__anext__()
- target_source_node_mapping = {
- 'TOOL': {'tool-lib-node': lambda n: [n.get('properties').get('node_data').get('tool_lib_id')],
- 'ai-chat-node': lambda n: [*(n.get('properties').get('node_data').get('mcp_tool_ids') or []),
- *(n.get('properties').get('node_data').get('tool_ids') or []),
- *(n.get('properties').get('node_data').get('skill_tool_ids') or [])],
- 'mcp-node': lambda n: [n.get('properties').get('node_data').get('mcp_tool_id')],
- 'tool-workflow-lib-node': lambda n: [n.get('properties').get('node_data').get('tool_lib_id')]
- },
- 'MODEL': {'ai-chat-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'question-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'speech-to-text-node': lambda n: [n.get('properties').get('node_data').get('stt_model_id')],
- 'text-to-speech-node': lambda n: [n.get('properties').get('node_data').get('tts_model_id')],
- 'image-to-video-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'image-generate-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'intent-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'image-understand-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'parameter-extraction-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- 'video-understand-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
- },
- 'KNOWLEDGE': {'search-knowledge-node': lambda n: n.get('properties').get('node_data').get('knowledge_id_list')},
- 'APPLICATION': {
- 'application-node': lambda n: [n.get('properties').get('node_data').get('application_id')]
- }
- }
- def get_node_handle_callback(source_type, source_id):
- def node_handle_callback(node):
- from system_manage.models.resource_mapping import ResourceMapping
- response = []
- for key, value in target_source_node_mapping.items():
- if node.get('type') in value:
- call = value.get(node.get('type'))
- target_source_id_list = call(node)
- for target_source_id in target_source_id_list:
- if target_source_id:
- response.append(ResourceMapping(source_type=source_type, target_type=key, source_id=source_id,
- target_id=target_source_id))
- return response
- return node_handle_callback
- def get_workflow_resource(workflow, node_handle):
- response = []
- if 'nodes' in workflow:
- for node in workflow.get('nodes'):
- rs = node_handle(node)
- if rs:
- for r in rs:
- response.append(r)
- if node.get('type') == 'loop-node':
- r = get_workflow_resource(node.get('properties', {}).get(
- 'node_data', {}).get('loop_body'), node_handle)
- for rn in r:
- response.append(rn)
- return list({(str(item.target_type) + str(item.target_id)): item for item in response}.values())
- return []
- application_instance_field_call_dict = {
- 'TOOL': [
- lambda instance: instance.mcp_tool_ids or [],
- lambda instance: instance.skill_tool_ids or [],
- lambda instance: instance.tool_ids or []
- ],
- 'MODEL': [
- lambda instance: [instance.model_id] if instance.model_id else [],
- lambda instance: [
- instance.tts_model_id] if instance.tts_model_id else [],
- lambda instance: [
- instance.stt_model_id] if instance.stt_model_id else []
- ]
- }
- knowledge_instance_field_call_dict = {
- 'MODEL': [lambda instance: [instance.embedding_model_id] if instance.embedding_model_id else []],
- }
- def get_instance_resource(instance, source_type, source_id, instance_field_call_dict):
- response = []
- from system_manage.models.resource_mapping import ResourceMapping
- for target_type, call_list in instance_field_call_dict.items():
- target_id_list = reduce(
- lambda x, y: [*x, *y], [call(instance) for call in call_list], [])
- if target_id_list:
- for target_id in target_id_list:
- response.append(ResourceMapping(source_type=source_type, target_type=target_type, source_id=source_id,
- target_id=target_id))
- return response
- def save_workflow_mapping(workflow, source_type, source_id, other_resource_mapping=None):
- if not other_resource_mapping:
- other_resource_mapping = []
- from system_manage.models.resource_mapping import ResourceMapping
- from django.db.models import QuerySet
- QuerySet(ResourceMapping).filter(
- source_type=source_type, source_id=source_id).delete()
- resource_mapping_list = get_workflow_resource(workflow,
- get_node_handle_callback(source_type,
- source_id))
- resource_mapping_list += other_resource_mapping
- if resource_mapping_list:
- QuerySet(ResourceMapping).bulk_create(
- {(str(item.target_type) + str(item.target_id)): item for item in resource_mapping_list}.values())
- def get_tool_id_list(workflow, with_deep=False):
- from tools.models import ToolWorkflow, ToolType
- _result = []
- for node in workflow.get('nodes', []):
- if node.get('type') == 'tool-lib-node':
- tool_id = node.get('properties', {}).get(
- 'node_data', {}).get('tool_lib_id')
- if tool_id:
- _result.append(tool_id)
- elif node.get('type') == 'loop-node':
- r = get_tool_id_list(node.get('properties', {}).get(
- 'node_data', {}).get('loop_body', {}))
- for item in r:
- _result.append(item)
- elif node.get('type') == 'tool-workflow-lib-node':
- tool_id = node.get('properties', {}).get(
- 'node_data', {}).get('tool_lib_id')
- if tool_id:
- _result.append(tool_id)
- elif node.get('type') == 'ai-chat-node':
- node_data = node.get('properties', {}).get('node_data', {})
- mcp_tool_ids = node_data.get('mcp_tool_ids') or []
- skill_tool_ids = node_data.get('skill_tool_ids') or []
- tool_ids = node_data.get('tool_ids') or []
- for _id in mcp_tool_ids + tool_ids + skill_tool_ids:
- _result.append(_id)
- elif node.get('type') == 'mcp-node':
- mcp_tool_id = node.get('properties', {}).get(
- 'node_data', {}).get('mcp_tool_id')
- if mcp_tool_id:
- _result.append(mcp_tool_id)
- if with_deep:
- workflow_list = QuerySet(Tool).filter(id__in=_result, tool_type=ToolType.WORKFLOW)
- tool_work_flow_list = QuerySet(ToolWorkflow).filter(tool_id__in=[wl.id for wl in workflow_list])
- for tool_work_flow in tool_work_flow_list:
- child_tool_id_list = get_child_tool_id_list(tool_work_flow.work_flow, [])
- for c in child_tool_id_list:
- _result.append(c)
- return _result
- def get_child_tool_id_list(work_flow, response):
- from tools.models import ToolWorkflow, ToolType
- tool_id_list = get_tool_id_list(work_flow, False)
- tool_id_list = [tool_id for tool_id in tool_id_list if
- len([r for r in response if r == tool_id]) == 0]
- tool_list = []
- if len(tool_id_list) > 0:
- tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
- work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW]
- if len(work_flow_tools) > 0:
- work_flow_tool_dict = {tw.tool_id: tw for tw in
- QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])}
- for tool in tool_list:
- response.append(str(tool.id))
- if tool.tool_type == ToolType.WORKFLOW:
- get_child_tool_id_list(work_flow_tool_dict.get(tool.id).work_flow, response)
- else:
- for tool in tool_list:
- response.append(str(tool.id))
- return response
- def build_schema(fields: dict):
- return create_model("dynamicSchema", **fields)
- def get_type(_type: str):
- if _type == 'float':
- return float
- if _type == 'string':
- return str
- if _type == 'int':
- return int
- if _type == 'dict':
- return dict
- if _type == 'array':
- return list
- if _type == 'boolean':
- return bool
- return object
- def get_workflow_args(tool, qv):
- for node in qv.work_flow.get('nodes'):
- if node.get('type') == 'tool-base-node':
- input_field_list = node.get('properties').get('user_input_field_list')
- return build_schema(
- {field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
- for field in input_field_list})
- return build_schema({})
- def get_workflow_func(source_type, source_id, tool, qv, workspace_id):
- tool_id = tool.id
- tool_record_id = str(uuid.uuid7())
- took_execute = ToolExecute(tool_id, tool_record_id,
- workspace_id,
- source_type,
- source_id,
- False)
- def inner(**kwargs):
- from application.flow.tool_workflow_manage import ToolWorkflowManage
- work_flow_manage = ToolWorkflowManage(
- Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
- {
- 'chat_record_id': tool_record_id,
- 'tool_id': tool_id,
- 'stream': True,
- 'workspace_id': workspace_id,
- **kwargs},
- ToolWorkflowPostHandler(took_execute, tool_id),
- is_the_task_interrupted=lambda: False,
- child_node=None,
- start_node_id=None,
- start_node_data=None,
- chat_record=None
- )
- res = work_flow_manage.run()
- for r in res:
- pass
- return work_flow_manage.out_context
- return inner
- def get_tools(source_type, source_id, tool_workflow_ids, workspace_id):
- tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, is_active=True, tool_type=ToolType.WORKFLOW,
- workspace_id=workspace_id)
- latest_subquery = ToolWorkflowVersion.objects.filter(
- tool_id=OuterRef('tool_id')
- ).order_by('-create_time')
- qs = ToolWorkflowVersion.objects.filter(
- tool_id__in=[t.id for t in tools],
- id=Subquery(latest_subquery.values('id')[:1])
- )
- qd = {q.tool_id: q for q in qs}
- results = []
- for tool in tools:
- qv = qd.get(tool.id)
- func = get_workflow_func(source_type, source_id, tool, qv,
- workspace_id)
- args = get_workflow_args(tool, qv)
- tool = StructuredTool.from_function(
- func=func,
- name=tool.name,
- description=tool.desc,
- args_schema=args,
- )
- results.append(tool)
- return results
|