test_model_provider.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import pytest
  2. from gpustack.routes.model_provider import determine_model_category
  3. from gpustack.schemas.model_provider import ModelProviderTypeEnum
  4. from gpustack.schemas.models import CategoryEnum
  5. doubao_llm_data = {
  6. "created": 1767587883,
  7. "domain": "VLM",
  8. "features": {
  9. "batch": {"batch_chat": False, "batch_job": False},
  10. "cache": {"prefix_cache": False, "session_cache": False},
  11. "structured_outputs": {"json_object": True, "json_schema": True},
  12. "tools": {"function_calling": True},
  13. },
  14. "id": "doubao-seed-1-8-251228",
  15. "modalities": {
  16. "input_modalities": ["text", "image", "video"],
  17. "output_modalities": ["text"],
  18. },
  19. "name": "doubao-seed-1-8",
  20. "object": "model",
  21. "task_type": ["VisualQuestionAnswering", "TextGeneration"],
  22. "token_limits": {
  23. "context_window": 262144,
  24. "max_input_token_length": 229376,
  25. "max_output_token_length": 65536,
  26. "max_reasoning_token_length": 32768,
  27. },
  28. "version": "251228",
  29. }
  30. doubao_llm_data2 = {
  31. "created": 1736337657,
  32. "domain": "LLM",
  33. "features": {
  34. "cache": {
  35. "prefix_cache": True,
  36. "session_cache": True,
  37. },
  38. "structured_outputs": {
  39. "json_object": False,
  40. "json_schema": False,
  41. },
  42. "tools": {
  43. "function_calling": True,
  44. },
  45. },
  46. "id": "doubao-1-5-lite-32k-250115",
  47. "modalities": {"input_modalities": ["text"], "output_modalities": ["text"]},
  48. "name": "doubao-1-5-lite-32k",
  49. "object": "model",
  50. "task_type": ["TextGeneration"],
  51. "token_limits": {"context_window": 32768, "max_output_token_length": 12288},
  52. "version": "250115",
  53. }
  54. doubao_embedding_data = {
  55. "created": 1715588483,
  56. "domain": "Embedding",
  57. "features": {},
  58. "id": "doubao-embedding-text-240515",
  59. "modalities": {"input_modalities": ["text"]},
  60. "name": "doubao-embedding",
  61. "object": "model",
  62. "status": "Retiring",
  63. "task_type": ["TextEmbedding"],
  64. "token_limits": {},
  65. "version": "text-240515",
  66. }
  67. qwen_llm_data = {
  68. "id": "qwen3-max-2026-01-23",
  69. "object": "model",
  70. "created": 1769481796,
  71. "owned_by": "system",
  72. }
  73. qwen_image_data = {
  74. "id": "qwen-image-edit-max",
  75. "object": "model",
  76. "created": 1768570977,
  77. "owned_by": "system",
  78. }
  79. qwen_llm_split_name_data = {
  80. "id": "siliconflow/deepseek-v3.2",
  81. "object": "model",
  82. "created": 1769611475,
  83. "owned_by": "system",
  84. }
  85. @pytest.mark.parametrize(
  86. "provider_type,model,expected",
  87. [
  88. # actual data from doubao
  89. (
  90. ModelProviderTypeEnum.DOUBAO,
  91. doubao_llm_data,
  92. [CategoryEnum.LLM.value],
  93. ),
  94. (
  95. ModelProviderTypeEnum.DOUBAO,
  96. doubao_llm_data2,
  97. [CategoryEnum.LLM.value],
  98. ),
  99. (
  100. ModelProviderTypeEnum.DOUBAO,
  101. doubao_embedding_data,
  102. [CategoryEnum.EMBEDDING.value],
  103. ),
  104. # actual data from qwen
  105. (
  106. ModelProviderTypeEnum.QWEN,
  107. qwen_image_data,
  108. [CategoryEnum.IMAGE.value],
  109. ),
  110. (
  111. ModelProviderTypeEnum.QWEN,
  112. qwen_llm_data,
  113. [CategoryEnum.LLM.value],
  114. ),
  115. (
  116. ModelProviderTypeEnum.QWEN,
  117. qwen_llm_split_name_data,
  118. [CategoryEnum.LLM.value],
  119. ),
  120. # actual data from deepseek
  121. (
  122. ModelProviderTypeEnum.DEEPSEEK,
  123. {
  124. "id": "deepseek-chat",
  125. "object": "model",
  126. "owned_by": "deepseek",
  127. },
  128. [CategoryEnum.LLM.value],
  129. ),
  130. (
  131. ModelProviderTypeEnum.DEEPSEEK,
  132. {
  133. "id": "deepseek-reasoner",
  134. "object": "model",
  135. "owned_by": "deepseek",
  136. },
  137. [CategoryEnum.LLM.value],
  138. ),
  139. ],
  140. )
  141. def test_determine_model_category(provider_type, model, expected):
  142. assert determine_model_category(provider_type, model) == expected