test_catalog.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import os
  2. import time
  3. import pytest
  4. from tenacity import retry, stop_after_attempt, wait_fixed
  5. from gpustack.schemas.models import SourceEnum
  6. from gpustack.server.catalog import get_model_set_specs, init_model_catalog
  7. from gpustack.utils.hub import match_hugging_face_files, match_model_scope_file_paths
  8. from gpustack.utils.compat_importlib import pkg_resources
  9. from huggingface_hub import HfApi
  10. from modelscope.hub.api import HubApi
  11. @pytest.mark.skipif(
  12. os.getenv("HF_TOKEN") is None,
  13. reason="Skipped by default unless HF_TOKEN is set. Unauthed requests are rate limited.",
  14. )
  15. def test_model_catalog():
  16. init_model_catalog()
  17. model_set_specs = get_model_set_specs()
  18. Hfapi = HfApi()
  19. model_name_filter = os.getenv("TEST_CATALOG_MODEL_NAME_FILTER")
  20. for model_set_id, model_specs in model_set_specs.items():
  21. assert model_set_id is not None
  22. assert len(model_specs) > 0
  23. for model_spec in model_specs:
  24. assert (
  25. model_spec.source == SourceEnum.HUGGING_FACE
  26. ), f"Expected huggingface source but got: {model_spec.source}"
  27. if (
  28. model_name_filter is not None
  29. and model_name_filter not in model_spec.huggingface_repo_id
  30. ):
  31. continue
  32. time.sleep(0.01) # mitigate rate limit
  33. print(model_spec.huggingface_repo_id, model_spec.huggingface_filename)
  34. if model_spec.huggingface_filename is None:
  35. model_info = Hfapi.model_info(model_spec.huggingface_repo_id)
  36. assert model_info is not None
  37. else:
  38. match_files = match_hugging_face_files(
  39. model_spec.huggingface_repo_id, model_spec.huggingface_filename
  40. )
  41. assert (
  42. len(match_files) > 0
  43. ), f"Failed to find model files: {model_spec.huggingface_repo_id}, {model_spec.huggingface_filename}"
  44. @pytest.mark.skipif(
  45. os.getenv("HF_TOKEN") is None,
  46. reason="Skipped by default unless HF_TOKEN is set. Unauthed requests are rate limited.",
  47. )
  48. def test_model_catalog_modelscope():
  49. modelscope_catalog_file = pkg_resources.files("gpustack.assets").joinpath(
  50. "model-catalog-modelscope.yaml"
  51. )
  52. init_model_catalog(str(modelscope_catalog_file))
  53. model_set_specs = get_model_set_specs()
  54. Msapi = HubApi()
  55. model_name_filter = os.getenv("TEST_CATALOG_MODEL_NAME_FILTER")
  56. for model_set_id, model_specs in model_set_specs.items():
  57. assert model_set_id is not None
  58. assert len(model_specs) > 0
  59. for model_spec in model_specs:
  60. assert (
  61. model_spec.source == SourceEnum.MODEL_SCOPE
  62. ), f"Expected modelscope source but got: {model_spec.source}"
  63. if (
  64. model_name_filter is not None
  65. and model_name_filter not in model_spec.model_scope_model_id
  66. ):
  67. continue
  68. print(model_spec.model_scope_model_id, model_spec.model_scope_file_path)
  69. if model_spec.model_scope_file_path is None:
  70. model_info = Msapi.get_model(model_spec.model_scope_model_id)
  71. assert model_info is not None
  72. else:
  73. match_files = match_model_scope_file_paths_with_retry(
  74. model_spec.model_scope_model_id,
  75. model_spec.model_scope_file_path,
  76. )
  77. assert (
  78. len(match_files) > 0
  79. ), f"Failed to find model files: {model_spec.model_scope_model_id}, {model_spec.model_scope_file_path}"
  80. @retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
  81. def match_model_scope_file_paths_with_retry(
  82. model_scope_model_id, model_scope_file_path
  83. ):
  84. return match_model_scope_file_paths(model_scope_model_id, model_scope_file_path)