tool_code.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # coding=utf-8
  2. import ast
  3. import base64
  4. import getpass
  5. import gzip
  6. import json
  7. import os
  8. try:
  9. import pwd
  10. except ImportError:
  11. pwd = None
  12. import random
  13. try:
  14. import resource
  15. except ImportError:
  16. resource = None
  17. import socket
  18. import subprocess
  19. import sys
  20. import tempfile
  21. import time
  22. from contextlib import contextmanager
  23. from contextlib import suppress
  24. from textwrap import dedent
  25. import uuid_utils.compat as uuid
  26. from django.utils.translation import gettext_lazy as _
  27. from common.utils.logger import maxkb_logger
  28. from maxkb.const import BASE_DIR, CONFIG
  29. from maxkb.const import PROJECT_DIR
  30. _enable_sandbox = bool(int(CONFIG.get('SANDBOX', 0)))
  31. _run_user = 'sandbox' if _enable_sandbox else getpass.getuser()
  32. _sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox') if _enable_sandbox else os.path.join(PROJECT_DIR, 'data', 'sandbox')
  33. _sandbox_python_sys_path = CONFIG.get_sandbox_python_package_paths().split(',')
  34. _process_limit_timeout_seconds = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_TIMEOUT_SECONDS", '3600'))
  35. _process_limit_cpu_cores = min(max(int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_CPU_CORES", '1')), 1), len(os.sched_getaffinity(0))) if sys.platform.startswith("linux") else os.cpu_count() # 只支持linux,window和mac不支持
  36. _process_limit_mem_mb = int(CONFIG.get("SANDBOX_PYTHON_PROCESS_LIMIT_MEM_MB", '256'))
  37. class ToolExecutor:
  38. def __init__(self):
  39. pass
  40. @staticmethod
  41. def init_sandbox_dir():
  42. if not _enable_sandbox:
  43. # 不启用sandbox就不初始化目录
  44. return
  45. try:
  46. # 只初始化一次
  47. fd = os.open(os.path.join(PROJECT_DIR, 'tmp', 'tool_executor_init_dir.lock'), os.O_CREAT | os.O_EXCL | os.O_WRONLY)
  48. os.close(fd)
  49. except FileExistsError:
  50. # 文件已存在 → 已初始化过
  51. return
  52. maxkb_logger.info("Init sandbox dir.")
  53. try:
  54. os.system("chmod -R g-rwx /dev/shm /dev/mqueue")
  55. os.system("chmod o-rwx /run/postgresql")
  56. except Exception as e:
  57. maxkb_logger.warning(f'Exception: {e}', exc_info=True)
  58. pass
  59. if CONFIG.get("SANDBOX_TMP_DIR_ENABLED", '0') == "1":
  60. os.system("chmod g+rwx /tmp")
  61. # 初始化sandbox配置文件
  62. sandbox_lib_path = os.path.dirname(f'{_sandbox_path}/lib/sandbox.so')
  63. sandbox_conf_file_path = f'{sandbox_lib_path}/.sandbox.conf'
  64. if os.path.exists(sandbox_conf_file_path):
  65. os.remove(sandbox_conf_file_path)
  66. banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip()
  67. allow_dl_paths = CONFIG.get("SANDBOX_PYTHON_ALLOW_DL_PATHS",'').strip()
  68. allow_dl_open = CONFIG.get("SANDBOX_PYTHON_ALLOW_DL_OPEN",'0')
  69. allow_subprocess = CONFIG.get("SANDBOX_PYTHON_ALLOW_SUBPROCESS", '0')
  70. allow_syscall = CONFIG.get("SANDBOX_PYTHON_ALLOW_SYSCALL", '0')
  71. if banned_hosts:
  72. hostname = socket.gethostname()
  73. local_ip = socket.gethostbyname(hostname)
  74. banned_hosts = f"{banned_hosts},{local_ip}"
  75. banned_hosts = ",".join(s.strip() for s in banned_hosts.split(",") if s.strip() and s.strip().lower() != hostname.lower())
  76. with open(sandbox_conf_file_path, "w") as f:
  77. f.write(f"SANDBOX_PYTHON_BANNED_HOSTS={banned_hosts}\n")
  78. f.write(f"SANDBOX_PYTHON_ALLOW_DL_PATHS={','.join(sorted(set(filter(None, sys.path + _sandbox_python_sys_path + allow_dl_paths.split(',')))))}\n")
  79. f.write(f"SANDBOX_PYTHON_ALLOW_DL_OPEN={allow_dl_open}\n")
  80. f.write(f"SANDBOX_PYTHON_ALLOW_SUBPROCESS={allow_subprocess}\n")
  81. f.write(f"SANDBOX_PYTHON_ALLOW_SYSCALL={allow_syscall}\n")
  82. os.system(f"chmod -R 550 {_sandbox_path}")
  83. try:
  84. init_sandbox_dir()
  85. except Exception as e:
  86. maxkb_logger.error(f'Exception: {e}', exc_info=True)
  87. def exec_code(self, code_str, keywords, function_name=None):
  88. _id = str(uuid.uuid7())
  89. action_function = f'({function_name !a}, locals_v.get({function_name !a}))' if function_name else 'locals_v.popitem()'
  90. set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
  91. _exec_code = f"""
  92. try:
  93. import os, sys, json
  94. from contextlib import redirect_stdout
  95. path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
  96. sys.path = [p for p in sys.path if p not in path_to_exclude]
  97. sys.path += {_sandbox_python_sys_path}
  98. _id = os.environ.get("_ID")
  99. locals_v = {{}}
  100. keywords = {keywords}
  101. globals_v = {{}}
  102. {set_run_user}
  103. os.environ.clear()
  104. with redirect_stdout(open(os.devnull, 'w')):
  105. exec({dedent(code_str)!a}, globals_v, locals_v)
  106. f_name, f = {action_function}
  107. globals_v.update(locals_v)
  108. exec_result = f(**keywords)
  109. sys.stdout.write("\\n" + _id)
  110. json.dump({{'code':200,'msg':'success','data':exec_result}}, sys.stdout, default=str)
  111. except Exception as e:
  112. if isinstance(e, MemoryError): e = Exception("Cannot allocate more memory: exceeded the limit of {_process_limit_mem_mb} MB.")
  113. sys.stdout.write("\\n" + _id)
  114. json.dump({{'code':500,'msg':str(e),'data':None}}, sys.stdout, default=str)
  115. sys.stdout.write("\\n" + _id + "__END__\\n")
  116. sys.stdout.flush()
  117. """
  118. maxkb_logger.debug(f"Tool execution({_id}) execute code: {_exec_code}")
  119. with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=True) as f:
  120. f.write(_exec_code)
  121. f.flush()
  122. with execution_timer(_id):
  123. subprocess_result = self._exec(f.name, _id)
  124. if subprocess_result.returncode != 0:
  125. raise Exception(subprocess_result.stderr or subprocess_result.stdout or "Unknown exception occurred")
  126. lines = subprocess_result.stdout.splitlines()
  127. if len(lines) < 2 or lines[-1] != f"{_id}__END__":
  128. raise Exception("Execution interrupted or tampered")
  129. last_line = lines[-2]
  130. if not last_line.startswith(_id):
  131. raise Exception("No result found.")
  132. result = json.loads(last_line[len(_id):])
  133. if result.get('code') == 200:
  134. return result.get('data')
  135. raise Exception(result.get('msg') + (f'\n{subprocess_result.stderr}' if subprocess_result.stderr else ''))
  136. def _generate_mcp_server_code(self, _code, params, name=None, description=None, tool_id=None):
  137. # 解析代码,提取导入语句和函数定义
  138. try:
  139. tree = ast.parse(_code)
  140. except SyntaxError:
  141. return _code
  142. imports = []
  143. functions = []
  144. other_code = []
  145. for node in tree.body:
  146. if isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
  147. imports.append(ast.unparse(node))
  148. elif isinstance(node, ast.FunctionDef):
  149. if node.name.startswith('_'):
  150. other_code.append(ast.unparse(node))
  151. continue
  152. # 修改函数参数以包含 params 中的默认值
  153. arg_names = [arg.arg for arg in node.args.args]
  154. # 为参数添加默认值,确保参数顺序正确
  155. defaults = []
  156. num_defaults = 0
  157. # 从后往前检查哪些参数有默认值
  158. for i, arg_name in enumerate(arg_names):
  159. if arg_name in params:
  160. num_defaults = len(arg_names) - i
  161. break
  162. # 为有默认值的参数创建默认值列表
  163. if num_defaults > 0:
  164. for i in range(len(arg_names) - num_defaults, len(arg_names)):
  165. arg_name = arg_names[i]
  166. if arg_name in params:
  167. default_value = params[arg_name]
  168. if isinstance(default_value, str):
  169. defaults.append(ast.Constant(value=default_value))
  170. elif isinstance(default_value, (int, float, bool)):
  171. defaults.append(ast.Constant(value=default_value))
  172. elif default_value is None:
  173. defaults.append(ast.Constant(value=None))
  174. else:
  175. defaults.append(ast.Constant(value=str(default_value)))
  176. else:
  177. # 如果某个参数没有默认值,需要添加 None 占位
  178. defaults.append(ast.Constant(value=None))
  179. node.args.defaults = defaults
  180. # 将不支持 JSON Schema 的参数类型注解替换为 Any,
  181. # 避免 FastMCP/Pydantic 生成 schema 时崩溃(如 requests.Response)
  182. _safe_annotation_names = {
  183. 'str', 'int', 'float', 'bool', 'dict', 'list', 'tuple',
  184. 'set', 'bytes', 'Any', 'Optional', 'Union', 'List',
  185. 'Dict', 'Tuple', 'Set', 'Sequence', 'None', 'NoneType',
  186. }
  187. def _is_safe_annotation(node_ann):
  188. if node_ann is None:
  189. return True
  190. if isinstance(node_ann, ast.Constant):
  191. return True
  192. if isinstance(node_ann, ast.Name):
  193. return node_ann.id in _safe_annotation_names
  194. if isinstance(node_ann, ast.Attribute):
  195. # e.g. requests.Response, typing.Optional — treat none as safe
  196. return False
  197. if isinstance(node_ann, (ast.Subscript, ast.BinOp)):
  198. # e.g. Optional[str], str | None — recurse
  199. if isinstance(node_ann, ast.Subscript):
  200. return _is_safe_annotation(node_ann.value) and _is_safe_annotation(node_ann.slice)
  201. return _is_safe_annotation(node_ann.left) and _is_safe_annotation(node_ann.right)
  202. return False
  203. for arg in node.args.args:
  204. if not _is_safe_annotation(arg.annotation):
  205. arg.annotation = ast.Name(id='Any', ctx=ast.Load())
  206. # 修改返回类型注解为 Result
  207. node.returns = ast.Name(id='Result', ctx=ast.Load())
  208. # 修改 return 语句为 return Result(result=..., tool_id=...)
  209. class ReturnTransformer(ast.NodeTransformer):
  210. def __init__(self, func_name):
  211. self.func_name = func_name
  212. def visit_Return(self, node):
  213. if node.value is None:
  214. # return 语句没有返回值
  215. new_return = ast.Return(
  216. value=ast.Call(
  217. func=ast.Name(id='Result', ctx=ast.Load()),
  218. args=[],
  219. keywords=[
  220. ast.keyword(arg='result', value=ast.Constant(value=None)),
  221. ast.keyword(arg='tool_id', value=ast.Constant(value=tool_id))
  222. ]
  223. )
  224. )
  225. else:
  226. # return 语句有返回值
  227. new_return = ast.Return(
  228. value=ast.Call(
  229. func=ast.Name(id='Result', ctx=ast.Load()),
  230. args=[],
  231. keywords=[
  232. ast.keyword(arg='result', value=node.value),
  233. ast.keyword(arg='tool_id', value=ast.Constant(value=tool_id))
  234. ]
  235. )
  236. )
  237. return ast.copy_location(new_return, node)
  238. transformer = ReturnTransformer(node.name)
  239. node = transformer.visit(node)
  240. ast.fix_missing_locations(node)
  241. func_code = ast.unparse(node)
  242. # 有些模型不支持name是中文,例如: deepseek, 其他模型未知
  243. escaped_desc = (name + ' ' + description).replace('\n', ' ').replace("'", " ")
  244. functions.append(f"@mcp.tool(description='{escaped_desc}')\n{func_code}\n")
  245. else:
  246. other_code.append(ast.unparse(node))
  247. # 构建完整的 MCP 服务器代码
  248. code_parts = ["from mcp.server.fastmcp import FastMCP"]
  249. code_parts.extend(imports)
  250. code_parts.append(f"\nfrom pydantic import BaseModel")
  251. code_parts.append(f"\nfrom typing import Any")
  252. code_parts.append(f"\nclass Result(BaseModel):")
  253. code_parts.append(f"\n\tresult: Any")
  254. code_parts.append(f"\n\ttool_id: str\n")
  255. code_parts.append(f"\nmcp = FastMCP(\"{uuid.uuid7()}\")\n")
  256. code_parts.extend(other_code)
  257. code_parts.extend(functions)
  258. code_parts.append("\nmcp.run(transport=\"stdio\")\n")
  259. return "\n".join(code_parts)
  260. def generate_mcp_server_code(self, code_str, params, name, description, tool_id):
  261. code = self._generate_mcp_server_code(code_str, params, name, description, tool_id)
  262. set_run_user = f'os.setgid({pwd.getpwnam(_run_user).pw_gid});os.setuid({pwd.getpwnam(_run_user).pw_uid});' if _enable_sandbox else ''
  263. return f"""
  264. import os, sys, logging
  265. logging.basicConfig(level=logging.WARNING)
  266. logging.getLogger("mcp").setLevel(logging.ERROR)
  267. logging.getLogger("mcp.server").setLevel(logging.ERROR)
  268. path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
  269. sys.path = [p for p in sys.path if p not in path_to_exclude]
  270. sys.path += {_sandbox_python_sys_path}
  271. {set_run_user}
  272. os.environ.clear()
  273. exec({dedent(code)!a})
  274. """
  275. def get_tool_mcp_config(self, tool, params):
  276. _code = self.generate_mcp_server_code(tool.code, params, tool.name, tool.desc, str(tool.id))
  277. maxkb_logger.debug(f"Python code of mcp tool: {_code}")
  278. compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
  279. tool_config = {
  280. 'command': sys.executable,
  281. 'args': [
  282. '-c',
  283. f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
  284. ],
  285. 'cwd': _sandbox_path,
  286. 'env': {
  287. 'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
  288. },
  289. 'transport': 'stdio',
  290. }
  291. return tool_config
  292. def get_app_mcp_config(self, api_key):
  293. app_config = {
  294. 'url': f'http://127.0.0.1:8080{CONFIG.get_chat_path()}/api/mcp',
  295. 'transport': 'streamable_http',
  296. 'headers': {
  297. 'Authorization': f'Bearer {api_key}',
  298. },
  299. }
  300. return app_config
  301. def _exec(self, execute_file, _id):
  302. kwargs = {'cwd': BASE_DIR, 'env': {
  303. 'LD_PRELOAD': f'{_sandbox_path}/lib/sandbox.so',
  304. '_ID': _id,
  305. }}
  306. def _set_resource_limit():
  307. if not _enable_sandbox or not sys.platform.startswith("linux"): return
  308. with suppress(Exception): resource.setrlimit(resource.RLIMIT_AS, (_process_limit_mem_mb * 1024 * 1024,) * 2)
  309. with suppress(Exception): os.sched_setaffinity(0, set(random.sample(list(os.sched_getaffinity(0)), _process_limit_cpu_cores)))
  310. try:
  311. subprocess_result = subprocess.run(
  312. [sys.executable, execute_file],
  313. timeout=_process_limit_timeout_seconds,
  314. text=True,
  315. capture_output=True,
  316. **kwargs,
  317. preexec_fn=_set_resource_limit
  318. )
  319. return subprocess_result
  320. except subprocess.TimeoutExpired:
  321. raise Exception(_(f"Process execution timed out after {_process_limit_timeout_seconds} seconds."))
  322. def validate_mcp_transport(self, code_str):
  323. servers = json.loads(code_str)
  324. for server, config in servers.items():
  325. if config.get('transport') not in ['sse', 'streamable_http']:
  326. raise Exception(_('Only support transport=sse or transport=streamable_http'))
  327. @contextmanager
  328. def execution_timer(id=""):
  329. start = time.perf_counter()
  330. try:
  331. yield
  332. finally:
  333. maxkb_logger.debug(f"Tool execution({id}) takes {time.perf_counter() - start:.6f} seconds.")