config.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from enum import Enum
  2. from typing import Optional
  3. from pydantic import BaseModel, Field
  4. from gpustack import __benchmark_runner_version__
  5. class GatewayModeEnum(str, Enum):
  6. """
  7. For both server and worker gateway mode
  8. auto - automatically detect the gateway mode
  9. embedded - use the embedded gateway(server)
  10. incluster - connect to an in-cluster gateway(server)
  11. external - connect to an external gateway(server)
  12. disabled - disable the gateway(server/worker)
  13. The support of worker gateway mode is reserved for future use.
  14. Only incluster and external modes are supported for worker.
  15. """
  16. auto = "auto"
  17. embedded = "embedded"
  18. incluster = "incluster"
  19. external = "external"
  20. disabled = "disabled"
  21. class ModelInstanceProxyModeEnum(str, Enum):
  22. """
  23. Enum for Model Instance Proxy Mode
  24. WORKER - Proxy through the worker
  25. DIRECT - Direct access to the model instance
  26. DELEGATED - Preserved for proxying through cluster gateway (not implemented yet)
  27. """
  28. WORKER = "worker"
  29. DIRECT = "direct"
  30. DELEGATED = "delegated"
  31. TUNNEL = "tunnel"
  32. class SensitivePredefinedConfig(BaseModel):
  33. # Common options
  34. huggingface_token: Optional[str] = Field(
  35. default=None, json_schema_extra={"env_var": "HF_TOKEN"}
  36. )
  37. class PredefinedConfig(SensitivePredefinedConfig):
  38. # Common options
  39. debug: bool = False
  40. cache_dir: Optional[str] = None
  41. log_dir: Optional[str] = None
  42. bin_dir: Optional[str] = None
  43. benchmark_dir: Optional[str] = None
  44. system_default_container_registry: Optional[str] = None
  45. image_name_override: Optional[str] = None
  46. image_repo: str = "gpustack/gpustack"
  47. benchmark_image_repo: str = (
  48. f"gpustack/benchmark-runner:{__benchmark_runner_version__}"
  49. )
  50. gateway_mode: GatewayModeEnum = GatewayModeEnum.auto
  51. gateway_kubeconfig: Optional[str] = None
  52. gateway_namespace: str = "higress-system"
  53. service_discovery_name: Optional[str] = None
  54. namespace: str = "gpustack-system"
  55. # Worker options
  56. disable_worker_metrics: bool = False
  57. worker_port: int = 10150
  58. worker_metrics_port: int = 10151
  59. service_port_range: Optional[str] = "40000-40063"
  60. ray_port_range: Optional[str] = "41000-41999"
  61. benchmark_max_duration_seconds: Optional[int] = None
  62. system_reserved: Optional[dict] = None
  63. pipx_path: Optional[str] = None
  64. tools_download_base_url: Optional[str] = None
  65. enable_hf_transfer: bool = False # Deprecated
  66. enable_hf_xet: bool = False # Deprecated
  67. proxy_mode: Optional[ModelInstanceProxyModeEnum] = None
  68. class PredefinedConfigNoDefaults(PredefinedConfig):
  69. debug: Optional[bool] = None
  70. disable_worker_metrics: Optional[bool] = None
  71. enable_hf_transfer: Optional[bool] = None # Deprecated
  72. enable_hf_xet: Optional[bool] = None # Deprecated
  73. worker_port: Optional[int] = None
  74. worker_metrics_port: Optional[int] = None
  75. service_port_range: Optional[str] = None
  76. ray_port_range: Optional[str] = None
  77. benchmark_max_duration_seconds: Optional[int] = None
  78. image_repo: Optional[str] = None
  79. benchmark_image_repo: Optional[str] = None
  80. gateway_mode: Optional[str] = None
  81. gateway_namespace: Optional[str] = None
  82. namespace: Optional[str] = None
  83. def parse_base_model_to_env_vars(
  84. config: BaseModel,
  85. ) -> dict[str, str]:
  86. env_vars = {}
  87. for field_name, field in config.__class__.model_fields.items():
  88. extra = getattr(field, 'json_schema_extra', None) or {}
  89. env_var = extra.get("env_var")
  90. if env_var is None:
  91. # assuming the field name is in snake_case
  92. env_var = f"GPUSTACK_{field_name.upper()}"
  93. value = getattr(config, field_name)
  94. if value is not None:
  95. if isinstance(value, bool):
  96. env_vars[env_var] = "true" if value else "false"
  97. else:
  98. env_vars[env_var] = str(value)
  99. return env_vars