config.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import logging
  2. from fastapi import APIRouter, Request
  3. from typing import Any, Dict
  4. from gpustack.api.exceptions import (
  5. InvalidException,
  6. ForbiddenException,
  7. )
  8. from gpustack.config.config import Config, set_global_config
  9. from gpustack.utils.config import (
  10. WHITELIST_CONFIG_FIELDS,
  11. READ_ONLY_CONFIG_FIELDS,
  12. coerce_value_by_field,
  13. is_local_request,
  14. )
  15. router = APIRouter()
  16. logger = logging.getLogger(__name__)
  17. @router.get("/config")
  18. async def get_config(request: Request):
  19. app_state = request.app.state
  20. cfg: Config = getattr(app_state, "server_config", None) or getattr(
  21. app_state, "config", None
  22. )
  23. if cfg is None:
  24. raise InvalidException(message="Config is not available")
  25. result: Dict[str, Any] = {}
  26. for field in READ_ONLY_CONFIG_FIELDS:
  27. if hasattr(cfg, field):
  28. result[field] = getattr(cfg, field)
  29. return result
  30. @router.put("/config")
  31. async def set_config(request: Request):
  32. if not is_local_request(request):
  33. raise ForbiddenException(message="Only localhost is allowed")
  34. app_state = request.app.state
  35. cfg: Config = getattr(app_state, "server_config", None) or getattr(
  36. app_state, "config", None
  37. )
  38. if cfg is None:
  39. raise InvalidException(message="Config is not available")
  40. data = await request.json()
  41. updates: Dict[str, Any] = {}
  42. for k, v in data.items():
  43. if k in WHITELIST_CONFIG_FIELDS:
  44. updates[k] = coerce_value_by_field(k, v)
  45. for k, v in updates.items():
  46. setattr(cfg, k, v)
  47. if "debug" in updates:
  48. logging.getLogger().setLevel(
  49. logging.DEBUG if bool(updates["debug"]) else logging.INFO
  50. )
  51. set_global_config(cfg)
  52. logger.info("Applied runtime config updates")
  53. return "ok"