models.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """
  2. Model registry for cross-instance event enrichment.
  3. This module provides lazy loading of model classes to avoid circular imports.
  4. It's separate from bus.py to prevent the import cycle:
  5. bus.py -> models -> active_record.py -> bus.py
  6. """
  7. import logging
  8. from typing import Callable, Dict, Optional, Type
  9. logger = logging.getLogger(__name__)
  10. class _ModelRegistry:
  11. """
  12. Registry for topic-to-model mappings with lazy loading.
  13. This avoids circular imports by only importing model classes when first accessed.
  14. """
  15. _REGISTRY: Dict[str, Callable[[], Optional[Type]]] = {
  16. 'worker': lambda: _import_model('gpustack.schemas.workers', 'Worker'),
  17. 'model': lambda: _import_model('gpustack.schemas.models', 'Model'),
  18. 'modelinstance': lambda: _import_model(
  19. 'gpustack.schemas.models', 'ModelInstance'
  20. ),
  21. 'modelfile': lambda: _import_model('gpustack.schemas.model_files', 'ModelFile'),
  22. 'modelroute': lambda: _import_model(
  23. 'gpustack.schemas.model_routes', 'ModelRoute'
  24. ),
  25. 'modelroutetarget': lambda: _import_model(
  26. 'gpustack.schemas.model_routes', 'ModelRouteTarget'
  27. ),
  28. 'cluster': lambda: _import_model('gpustack.schemas.clusters', 'Cluster'),
  29. 'workerpool': lambda: _import_model('gpustack.schemas.clusters', 'WorkerPool'),
  30. 'cloudcredential': lambda: _import_model(
  31. 'gpustack.schemas.clusters', 'CloudCredential'
  32. ),
  33. 'modelprovider': lambda: _import_model(
  34. 'gpustack.schemas.model_provider', 'ModelProvider'
  35. ),
  36. 'user': lambda: _import_model('gpustack.schemas.users', 'User'),
  37. 'apikey': lambda: _import_model('gpustack.schemas.api_keys', 'ApiKey'),
  38. 'benchmark': lambda: _import_model('gpustack.schemas.benchmark', 'Benchmark'),
  39. 'inferencebackend': lambda: _import_model(
  40. 'gpustack.schemas.inference_backend', 'InferenceBackend'
  41. ),
  42. }
  43. @classmethod
  44. def get_model(cls, topic: str) -> Optional[Type]:
  45. """Get model class for a topic, or None if not registered."""
  46. loader = cls._REGISTRY.get(topic)
  47. return loader() if loader else None
  48. def _import_model(module_path: str, class_name: str) -> Optional[Type]:
  49. """Import a model class by module path and class name."""
  50. try:
  51. module = __import__(module_path, fromlist=[class_name])
  52. return getattr(module, class_name)
  53. except (ImportError, AttributeError) as e:
  54. logger.debug(f"Failed to import {module_path}.{class_name}: {e}")
  55. return None
  56. def get_model_for_topic(topic: str) -> Optional[Type]:
  57. """Get the model class associated with a topic name."""
  58. return _ModelRegistry.get_model(topic)