backend_dependency_manager.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import logging
  2. from typing import Dict, List
  3. from dataclasses import dataclass
  4. from packaging.specifiers import SpecifierSet
  5. from packaging.version import Version
  6. from gpustack.schemas.models import BackendEnum
  7. logger = logging.getLogger(__name__)
  8. @dataclass
  9. class BackendDependencySpec:
  10. """
  11. Represents a backend dependency specification.
  12. Attributes:
  13. backend: The backend name (e.g., 'vox-box', 'vllm')
  14. dependencies: List of dependency specifications (e.g., ['transformers==4.51.3', 'torch>=2.0.0'])
  15. """
  16. backend: str
  17. dependencies: List[str]
  18. def to_pip_args(self) -> str:
  19. """
  20. Convert dependencies to pip arguments format.
  21. Returns:
  22. String in format "--pip-args='dep1 dep2 dep3'"
  23. """
  24. if not self.dependencies:
  25. return ""
  26. deps_str = " ".join(self.dependencies)
  27. return f"--pip-args='{deps_str}'"
  28. class BackendDependencyManager:
  29. """
  30. Manages backend dependencies for different inference backends.
  31. Examples:
  32. - model_env: {"GPUSTACK_BACKEND_DEPS"="transformers==4.53.3,torch>=2.0.0"}
  33. """
  34. def __init__(self, backend: str, version: str, model_env: Dict[str, str] = None):
  35. self.backend = backend
  36. self.version = version
  37. self._custom_specs: BackendDependencySpec = None
  38. # Initialize default dependencies for each backend using version specifiers
  39. # Format: {backend: {version_specifier: [dependencies]}}
  40. self.default_dependencies_specs: Dict[str, Dict[str, List[str]]] = {
  41. BackendEnum.VLLM: {
  42. "<=0.10.0": ["transformers==4.53.3"],
  43. },
  44. }
  45. self._load_from_environment(model_env)
  46. def _load_from_environment(self, model_env: Dict[str, str] = None):
  47. """
  48. Load custom dependency specifications from model environment variables.
  49. Environment variable format:
  50. GPUSTACK_BACKEND_DEPS="dep1,dep2"
  51. """
  52. if not model_env:
  53. return
  54. # First try to get from model_env, then fallback to system environment
  55. env_deps = model_env.get("GPUSTACK_BACKEND_DEPS")
  56. if not env_deps:
  57. return
  58. try:
  59. dependencies = [dep.strip() for dep in env_deps.split(",") if dep.strip()]
  60. self._custom_specs = BackendDependencySpec(
  61. backend=self.backend, dependencies=dependencies
  62. )
  63. logger.info(f"Loaded custom dependency spec: {dependencies}")
  64. except Exception as e:
  65. logger.warning(f"Failed to parse GPUSTACK_BACKEND_DEPS: {e}")
  66. def get_dependency_spec(self) -> BackendDependencySpec:
  67. """
  68. Get dependency specification for a backend and version.
  69. Returns:
  70. BackendDependencySpec with custom or default dependencies
  71. """
  72. # First check for legacy format (backend:version)
  73. if self._custom_specs:
  74. return self._custom_specs
  75. # Fall back to default dependencies using version specifiers
  76. default_version_deps = self.default_dependencies_specs.get(self.backend, {})
  77. if not default_version_deps:
  78. return None
  79. # Normalize version by removing 'v' prefix if present
  80. normalized_version = self.version.lstrip('v')
  81. try:
  82. version_obj = Version(normalized_version)
  83. except Exception as e:
  84. logger.warning(
  85. f"Invalid version format '{self.version}' for backend {self.backend}: {e}"
  86. )
  87. return None
  88. # Check each version specifier to find a match
  89. for version_spec, dependencies in default_version_deps.items():
  90. specifier_set = SpecifierSet(version_spec)
  91. if version_obj in specifier_set:
  92. logger.debug(
  93. f"Found matching dependency spec for {self.backend} {self.version}: {version_spec}"
  94. )
  95. return BackendDependencySpec(
  96. backend=self.backend, dependencies=dependencies
  97. )
  98. return None
  99. def get_pipx_install_args(self) -> List[str]:
  100. """
  101. Get pipx installation arguments for a backend.
  102. Args:
  103. backend: Backend name
  104. version: Backend version
  105. Returns:
  106. List of additional arguments for pipx install command
  107. """
  108. spec = self.get_dependency_spec()
  109. if not spec or not spec.dependencies:
  110. return []
  111. pip_args = spec.to_pip_args()
  112. return [pip_args] if pip_args else []