model_source.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from typing import Optional
  2. from gpustack.schemas.models import Model, ModelSource, SourceEnum
  3. from gpustack.server.catalog import get_catalog_draft_models
  4. def get_draft_model_source(model: Model) -> Optional[ModelSource]:
  5. """
  6. Get the model source for the draft model.
  7. First check the catalog for the draft model.
  8. If not found, get the model source empirically to support custom draft models.
  9. """
  10. if model.speculative_config is None or not model.speculative_config.draft_model:
  11. return None
  12. draft_model = model.speculative_config.draft_model
  13. catalog_draft_models = get_catalog_draft_models()
  14. for catalog_draft_model in catalog_draft_models:
  15. if catalog_draft_model.name == draft_model:
  16. return catalog_draft_model
  17. # If draft_model looks like a path, assume it's a local path.
  18. if draft_model.startswith("/"):
  19. return ModelSource(source=SourceEnum.LOCAL_PATH, local_path=draft_model)
  20. # Otherwise, assume it comes from the same source as the main model.
  21. if model.source == SourceEnum.HUGGING_FACE:
  22. return ModelSource(
  23. source=SourceEnum.HUGGING_FACE,
  24. huggingface_repo_id=draft_model,
  25. )
  26. elif model.source == SourceEnum.MODEL_SCOPE:
  27. return ModelSource(
  28. source=SourceEnum.MODEL_SCOPE,
  29. model_scope_model_id=draft_model,
  30. )
  31. return None