tools.py 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: utils.py
  6. @date:2024/6/6 15:15
  7. @desc:
  8. """
  9. from langchain_core.tools import StructuredTool
  10. from application.flow.common import Workflow, WorkflowMode
  11. from application.serializers.common import ToolExecute
  12. from tools.models import ToolRecord, Tool, ToolScope, ToolWorkflowVersion, ToolType
  13. from maxkb.const import CONFIG
  14. from knowledge.models.knowledge_action import State
  15. from knowledge.models import File
  16. from common.utils.logger import maxkb_logger
  17. from common.result import result
  18. from application.flow.i_step_node import WorkFlowPostHandler, ToolWorkflowPostHandler
  19. from application.flow.backend.sandbox_shell import SandboxShellBackend
  20. import asyncio
  21. import io
  22. import json
  23. import os
  24. import queue
  25. import re
  26. import shutil
  27. import threading
  28. import zipfile
  29. from functools import reduce
  30. from typing import Iterator
  31. from pydantic import Field, create_model
  32. import uuid_utils.compat as uuid
  33. from asgiref.sync import sync_to_async
  34. from deepagents import create_deep_agent
  35. from django.db.models import QuerySet, OuterRef, Subquery
  36. from django.http import StreamingHttpResponse
  37. from langchain_core.messages import BaseMessageChunk, BaseMessage, ToolMessage, AIMessageChunk, SystemMessage
  38. from langchain_mcp_adapters.client import MultiServerMCPClient
  39. from langgraph.checkpoint.memory import MemorySaver
  40. # ---------------------------------------------------------------------------
  41. # Fix: qwen's OpenAI-compatible streaming sends id='' (empty string) for
  42. # intermediate tool_call_chunks while only the first chunk carries the real
  43. # id ('call_xxx...'). langchain-core's merge_lists treats '' != 'call_xxx' as
  44. # an ID conflict and _appends_ instead of merging → the accumulated AIMessage
  45. # ends up with two separate tool_calls (one with empty args, one with empty
  46. # id) instead of one correct entry. This causes the Qwen API to reject the
  47. # next request with "function.arguments must be in JSON format".
  48. #
  49. # Patch: normalise id='' → None for items that have an 'index' key
  50. # (i.e. tool_call_chunk dicts). merge_lists treats None as "no id" and will
  51. # merge with any existing entry, keeping the real id from the first chunk.
  52. # ---------------------------------------------------------------------------
  53. import langchain_core.messages.ai as _lc_ai_module
  54. from langchain_core.utils._merge import merge_lists as _original_merge_lists
  55. def _merge_lists_normalize_empty_tool_chunk_ids(left, *others):
  56. """Wrapper around merge_lists that normalises empty-string IDs to None in
  57. tool_call_chunk items (those with an 'index' key) so that qwen streaming
  58. chunks with id='' are merged correctly by index."""
  59. def _norm(lst):
  60. if lst is None:
  61. return lst
  62. result = []
  63. for item in lst:
  64. if isinstance(item, dict) and 'index' in item and item.get('id') == '':
  65. item = {**item, 'id': None}
  66. result.append(item)
  67. return result
  68. return _original_merge_lists(
  69. _norm(left),
  70. *[_norm(o) for o in others],
  71. )
  72. # Replace the module-level reference used by add_ai_message_chunks in ai.py
  73. _lc_ai_module.merge_lists = _merge_lists_normalize_empty_tool_chunk_ids
  74. class Reasoning:
  75. def __init__(self, reasoning_content_start, reasoning_content_end):
  76. self.content = ""
  77. self.reasoning_content = ""
  78. self.all_content = ""
  79. self.reasoning_content_start_tag = reasoning_content_start
  80. self.reasoning_content_end_tag = reasoning_content_end
  81. self.reasoning_content_start_tag_len = len(
  82. reasoning_content_start) if reasoning_content_start is not None else 0
  83. self.reasoning_content_end_tag_len = len(
  84. reasoning_content_end) if reasoning_content_end is not None else 0
  85. self.reasoning_content_end_tag_prefix = reasoning_content_end[
  86. 0] if self.reasoning_content_end_tag_len > 0 else ''
  87. self.reasoning_content_is_start = False
  88. self.reasoning_content_is_end = False
  89. self.reasoning_content_chunk = ""
  90. def get_end_reasoning_content(self):
  91. if not self.reasoning_content_is_start and not self.reasoning_content_is_end:
  92. r = {'content': self.all_content, 'reasoning_content': ''}
  93. self.reasoning_content_chunk = ""
  94. return r
  95. if self.reasoning_content_is_start and not self.reasoning_content_is_end:
  96. r = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
  97. self.reasoning_content_chunk = ""
  98. return r
  99. return {'content': '', 'reasoning_content': ''}
  100. def _normalize_content(self, content):
  101. """将不同类型的内容统一转换为字符串"""
  102. if isinstance(content, str):
  103. return content
  104. elif isinstance(content, list):
  105. # 处理包含多种内容类型的列表
  106. normalized_parts = []
  107. for item in content:
  108. if isinstance(item, dict):
  109. if item.get('type') == 'text':
  110. normalized_parts.append(item.get('text', ''))
  111. return ''.join(normalized_parts)
  112. else:
  113. return str(content)
  114. def get_reasoning_content(self, chunk):
  115. # 如果没有开始思考过程标签那么就全是结果
  116. if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
  117. self.content += chunk.content
  118. return {'content': chunk.content, 'reasoning_content': ''}
  119. # 如果没有结束思考过程标签那么就全部是思考过程
  120. if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
  121. return {'content': '', 'reasoning_content': chunk.content}
  122. chunk.content = self._normalize_content(chunk.content)
  123. self.all_content += chunk.content
  124. if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
  125. if self.all_content.startswith(self.reasoning_content_start_tag):
  126. self.reasoning_content_is_start = True
  127. self.reasoning_content_chunk = self.all_content[self.reasoning_content_start_tag_len:]
  128. else:
  129. if not self.reasoning_content_is_end:
  130. self.reasoning_content_is_end = True
  131. self.content += self.all_content
  132. return {'content': self.all_content,
  133. 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
  134. '') if chunk.additional_kwargs else ''
  135. }
  136. else:
  137. if self.reasoning_content_is_start:
  138. self.reasoning_content_chunk += chunk.content
  139. reasoning_content_end_tag_prefix_index = self.reasoning_content_chunk.find(
  140. self.reasoning_content_end_tag_prefix)
  141. if self.reasoning_content_is_end:
  142. self.content += chunk.content
  143. return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
  144. '') if chunk.additional_kwargs else ''
  145. }
  146. # 是否包含结束
  147. if reasoning_content_end_tag_prefix_index > -1:
  148. if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
  149. reasoning_content_end_tag_index = self.reasoning_content_chunk.find(
  150. self.reasoning_content_end_tag)
  151. if reasoning_content_end_tag_index > -1:
  152. reasoning_content_chunk = self.reasoning_content_chunk[
  153. 0:reasoning_content_end_tag_index]
  154. content_chunk = self.reasoning_content_chunk[
  155. reasoning_content_end_tag_index + self.reasoning_content_end_tag_len:]
  156. self.reasoning_content += reasoning_content_chunk
  157. self.content += content_chunk
  158. self.reasoning_content_chunk = ""
  159. self.reasoning_content_is_end = True
  160. return {'content': content_chunk, 'reasoning_content': reasoning_content_chunk}
  161. else:
  162. reasoning_content_chunk = self.reasoning_content_chunk[
  163. 0:reasoning_content_end_tag_prefix_index + 1]
  164. self.reasoning_content_chunk = self.reasoning_content_chunk.replace(
  165. reasoning_content_chunk, '')
  166. self.reasoning_content += reasoning_content_chunk
  167. return {'content': '', 'reasoning_content': reasoning_content_chunk}
  168. else:
  169. return {'content': '', 'reasoning_content': ''}
  170. else:
  171. if self.reasoning_content_is_end:
  172. self.content += chunk.content
  173. return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
  174. '') if chunk.additional_kwargs else ''
  175. }
  176. else:
  177. # aaa
  178. result = {'content': '',
  179. 'reasoning_content': self.reasoning_content_chunk}
  180. self.reasoning_content += self.reasoning_content_chunk
  181. self.reasoning_content_chunk = ""
  182. return result
  183. def event_content(chat_id, chat_record_id, response, workflow,
  184. write_context,
  185. post_handler: WorkFlowPostHandler):
  186. """
  187. 用于处理流式输出
  188. @param chat_id: 会话id
  189. @param chat_record_id: 对话记录id
  190. @param response: 响应数据
  191. @param workflow: 工作流管理器
  192. @param write_context 写入节点上下文
  193. @param post_handler: 后置处理器
  194. """
  195. answer = ''
  196. try:
  197. for chunk in response:
  198. answer += chunk.content
  199. yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
  200. 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n"
  201. write_context(answer, 200)
  202. post_handler.handler(chat_id, chat_record_id, answer, workflow)
  203. yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
  204. 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n"
  205. except Exception as e:
  206. answer = str(e)
  207. write_context(answer, 500)
  208. post_handler.handler(chat_id, chat_record_id, answer, workflow)
  209. yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
  210. 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n"
  211. def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context,
  212. post_handler):
  213. """
  214. 将结果转换为服务流输出
  215. @param chat_id: 会话id
  216. @param chat_record_id: 对话记录id
  217. @param response: 响应数据
  218. @param workflow: 工作流管理器
  219. @param write_context 写入节点上下文
  220. @param post_handler: 后置处理器
  221. @return: 响应
  222. """
  223. r = StreamingHttpResponse(
  224. streaming_content=event_content(
  225. chat_id, chat_record_id, response, workflow, write_context, post_handler),
  226. content_type='text/event-stream;charset=utf-8',
  227. charset='utf-8')
  228. r['Cache-Control'] = 'no-cache'
  229. return r
  230. def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context,
  231. post_handler: WorkFlowPostHandler):
  232. """
  233. 将结果转换为服务输出
  234. @param chat_id: 会话id
  235. @param chat_record_id: 对话记录id
  236. @param response: 响应数据
  237. @param workflow: 工作流管理器
  238. @param write_context 写入节点上下文
  239. @param post_handler: 后置处理器
  240. @return: 响应
  241. """
  242. answer = response.content
  243. write_context(answer)
  244. post_handler.handler(chat_id, chat_record_id, answer, workflow)
  245. return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
  246. 'content': answer, 'is_end': True})
  247. def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow,
  248. post_handler: WorkFlowPostHandler):
  249. answer = response.content
  250. post_handler.handler(chat_id, chat_record_id, answer, workflow)
  251. return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
  252. 'content': answer, 'is_end': True})
  253. def to_stream_response_simple(stream_event):
  254. r = StreamingHttpResponse(
  255. streaming_content=stream_event,
  256. content_type='text/event-stream;charset=utf-8',
  257. charset='utf-8')
  258. r['Cache-Control'] = 'no-cache'
  259. return r
  260. def generate_tool_message_complete(icon, name, input_content, output_content):
  261. """生成包含输入和输出的工具消息模版"""
  262. # 确保输入内容是字符串,如果不是则尝试转换为 JSON 字符串
  263. if not isinstance(input_content, str):
  264. input_content = json.dumps(input_content, ensure_ascii=False)
  265. # 格式化输出
  266. if not isinstance(output_content, str):
  267. output_content = json.dumps(output_content, ensure_ascii=False)
  268. content = {
  269. "icon": icon,
  270. "title": name,
  271. "type": "simple-tool-calls",
  272. "content": {
  273. "input": input_content,
  274. "output": output_content
  275. }
  276. }
  277. return f'<tool_calls_render>{json.dumps(content)}</tool_calls_render>'
  278. # 全局单例事件循环
  279. _global_loop = None
  280. _loop_thread = None
  281. _loop_lock = threading.Lock()
  282. def get_global_loop():
  283. """获取全局共享的事件循环"""
  284. global _global_loop, _loop_thread
  285. with _loop_lock:
  286. if _global_loop is None:
  287. _global_loop = asyncio.new_event_loop()
  288. def run_forever():
  289. asyncio.set_event_loop(_global_loop)
  290. _global_loop.run_forever()
  291. _loop_thread = threading.Thread(
  292. target=run_forever, daemon=True, name="GlobalAsyncLoop")
  293. _loop_thread.start()
  294. return _global_loop
  295. def _extract_tool_id(raw_id):
  296. """从 raw_id 中提取最后一个符合 call_... 模式的 id,若无匹配则返回原值或 None"""
  297. if not raw_id:
  298. return None
  299. if not isinstance(raw_id, str):
  300. raw_id = str(raw_id)
  301. s = raw_id
  302. prefix = 'call_'
  303. positions = [m.start() for m in re.finditer(re.escape(prefix), s)]
  304. if not positions:
  305. return raw_id
  306. # 取最后一个前缀位置,截到下一个前缀或结尾
  307. start = positions[-1]
  308. end = len(s)
  309. for pos in positions:
  310. if pos > start:
  311. end = pos
  312. break
  313. tool_id = s[start:end]
  314. return tool_id or raw_id
  315. async def _initialize_skills(mcp_servers, temp_dir):
  316. skills_dir = os.path.join(temp_dir, 'skills')
  317. mcp_config = json.loads(mcp_servers)
  318. if "skills" in mcp_config:
  319. skill_file_items = mcp_config.pop('skills')
  320. for skill_file in skill_file_items:
  321. # 使用 sync_to_async 包装 ORM 查询
  322. file = await sync_to_async(lambda: QuerySet(File).filter(id=skill_file['file_id']).first())()
  323. if not file:
  324. continue
  325. # get_bytes 可能也涉及 IO,也用 sync_to_async 包装
  326. file_bytes = await sync_to_async(file.get_bytes)()
  327. params = skill_file.get('params', {})
  328. with zipfile.ZipFile(io.BytesIO(file_bytes), 'r') as zip_ref:
  329. members = [
  330. m for m in zip_ref.namelist()
  331. if not m.startswith('__MACOSX/') and '__MACOSX' not in m
  332. ]
  333. for member in members:
  334. if ".." in member or member.startswith("/"):
  335. raise ValueError(f"非法路径: {member}")
  336. zip_ref.extractall(skills_dir, members=members)
  337. # 获取技能解压后的顶级目录名
  338. top_level_dirs = set()
  339. for member in members:
  340. parts = member.split('/')
  341. if parts[0]:
  342. top_level_dirs.add(parts[0])
  343. # 将 params 写入每个顶级目录下的 .env 文件
  344. if params:
  345. env_lines = []
  346. for key, value in params.items():
  347. # 对含空格或特殊字符的值加引号
  348. env_lines.append(f'{key}={value}')
  349. env_content = '\n'.join(env_lines) + '\n'
  350. for top_dir in top_level_dirs:
  351. env_path = os.path.join(skills_dir, top_dir, '.env')
  352. with open(env_path, 'w', encoding='utf-8') as f:
  353. f.write(env_content)
  354. os.system("chmod -R g+rx " + temp_dir) # 确保技能目录可访问
  355. client = MultiServerMCPClient(mcp_config)
  356. return client
  357. async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
  358. source_id=None, source_type=None, temp_dir=None, chat_id=None, extra_tools=None):
  359. try:
  360. checkpointer = MemorySaver()
  361. client = await _initialize_skills(mcp_servers, temp_dir)
  362. tools = await client.get_tools()
  363. for tool in tools:
  364. tool.handle_tool_error = True
  365. if extra_tools:
  366. for tool in extra_tools:
  367. tools.append(tool)
  368. # ---------------------------------------------------------------------------
  369. # Fix: vLLM (and Qwen chat templates) reject conversations that contain more
  370. # than one SystemMessage, or a SystemMessage that is not the very first
  371. # message. create_deep_agent always prepends its own BASE_AGENT_PROMPT as a
  372. # SystemMessage before calling the model (factory.py line ~1319). If
  373. # message_list already contains a SystemMessage (built in base_chat_node.py
  374. # via generate_message_list), the API receives two system messages and raises
  375. # "System message must be at the beginning."
  376. #
  377. # Solution: strip the user-supplied SystemMessage out of message_list and
  378. # pass its text as the system_prompt argument of create_deep_agent.
  379. # deepagents will then merge it with BASE_AGENT_PROMPT into a single
  380. # combined system message, so the model only ever sees one.
  381. # ---------------------------------------------------------------------------
  382. user_system_prompt = None
  383. filtered_message_list = []
  384. for msg in message_list:
  385. if isinstance(msg, SystemMessage):
  386. # Normalise content to plain string regardless of whether the
  387. # message was built with a str or a list of content blocks.
  388. if isinstance(msg.content, str):
  389. user_system_prompt = msg.content
  390. elif isinstance(msg.content, list):
  391. user_system_prompt = ''.join(
  392. item.get('text', '') if isinstance(item, dict) else str(item)
  393. for item in msg.content
  394. )
  395. else:
  396. user_system_prompt = str(msg.content)
  397. else:
  398. filtered_message_list.append(msg)
  399. agent = create_deep_agent(
  400. model=chat_model,
  401. backend=SandboxShellBackend(root_dir=temp_dir, virtual_mode=True),
  402. skills=['/skills'],
  403. tools=tools,
  404. system_prompt=user_system_prompt,
  405. interrupt_on={
  406. "write_file": False,
  407. "read_file": False,
  408. "edit_file": False
  409. },
  410. checkpointer=checkpointer,
  411. )
  412. recursion_limit = int(CONFIG.get(
  413. "LANGCHAIN_GRAPH_RECURSION_LIMIT", '100'))
  414. response = agent.astream(
  415. {"messages": filtered_message_list},
  416. config={"recursion_limit": recursion_limit,
  417. "configurable": {"thread_id": chat_id}},
  418. stream_mode='messages'
  419. )
  420. tool_calls_info = {} # tool_id -> {'name': ..., 'input': ...}
  421. # key(index/id) -> {'id': ..., 'name': ..., 'arguments': ...}
  422. _tool_fragments = {}
  423. def _merge_arguments(entry, part_args):
  424. if not isinstance(part_args, str):
  425. try:
  426. part_args = json.dumps(part_args, ensure_ascii=False)
  427. except Exception:
  428. part_args = str(part_args) if part_args else ''
  429. if not part_args:
  430. return
  431. # Some providers first emit placeholder args like "{}" and then
  432. # stream the real JSON fragments via later chunks. Prefer fragments.
  433. if entry['arguments'] in ('{}', '[]') and part_args.startswith('{'):
  434. entry['arguments'] = part_args
  435. return
  436. if entry['arguments']:
  437. try:
  438. existing_obj = json.loads(entry['arguments'])
  439. new_obj = json.loads(part_args)
  440. if isinstance(existing_obj, dict) and isinstance(new_obj, dict):
  441. merged = {**existing_obj, **new_obj}
  442. entry['arguments'] = json.dumps(
  443. merged, ensure_ascii=False)
  444. else:
  445. entry['arguments'] += part_args
  446. except (json.JSONDecodeError, ValueError):
  447. entry['arguments'] += part_args
  448. else:
  449. entry['arguments'] = part_args
  450. def _get_fragment_key(idx, raw_id):
  451. if idx is not None:
  452. return f'idx:{idx}'
  453. if raw_id and str(raw_id).strip():
  454. return f"id:{_extract_tool_id(str(raw_id).strip())}"
  455. return None
  456. def _upsert_fragment(key, raw_id, func_name, part_args):
  457. if key is None:
  458. return
  459. entry = _tool_fragments.setdefault(
  460. key, {'id': '', 'name': '', 'arguments': ''})
  461. if raw_id and str(raw_id).strip():
  462. new_id = str(raw_id).strip()
  463. if entry.get('completed') and entry.get('id') and entry['id'] != new_id:
  464. maxkb_logger.debug(
  465. f"Resetting completed fragment {key}: old ID {entry['id']} -> new ID {new_id}")
  466. entry.clear()
  467. entry.update({'id': '', 'name': '', 'arguments': ''})
  468. entry['id'] = new_id
  469. if func_name:
  470. entry['name'] = func_name
  471. _merge_arguments(entry, part_args)
  472. async for chunk in response:
  473. # print(chunk)
  474. if isinstance(chunk[0], AIMessageChunk):
  475. # ----------------------------------------------------------------
  476. # 1. 从 tool_call_chunks 中聚合工具调用片段
  477. # (qwen/OpenAI streaming 通过 tool_call_chunks 传递,
  478. # additional_kwargs['tool_calls'] 在流式时通常为空)
  479. # ----------------------------------------------------------------
  480. for tc_chunk in (chunk[0].tool_call_chunks or []):
  481. raw_id = tc_chunk.get('id')
  482. key = _get_fragment_key(tc_chunk.get('index'), raw_id)
  483. _upsert_fragment(
  484. key,
  485. raw_id,
  486. tc_chunk.get('name'),
  487. tc_chunk.get('args', '')
  488. )
  489. # ----------------------------------------------------------------
  490. # 1.1 兼容部分模型将工具调用放在 chunk.tool_calls,且 tool_call_chunks
  491. # 的 index 为空(例如 ollama/qwen)
  492. # ----------------------------------------------------------------
  493. has_tool_call_chunks = bool(chunk[0].tool_call_chunks)
  494. for tool_call in (chunk[0].tool_calls or []):
  495. raw_id = tool_call.get('id')
  496. part_args = tool_call.get('args', '')
  497. # qwen-plus often emits {} here as a placeholder while
  498. # the real args are split in tool_call_chunks/invalid_tool_calls.
  499. if has_tool_call_chunks and (
  500. part_args == '' or part_args == {} or part_args == []
  501. ):
  502. part_args = ''
  503. key = _get_fragment_key(tool_call.get('index'), raw_id)
  504. _upsert_fragment(
  505. key,
  506. raw_id,
  507. tool_call.get('name'),
  508. part_args
  509. )
  510. # ----------------------------------------------------------------
  511. # 1.2 兼容 invalid_tool_calls 分片(部分模型会把中间 JSON 片段放这里)
  512. # ----------------------------------------------------------------
  513. for invalid_tool_call in (chunk[0].invalid_tool_calls or []):
  514. raw_id = invalid_tool_call.get('id')
  515. key = _get_fragment_key(
  516. invalid_tool_call.get('index'), raw_id)
  517. _upsert_fragment(
  518. key,
  519. raw_id,
  520. invalid_tool_call.get('name'),
  521. invalid_tool_call.get('args', '')
  522. )
  523. # ----------------------------------------------------------------
  524. # 2. 兼容 additional_kwargs['tool_calls'] 方式(旧格式/非流式情况)
  525. # ----------------------------------------------------------------
  526. legacy_tool_calls = chunk[0].additional_kwargs.get(
  527. 'tool_calls', [])
  528. for tool_call in legacy_tool_calls:
  529. raw_id = tool_call.get('id')
  530. func = tool_call.get('function', {})
  531. if isinstance(func, dict):
  532. func_name = func.get('name')
  533. part_args = func.get('arguments', '')
  534. else:
  535. func_name = tool_call.get('name')
  536. part_args = tool_call.get('arguments', '')
  537. key = _get_fragment_key(tool_call.get('index'), raw_id)
  538. _upsert_fragment(key, raw_id, func_name, part_args)
  539. # ----------------------------------------------------------------
  540. # 3. 检测工具调用结束,更新 tool_calls_info
  541. # ----------------------------------------------------------------
  542. is_finish_chunk = (
  543. chunk[0].response_metadata.get(
  544. 'finish_reason') == 'tool_calls'
  545. or chunk[0].chunk_position == 'last'
  546. )
  547. if is_finish_chunk:
  548. # 在 finish chunk 时,将所有未完成的 fragment 标记完成并更新 tool_calls_info
  549. maxkb_logger.debug(
  550. f"Processing finish chunk. Tool fragments: {_tool_fragments}")
  551. for idx, entry in _tool_fragments.items():
  552. if entry.get('completed'):
  553. maxkb_logger.debug(
  554. f"Skipping fragment {idx}: already completed")
  555. continue
  556. if not entry.get('id'):
  557. maxkb_logger.debug(
  558. f"Skipping fragment {idx}: missing id. Fragment: {entry}")
  559. continue
  560. if not entry.get('arguments'):
  561. maxkb_logger.debug(
  562. f"Skipping fragment {idx}: missing arguments. Fragment: {entry}")
  563. continue
  564. if not entry.get('completed') and entry.get('id') and entry.get('arguments'):
  565. try:
  566. parsed_args = json.loads(entry['arguments'])
  567. filtered_args = {
  568. k: v for k, v in parsed_args.items()
  569. if k not in tool_init_params
  570. } if tool_init_params else parsed_args
  571. normalized_id = _extract_tool_id(entry['id'])
  572. info = {
  573. 'name': entry['name'],
  574. 'input': json.dumps(filtered_args, ensure_ascii=False)
  575. }
  576. tool_calls_info[entry['id']] = info
  577. if normalized_id and normalized_id != entry['id']:
  578. tool_calls_info[normalized_id] = info
  579. entry['completed'] = True
  580. maxkb_logger.debug(
  581. f"Added tool call {entry['id']} to tool_calls_info")
  582. except (json.JSONDecodeError, ValueError) as e:
  583. # JSON parsing failed, but still add to tool_calls_info with raw arguments
  584. # to prevent "Tool ID not found" errors when ToolMessage arrives
  585. maxkb_logger.warning(
  586. f"Failed to parse tool arguments at finish for tool {entry.get('id', 'unknown')}: "
  587. f"{entry['arguments']}, error: {e}. Using raw arguments.")
  588. normalized_id = _extract_tool_id(entry['id'])
  589. info = {
  590. 'name': entry['name'],
  591. # Use raw arguments
  592. 'input': entry['arguments']
  593. }
  594. tool_calls_info[entry['id']] = info
  595. if normalized_id and normalized_id != entry['id']:
  596. tool_calls_info[normalized_id] = info
  597. entry['completed'] = True
  598. # ----------------------------------------------------------------
  599. # 4. 修复 tool_call_chunks 中的空 id(回填已知 id)
  600. # ----------------------------------------------------------------
  601. if chunk[0].tool_call_chunks:
  602. for tc_chunk in chunk[0].tool_call_chunks:
  603. key = _get_fragment_key(
  604. tc_chunk.get('index'), tc_chunk.get('id'))
  605. if key is not None:
  606. frag = _tool_fragments.get(key)
  607. if frag and frag.get('id') and not tc_chunk.get('id'):
  608. tc_chunk['id'] = frag['id']
  609. # ----------------------------------------------------------------
  610. # 5. 修复 additional_kwargs['tool_calls'](兼容旧格式)
  611. # 仅在 finish chunk 时写入完整参数,避免污染中间 chunk 的
  612. # additional_kwargs(中间 chunk 会被 ainvoke 累积,如果写入
  613. # 不完整 JSON 会导致下一轮 API 调用出现 arguments 非 JSON 格式错误)
  614. # ----------------------------------------------------------------
  615. if legacy_tool_calls and is_finish_chunk:
  616. fixed_tool_calls = []
  617. for tool_call in legacy_tool_calls:
  618. key = _get_fragment_key(
  619. tool_call.get('index'), tool_call.get('id'))
  620. frag = _tool_fragments.get(
  621. key) if key is not None else None
  622. tc = dict(tool_call)
  623. if frag and frag.get('id') and not tc.get('id'):
  624. tc['id'] = frag['id']
  625. if frag and isinstance(tc.get('function'), dict):
  626. tc['function'] = dict(tc['function'])
  627. if frag.get('completed'):
  628. tc['function']['arguments'] = frag['arguments']
  629. fixed_tool_calls.append(tc)
  630. chunk[0].additional_kwargs['tool_calls'] = fixed_tool_calls
  631. yield chunk[0]
  632. if mcp_output_enable and isinstance(chunk[0], ToolMessage):
  633. tool_id = chunk[0].tool_call_id
  634. normalized_tool_id = _extract_tool_id(tool_id)
  635. tool_info = tool_calls_info.get(tool_id) or tool_calls_info.get(
  636. normalized_tool_id)
  637. if tool_info:
  638. try:
  639. if isinstance(chunk[0].content, str):
  640. tool_result = json.loads(chunk[0].content)
  641. elif isinstance(chunk[0].content, dict):
  642. tool_result = chunk[0].content
  643. elif isinstance(chunk[0].content, list):
  644. tool_result = chunk[0].content[0] if len(
  645. chunk[0].content) > 0 else {}
  646. else:
  647. tool_result = {}
  648. text = tool_result.get('text') if 'text' in tool_result else None
  649. text_result = json.loads(text) if text else tool_result
  650. if text:
  651. tool_lib_id = text_result.pop('tool_id') if 'tool_id' in text_result else None
  652. else:
  653. tool_lib_id = tool_result.pop('tool_id') if 'tool_id' in tool_result else None
  654. if tool_lib_id:
  655. await save_tool_record(tool_lib_id, tool_info, tool_result, source_id, source_type)
  656. tool_result = json.dumps(text_result, ensure_ascii=False)
  657. except Exception as e:
  658. tool_result = chunk[0].content
  659. content = generate_tool_message_complete(
  660. tool_info.get('icon', ''),
  661. tool_info['name'],
  662. tool_info['input'],
  663. tool_result
  664. )
  665. chunk[0].content = content
  666. else:
  667. maxkb_logger.warning(
  668. f"Tool ID {tool_id} not found in tool_calls_info. "
  669. f"Normalized Tool ID: {normalized_tool_id}. "
  670. f"Available IDs: {list(tool_calls_info.keys())}. "
  671. f"Tool fragments at this point: {_tool_fragments}"
  672. )
  673. yield chunk[0]
  674. except ExceptionGroup as eg:
  675. def get_real_error(exc):
  676. if isinstance(exc, ExceptionGroup):
  677. return get_real_error(exc.exceptions[0])
  678. return exc
  679. real_error = get_real_error(eg)
  680. error_msg = f"{type(real_error).__name__}: {str(real_error)}"
  681. raise RuntimeError(error_msg) from None
  682. except Exception as e:
  683. error_msg = f"{type(e).__name__}: {str(e)}"
  684. raise RuntimeError(error_msg) from None
  685. async def save_tool_record(tool_id, tool_info, tool_result, source_id, source_type):
  686. tool = await sync_to_async(lambda: QuerySet(Tool).filter(id=tool_id).first())()
  687. tool_info['icon'] = tool.icon
  688. tool_record = ToolRecord(
  689. id=uuid.uuid7(),
  690. workspace_id=tool.workspace_id,
  691. tool_id=tool_id,
  692. source_type=source_type,
  693. source_id=source_id,
  694. meta={'input': tool_info['input'], 'output': tool_result},
  695. state=State.SUCCESS
  696. )
  697. await sync_to_async(tool_record.save)()
  698. def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True, tool_init_params={},
  699. source_id=None, source_type=None, chat_id=None, extra_tools=None):
  700. """使用全局事件循环,不创建新实例"""
  701. result_queue = queue.Queue()
  702. loop = get_global_loop() # 使用共享循环
  703. # 创建临时文件夹
  704. if chat_id:
  705. temp_dir = os.path.join('/tmp', chat_id[:8])
  706. else:
  707. temp_dir = os.path.join('/tmp', uuid.uuid7().hex[:8])
  708. skills_dir = os.path.join(temp_dir, 'skills')
  709. os.makedirs(skills_dir, exist_ok=True)
  710. # print(f"Initializing skills in temporary directory: {skills_dir}")
  711. async def _run():
  712. try:
  713. async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable, tool_init_params,
  714. source_id, source_type, temp_dir, chat_id, extra_tools)
  715. async for chunk in async_gen:
  716. result_queue.put(('data', chunk))
  717. except Exception as e:
  718. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  719. result_queue.put(('error', e))
  720. finally:
  721. result_queue.put(('done', None))
  722. # 在全局循环中调度任务
  723. asyncio.run_coroutine_threadsafe(_run(), loop)
  724. while True:
  725. msg_type, data = result_queue.get()
  726. if msg_type == 'done':
  727. # 清理临时文件夹
  728. shutil.rmtree(temp_dir, ignore_errors=True)
  729. break
  730. if msg_type == 'error':
  731. # 清理临时文件夹
  732. shutil.rmtree(temp_dir, ignore_errors=True)
  733. raise data
  734. yield data
  735. async def anext_async(agen):
  736. return await agen.__anext__()
  737. target_source_node_mapping = {
  738. 'TOOL': {'tool-lib-node': lambda n: [n.get('properties').get('node_data').get('tool_lib_id')],
  739. 'ai-chat-node': lambda n: [*(n.get('properties').get('node_data').get('mcp_tool_ids') or []),
  740. *(n.get('properties').get('node_data').get('tool_ids') or []),
  741. *(n.get('properties').get('node_data').get('skill_tool_ids') or [])],
  742. 'mcp-node': lambda n: [n.get('properties').get('node_data').get('mcp_tool_id')],
  743. 'tool-workflow-lib-node': lambda n: [n.get('properties').get('node_data').get('tool_lib_id')]
  744. },
  745. 'MODEL': {'ai-chat-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  746. 'question-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  747. 'speech-to-text-node': lambda n: [n.get('properties').get('node_data').get('stt_model_id')],
  748. 'text-to-speech-node': lambda n: [n.get('properties').get('node_data').get('tts_model_id')],
  749. 'image-to-video-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  750. 'image-generate-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  751. 'intent-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  752. 'image-understand-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  753. 'parameter-extraction-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  754. 'video-understand-node': lambda n: [n.get('properties').get('node_data').get('model_id')],
  755. },
  756. 'KNOWLEDGE': {'search-knowledge-node': lambda n: n.get('properties').get('node_data').get('knowledge_id_list')},
  757. 'APPLICATION': {
  758. 'application-node': lambda n: [n.get('properties').get('node_data').get('application_id')]
  759. }
  760. }
  761. def get_node_handle_callback(source_type, source_id):
  762. def node_handle_callback(node):
  763. from system_manage.models.resource_mapping import ResourceMapping
  764. response = []
  765. for key, value in target_source_node_mapping.items():
  766. if node.get('type') in value:
  767. call = value.get(node.get('type'))
  768. target_source_id_list = call(node)
  769. for target_source_id in target_source_id_list:
  770. if target_source_id:
  771. response.append(ResourceMapping(source_type=source_type, target_type=key, source_id=source_id,
  772. target_id=target_source_id))
  773. return response
  774. return node_handle_callback
  775. def get_workflow_resource(workflow, node_handle):
  776. response = []
  777. if 'nodes' in workflow:
  778. for node in workflow.get('nodes'):
  779. rs = node_handle(node)
  780. if rs:
  781. for r in rs:
  782. response.append(r)
  783. if node.get('type') == 'loop-node':
  784. r = get_workflow_resource(node.get('properties', {}).get(
  785. 'node_data', {}).get('loop_body'), node_handle)
  786. for rn in r:
  787. response.append(rn)
  788. return list({(str(item.target_type) + str(item.target_id)): item for item in response}.values())
  789. return []
  790. application_instance_field_call_dict = {
  791. 'TOOL': [
  792. lambda instance: instance.mcp_tool_ids or [],
  793. lambda instance: instance.skill_tool_ids or [],
  794. lambda instance: instance.tool_ids or []
  795. ],
  796. 'MODEL': [
  797. lambda instance: [instance.model_id] if instance.model_id else [],
  798. lambda instance: [
  799. instance.tts_model_id] if instance.tts_model_id else [],
  800. lambda instance: [
  801. instance.stt_model_id] if instance.stt_model_id else []
  802. ]
  803. }
  804. knowledge_instance_field_call_dict = {
  805. 'MODEL': [lambda instance: [instance.embedding_model_id] if instance.embedding_model_id else []],
  806. }
  807. def get_instance_resource(instance, source_type, source_id, instance_field_call_dict):
  808. response = []
  809. from system_manage.models.resource_mapping import ResourceMapping
  810. for target_type, call_list in instance_field_call_dict.items():
  811. target_id_list = reduce(
  812. lambda x, y: [*x, *y], [call(instance) for call in call_list], [])
  813. if target_id_list:
  814. for target_id in target_id_list:
  815. response.append(ResourceMapping(source_type=source_type, target_type=target_type, source_id=source_id,
  816. target_id=target_id))
  817. return response
  818. def save_workflow_mapping(workflow, source_type, source_id, other_resource_mapping=None):
  819. if not other_resource_mapping:
  820. other_resource_mapping = []
  821. from system_manage.models.resource_mapping import ResourceMapping
  822. from django.db.models import QuerySet
  823. QuerySet(ResourceMapping).filter(
  824. source_type=source_type, source_id=source_id).delete()
  825. resource_mapping_list = get_workflow_resource(workflow,
  826. get_node_handle_callback(source_type,
  827. source_id))
  828. resource_mapping_list += other_resource_mapping
  829. if resource_mapping_list:
  830. QuerySet(ResourceMapping).bulk_create(
  831. {(str(item.target_type) + str(item.target_id)): item for item in resource_mapping_list}.values())
  832. def get_tool_id_list(workflow, with_deep=False):
  833. from tools.models import ToolWorkflow, ToolType
  834. _result = []
  835. for node in workflow.get('nodes', []):
  836. if node.get('type') == 'tool-lib-node':
  837. tool_id = node.get('properties', {}).get(
  838. 'node_data', {}).get('tool_lib_id')
  839. if tool_id:
  840. _result.append(tool_id)
  841. elif node.get('type') == 'loop-node':
  842. r = get_tool_id_list(node.get('properties', {}).get(
  843. 'node_data', {}).get('loop_body', {}))
  844. for item in r:
  845. _result.append(item)
  846. elif node.get('type') == 'tool-workflow-lib-node':
  847. tool_id = node.get('properties', {}).get(
  848. 'node_data', {}).get('tool_lib_id')
  849. if tool_id:
  850. _result.append(tool_id)
  851. elif node.get('type') == 'ai-chat-node':
  852. node_data = node.get('properties', {}).get('node_data', {})
  853. mcp_tool_ids = node_data.get('mcp_tool_ids') or []
  854. skill_tool_ids = node_data.get('skill_tool_ids') or []
  855. tool_ids = node_data.get('tool_ids') or []
  856. for _id in mcp_tool_ids + tool_ids + skill_tool_ids:
  857. _result.append(_id)
  858. elif node.get('type') == 'mcp-node':
  859. mcp_tool_id = node.get('properties', {}).get(
  860. 'node_data', {}).get('mcp_tool_id')
  861. if mcp_tool_id:
  862. _result.append(mcp_tool_id)
  863. if with_deep:
  864. workflow_list = QuerySet(Tool).filter(id__in=_result, tool_type=ToolType.WORKFLOW)
  865. tool_work_flow_list = QuerySet(ToolWorkflow).filter(tool_id__in=[wl.id for wl in workflow_list])
  866. for tool_work_flow in tool_work_flow_list:
  867. child_tool_id_list = get_child_tool_id_list(tool_work_flow.work_flow, [])
  868. for c in child_tool_id_list:
  869. _result.append(c)
  870. return _result
  871. def get_child_tool_id_list(work_flow, response):
  872. from tools.models import ToolWorkflow, ToolType
  873. tool_id_list = get_tool_id_list(work_flow, False)
  874. tool_id_list = [tool_id for tool_id in tool_id_list if
  875. len([r for r in response if r == tool_id]) == 0]
  876. tool_list = []
  877. if len(tool_id_list) > 0:
  878. tool_list = QuerySet(Tool).filter(id__in=tool_id_list).exclude(scope=ToolScope.SHARED)
  879. work_flow_tools = [tool for tool in tool_list if tool.tool_type == ToolType.WORKFLOW]
  880. if len(work_flow_tools) > 0:
  881. work_flow_tool_dict = {tw.tool_id: tw for tw in
  882. QuerySet(ToolWorkflow).filter(tool_id__in=[t.id for t in work_flow_tools])}
  883. for tool in tool_list:
  884. response.append(str(tool.id))
  885. if tool.tool_type == ToolType.WORKFLOW:
  886. get_child_tool_id_list(work_flow_tool_dict.get(tool.id).work_flow, response)
  887. else:
  888. for tool in tool_list:
  889. response.append(str(tool.id))
  890. return response
  891. def build_schema(fields: dict):
  892. return create_model("dynamicSchema", **fields)
  893. def get_type(_type: str):
  894. if _type == 'float':
  895. return float
  896. if _type == 'string':
  897. return str
  898. if _type == 'int':
  899. return int
  900. if _type == 'dict':
  901. return dict
  902. if _type == 'array':
  903. return list
  904. if _type == 'boolean':
  905. return bool
  906. return object
  907. def get_workflow_args(tool, qv):
  908. for node in qv.work_flow.get('nodes'):
  909. if node.get('type') == 'tool-base-node':
  910. input_field_list = node.get('properties').get('user_input_field_list')
  911. return build_schema(
  912. {field.get('field'): (get_type(field.get('type')), Field(..., description=field.get('desc')))
  913. for field in input_field_list})
  914. return build_schema({})
  915. def get_workflow_func(source_type, source_id, tool, qv, workspace_id):
  916. tool_id = tool.id
  917. tool_record_id = str(uuid.uuid7())
  918. took_execute = ToolExecute(tool_id, tool_record_id,
  919. workspace_id,
  920. source_type,
  921. source_id,
  922. False)
  923. def inner(**kwargs):
  924. from application.flow.tool_workflow_manage import ToolWorkflowManage
  925. work_flow_manage = ToolWorkflowManage(
  926. Workflow.new_instance(qv.work_flow, WorkflowMode.TOOL),
  927. {
  928. 'chat_record_id': tool_record_id,
  929. 'tool_id': tool_id,
  930. 'stream': True,
  931. 'workspace_id': workspace_id,
  932. **kwargs},
  933. ToolWorkflowPostHandler(took_execute, tool_id),
  934. is_the_task_interrupted=lambda: False,
  935. child_node=None,
  936. start_node_id=None,
  937. start_node_data=None,
  938. chat_record=None
  939. )
  940. res = work_flow_manage.run()
  941. for r in res:
  942. pass
  943. return work_flow_manage.out_context
  944. return inner
  945. def get_tools(source_type, source_id, tool_workflow_ids, workspace_id):
  946. tools = QuerySet(Tool).filter(id__in=tool_workflow_ids, is_active=True, tool_type=ToolType.WORKFLOW,
  947. workspace_id=workspace_id)
  948. latest_subquery = ToolWorkflowVersion.objects.filter(
  949. tool_id=OuterRef('tool_id')
  950. ).order_by('-create_time')
  951. qs = ToolWorkflowVersion.objects.filter(
  952. tool_id__in=[t.id for t in tools],
  953. id=Subquery(latest_subquery.values('id')[:1])
  954. )
  955. qd = {q.tool_id: q for q in qs}
  956. results = []
  957. for tool in tools:
  958. qv = qd.get(tool.id)
  959. func = get_workflow_func(source_type, source_id, tool, qv,
  960. workspace_id)
  961. args = get_workflow_args(tool, qv)
  962. tool = StructuredTool.from_function(
  963. func=func,
  964. name=tool.name,
  965. description=tool.desc,
  966. args_schema=args,
  967. )
  968. results.append(tool)
  969. return results