filesystem.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import asyncio
  2. import json
  3. import logging
  4. import os
  5. import subprocess
  6. import traceback
  7. from typing import Optional
  8. from fastapi import APIRouter, Depends, HTTPException, Query, Request
  9. from gpustack.api.auth import worker_auth
  10. from gpustack.config.config import Config
  11. from gpustack.schemas.filesystem import (
  12. FileExistsResponse,
  13. GGUFParseRequest,
  14. GGUFParseResponse,
  15. )
  16. from gpustack.schemas.models import Model
  17. from gpustack.scheduler.calculator import (
  18. _gguf_parser_command,
  19. _gguf_parser_env,
  20. GPUOffloadEnum,
  21. calculate_local_model_weight_size,
  22. )
  23. router = APIRouter(dependencies=[Depends(worker_auth)])
  24. logger = logging.getLogger(__name__)
  25. ALLOWED_CONFIG_FILES = {
  26. "config.json",
  27. "model_index.json",
  28. "tokenizer.json",
  29. "tokenizer_config.json",
  30. "special_tokens_map.json",
  31. "generation_config.json",
  32. "adapter_config.json",
  33. "preprocessor_config.json",
  34. }
  35. def is_config_file(filename: str) -> bool:
  36. """Check if a file is a model config file."""
  37. return filename in ALLOWED_CONFIG_FILES
  38. def validate_path_security(path: str, base_path: str = None) -> str:
  39. """
  40. Validate path security to prevent directory traversal attacks.
  41. This function:
  42. 1. Resolves the absolute path (following symlinks)
  43. 2. Validates the path is within the allowed base directory (if provided)
  44. 3. Prevents directory traversal attacks
  45. Args:
  46. path: The path to validate
  47. base_path: Optional base directory that the path must be within
  48. Returns:
  49. The validated absolute path
  50. Raises:
  51. HTTPException: If the path is invalid or outside the allowed directory
  52. Security:
  53. - Uses os.path.realpath to resolve symlinks and get absolute path
  54. - Validates path is within base_path if provided
  55. - Prevents directory traversal attacks (../, symlinks, etc.)
  56. """
  57. try:
  58. # Resolve to absolute path, following symlinks
  59. # This is more secure than os.path.normpath which doesn't resolve symlinks
  60. resolved_path = os.path.realpath(path)
  61. # If base_path is provided, ensure the resolved path is within it
  62. if base_path:
  63. resolved_base = os.path.realpath(base_path)
  64. # Use os.path.commonpath to check if resolved_path is under resolved_base
  65. # This prevents directory traversal attacks
  66. try:
  67. common = os.path.commonpath([resolved_base, resolved_path])
  68. if common != resolved_base:
  69. raise HTTPException(
  70. status_code=403,
  71. detail="Access denied: Path is outside allowed directory",
  72. )
  73. except ValueError:
  74. # Paths are on different drives (Windows)
  75. raise HTTPException(
  76. status_code=403,
  77. detail="Access denied: Path is outside allowed directory",
  78. )
  79. return resolved_path
  80. except HTTPException:
  81. raise
  82. except Exception as e:
  83. logger.error(f"Error validating path {path}: {e}")
  84. raise HTTPException(status_code=400, detail=f"Invalid path: {str(e)}")
  85. @router.get("/files/model-config")
  86. async def read_model_config(path: str = Query(..., description="File path to read")):
  87. """
  88. Read and parse a model config file.
  89. Only model config files (config.json, model_index.json, etc.) can be read for security.
  90. Returns the parsed configuration object.
  91. Security:
  92. - Uses os.path.realpath to resolve symlinks and prevent directory traversal
  93. - Only allows reading of whitelisted config files
  94. - Validates file exists and is a regular file
  95. """
  96. try:
  97. # Validate path security (resolves symlinks, prevents directory traversal)
  98. validated_path = validate_path_security(path)
  99. # Check if path exists
  100. if not os.path.exists(validated_path):
  101. raise HTTPException(status_code=404, detail=f"File not found: {path}")
  102. # Check if path is a file
  103. if not os.path.isfile(validated_path):
  104. raise HTTPException(status_code=400, detail=f"Path is not a file: {path}")
  105. # Check if file is a config file for security
  106. filename = os.path.basename(validated_path)
  107. if not is_config_file(filename):
  108. raise HTTPException(
  109. status_code=403,
  110. detail="Access denied: Only model config files are allowed to be read",
  111. )
  112. # Read and parse JSON file
  113. try:
  114. with open(validated_path, "r", encoding="utf-8") as f:
  115. import json
  116. config_data = json.load(f)
  117. except PermissionError:
  118. raise HTTPException(status_code=403, detail=f"Permission denied: {path}")
  119. except json.JSONDecodeError as e:
  120. raise HTTPException(status_code=400, detail=f"Invalid JSON file: {str(e)}")
  121. except OSError as e:
  122. raise HTTPException(
  123. status_code=500, detail=f"Failed to read file: {str(e)}"
  124. )
  125. return config_data
  126. except HTTPException:
  127. raise
  128. except Exception as e:
  129. logger.error(f"Error reading file {path}: {e}")
  130. raise HTTPException(status_code=500, detail=f"Failed to read file: {str(e)}")
  131. @router.get("/files/file-exists", response_model=FileExistsResponse)
  132. async def file_exists(path: str = Query(..., description="Path to check")):
  133. """
  134. Check if a path exists.
  135. Security:
  136. - Uses os.path.realpath to resolve symlinks and prevent directory traversal
  137. """
  138. try:
  139. # Validate path security (resolves symlinks, prevents directory traversal)
  140. validated_path = validate_path_security(path)
  141. # Check if path exists
  142. exists = os.path.exists(validated_path)
  143. is_file = os.path.isfile(validated_path) if exists else False
  144. is_dir = os.path.isdir(validated_path) if exists else False
  145. return FileExistsResponse(
  146. exists=exists, path=validated_path, is_file=is_file, is_dir=is_dir
  147. )
  148. except Exception as e:
  149. logger.error(f"Error checking path {path}: {e}")
  150. raise HTTPException(status_code=500, detail=f"Failed to check path: {str(e)}")
  151. def is_diffusion_model(path: str) -> bool:
  152. """
  153. Check if a path is a diffusion model by looking for model_index.json file.
  154. Args:
  155. path: Directory path to check
  156. Returns:
  157. True if model_index.json exists in the directory, False otherwise
  158. """
  159. model_index_path = os.path.join(path, "model_index.json")
  160. try:
  161. return os.path.isfile(model_index_path)
  162. except OSError:
  163. return False
  164. @router.get("/files/model-weight-size")
  165. async def get_model_weight_size(
  166. path: str = Query(..., description="Directory path to scan"),
  167. ):
  168. """
  169. Calculate the total size of model weight files in a directory.
  170. Security:
  171. - Uses os.path.realpath to resolve symlinks and prevent directory traversal
  172. - Only scans the specified directory (not recursive for LLM, component dirs for diffusion)
  173. """
  174. try:
  175. # Validate path security (resolves symlinks, prevents directory traversal)
  176. validated_path = validate_path_security(path)
  177. if not os.path.exists(validated_path):
  178. raise HTTPException(status_code=404, detail=f"Directory not found: {path}")
  179. if not os.path.isdir(validated_path):
  180. raise HTTPException(
  181. status_code=400, detail=f"Path is not a directory: {path}"
  182. )
  183. is_diffusion = is_diffusion_model(validated_path)
  184. # Calculate size using utility function
  185. try:
  186. total_size = calculate_local_model_weight_size(validated_path, is_diffusion)
  187. except FileNotFoundError as e:
  188. raise HTTPException(status_code=404, detail=str(e))
  189. except NotADirectoryError as e:
  190. raise HTTPException(status_code=400, detail=str(e))
  191. except PermissionError as e:
  192. raise HTTPException(status_code=403, detail=str(e))
  193. except json.JSONDecodeError as e:
  194. raise HTTPException(
  195. status_code=400, detail=f"Invalid model_index.json: {str(e)}"
  196. )
  197. return {"size": total_size}
  198. except HTTPException:
  199. raise
  200. except Exception as e:
  201. logger.error(f"Error calculating model weight size for {path}: {e}")
  202. raise HTTPException(
  203. status_code=500, detail=f"Failed to calculate size: {str(e)}"
  204. )
  205. @router.post("/files/parse-gguf", response_model=GGUFParseResponse)
  206. async def parse_gguf_file(http_request: Request, body: GGUFParseRequest):
  207. """
  208. Parse a GGUF file using gguf-parser binary on the worker.
  209. Security:
  210. - Uses os.path.realpath to resolve symlinks and prevent directory traversal
  211. - Only allow parsing of existing files
  212. - 60 second timeout to prevent long-running processes
  213. """
  214. try:
  215. # 1. Deserialize Model object
  216. model = Model.model_validate(body.model_dict)
  217. # 2. Path validation - use validate_path_security for robust security
  218. validated_path = validate_path_security(model.local_path)
  219. # Check if file exists
  220. if not os.path.exists(validated_path):
  221. raise HTTPException(
  222. status_code=404, detail=f"File not found: {model.local_path}"
  223. )
  224. # Check if path is a file
  225. if not os.path.isfile(validated_path):
  226. raise HTTPException(
  227. status_code=400, detail=f"Path is not a file: {model.local_path}"
  228. )
  229. # Update model.local_path to use validated path
  230. model.local_path = validated_path
  231. # 3. Build offload enum
  232. offload_enum = GPUOffloadEnum(body.offload)
  233. # 4. Prepare kwargs (override parameters)
  234. kwargs = {}
  235. if body.tensor_split:
  236. kwargs["tensor_split"] = body.tensor_split
  237. if body.rpc:
  238. kwargs["rpc"] = body.rpc
  239. # cache_dir from this worker's app.state.config (Worker._serve_apis), not from body.
  240. worker_cfg: Optional[Config] = getattr(http_request.app.state, "config", None)
  241. if worker_cfg is not None:
  242. kwargs["cache_dir"] = worker_cfg.cache_dir
  243. # 5. Reuse _gguf_parser_command to build command
  244. command = await _gguf_parser_command(model, offload_enum, **kwargs)
  245. env = _gguf_parser_env(model)
  246. # 6. Execute command
  247. logger.debug(f"Executing gguf-parser command: {' '.join(map(str, command))}")
  248. # Use subprocess.run in a thread to avoid asyncio event loop issues
  249. # This is more reliable than asyncio.create_subprocess_exec in worker threads
  250. def run_command():
  251. """Run command synchronously in a thread."""
  252. try:
  253. result = subprocess.run(
  254. command,
  255. env=env,
  256. stdout=subprocess.PIPE,
  257. stderr=subprocess.PIPE,
  258. timeout=60,
  259. )
  260. return result.returncode, result.stdout, result.stderr
  261. except subprocess.TimeoutExpired:
  262. return -1, b"", b"Parsing timed out after 60 seconds"
  263. # Run in thread pool to avoid blocking
  264. returncode, stdout, stderr = await asyncio.to_thread(run_command)
  265. logger.debug("Process completed, processing output")
  266. if returncode == -1:
  267. # Timeout
  268. logger.error(f"GGUF parsing timed out for {model.local_path}")
  269. return GGUFParseResponse(success=False, error=stderr.decode())
  270. if returncode != 0:
  271. error_msg = stderr.decode() if stderr else "Unknown error"
  272. logger.error(f"GGUF parsing failed for {model.local_path}: {error_msg}")
  273. return GGUFParseResponse(success=False, error=error_msg)
  274. output_str = stdout.decode()
  275. logger.debug(f"GGUF parsing succeeded for {model.local_path}")
  276. return GGUFParseResponse(success=True, output=output_str)
  277. except HTTPException:
  278. raise
  279. except Exception as e:
  280. error_detail = traceback.format_exc()
  281. logger.error(f"Error parsing GGUF file: {e}\nTraceback:\n{error_detail}")
  282. return GGUFParseResponse(success=False, error=f"{type(e).__name__}: {str(e)}")