test_pretrained_config.py 869 B

12345678910111213141516171819202122
  1. from transformers import AutoConfig
  2. from gpustack.utils.hub import get_hf_text_config, get_max_model_len
  3. def test_get_max_model_len():
  4. hf_model_lengths = {
  5. "Qwen/Qwen2-0.5B-Instruct": 32768,
  6. "Qwen/Qwen2-VL-7B-Instruct": 32768,
  7. "tiiuae/falcon-7b": 2048,
  8. "microsoft/Phi-3.5-mini-instruct": 131072,
  9. "llava-hf/llava-v1.6-mistral-7b-hf": 32768,
  10. "unsloth/Llama-3.2-11B-Vision-Instruct": 131072,
  11. "THUDM/glm-4-9b-chat-1m": 1048576,
  12. }
  13. for model_name, expected_max_model_len in hf_model_lengths.items():
  14. config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
  15. pretrained_or_hf_text_config = get_hf_text_config(config)
  16. assert (
  17. get_max_model_len(pretrained_or_hf_text_config) == expected_max_model_len
  18. ), f"max_model_len mismatch for {model_name}"