s6_services.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from dataclasses import dataclass
  2. from typing import Optional, Set, List, Dict
  3. gpustack_service_name = "gpustack"
  4. @dataclass
  5. class S6Service:
  6. name: str
  7. ports: Optional[List[int | str]] = None
  8. is_dependency: bool = False
  9. longrun: bool = True
  10. class S6Services:
  11. services: Set[str]
  12. support_pipeline: bool = False
  13. pipeline_prefix: str = "pipeline-"
  14. dependencies: Set[str] # dependency services
  15. service_port_getters: Dict[
  16. str, List[str | int]
  17. ] # service name to port or config field
  18. def __init__(
  19. self,
  20. *services: S6Service,
  21. support_pipeline: bool = False,
  22. pipeline_prefix: str = "pipeline-",
  23. ):
  24. self.services = set()
  25. self.dependencies = set()
  26. self.service_port_getters = {}
  27. self.support_pipeline = support_pipeline
  28. self.pipeline_prefix = pipeline_prefix
  29. for service in services:
  30. if service.longrun:
  31. self.services.add(service.name)
  32. if service.is_dependency:
  33. self.dependencies.add(service.name)
  34. if service.ports:
  35. self.service_port_getters[service.name] = list(service.ports)
  36. def all_services(self) -> List[str]:
  37. managed_services = set(self.services) | set(self.dependencies)
  38. if self.support_pipeline:
  39. pipeline_services = [
  40. self.pipeline_prefix + service for service in self.services
  41. ]
  42. return list(managed_services) + pipeline_services
  43. return list(managed_services)
  44. def set_ports(self, config: object, ports: Dict[int, str]):
  45. if not self.service_port_getters:
  46. return
  47. for service, port_list in self.service_port_getters.items():
  48. for port_or_field in port_list:
  49. if isinstance(port_or_field, int):
  50. ports[port_or_field] = service
  51. else:
  52. port_value = getattr(config, port_or_field, None)
  53. if port_value is None or not isinstance(port_value, int):
  54. continue
  55. if not port_conflict(port_value, ports):
  56. ports[port_value] = service
  57. @property
  58. def dep_services(self) -> List[str]:
  59. return list(self.dependencies or [])
  60. gpustack_server_services = S6Services(
  61. S6Service("gpustack-server", ["port", "proxy_port", "tls_port", "metrics_port", "api_port"], True),
  62. )
  63. gateway_services = S6Services(
  64. S6Service("apiserver", [18443], True),
  65. S6Service("pilot", [9876, 15010, 15012]),
  66. S6Service("controller", [8888, 15051]),
  67. S6Service("gateway", [15000, 15021, 15090, 15020]),
  68. S6Service("supercronic"),
  69. support_pipeline=True,
  70. )
  71. postgres_services = S6Services(
  72. S6Service("postgres", ["database_port"], True),
  73. )
  74. migration_services = S6Services(
  75. S6Service("gpustack-migration", [], True, False),
  76. )
  77. observability_services = S6Services(
  78. S6Service("grafana", ["builtin_grafana_port"]),
  79. S6Service("prometheus", ["builtin_prometheus_port"]),
  80. support_pipeline=True,
  81. )
  82. def all_services() -> List[str]:
  83. return [
  84. *gateway_services.all_services(),
  85. *postgres_services.all_services(),
  86. *migration_services.all_services(),
  87. *observability_services.all_services(),
  88. *gpustack_server_services.all_services(),
  89. ]
  90. def port_conflict(port: int, ports: Dict[int, str]) -> bool:
  91. existing_service = ports.get(port, None)
  92. if existing_service is not None:
  93. raise Exception(f"Port conflict: {port} is already used by " + existing_service)
  94. return False