vllm_model_provider.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # coding=utf-8
  2. import os
  3. from urllib.parse import urlparse, ParseResult
  4. import requests
  5. from common.utils.common import get_file_content
  6. from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
  7. ModelInfoManage
  8. from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential
  9. from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential
  10. from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential
  11. from models_provider.impl.vllm_model_provider.credential.reranker import VllmRerankerCredential
  12. from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential
  13. from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel
  14. from models_provider.impl.vllm_model_provider.model.image import VllmImage
  15. from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel
  16. from maxkb.conf import PROJECT_DIR
  17. from django.utils.translation import gettext as _
  18. from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker
  19. from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText
  20. v_llm_model_credential = VLLMModelCredential()
  21. image_model_credential = VllmImageModelCredential()
  22. embedding_model_credential = VllmEmbeddingCredential()
  23. whisper_model_credential = VLLMWhisperModelCredential()
  24. rerank_model_credential = VllmRerankerCredential()
  25. model_info_list = [
  26. ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential,
  27. VllmChatModel),
  28. ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential,
  29. VllmChatModel),
  30. ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential,
  31. VllmChatModel),
  32. ]
  33. image_model_info_list = [
  34. ModelInfo('Qwen/Qwen2-VL-2B-Instruct', '', ModelTypeConst.IMAGE, image_model_credential, VllmImage),
  35. ]
  36. embedding_model_info_list = [
  37. ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING,
  38. embedding_model_credential, VllmEmbeddingModel),
  39. ]
  40. whisper_model_info_list = [
  41. ModelInfo('whisper-tiny', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
  42. ModelInfo('whisper-large-v3-turbo', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
  43. ModelInfo('whisper-small', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
  44. ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText),
  45. ]
  46. reranker_model_info_list = [
  47. ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker),
  48. ]
  49. model_info_manage = (
  50. ModelInfoManage.builder()
  51. .append_model_info_list(model_info_list)
  52. .append_default_model_info(ModelInfo('facebook/opt-125m',
  53. _('Facebook’s 125M parameter model'),
  54. ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel))
  55. .append_model_info_list(image_model_info_list)
  56. .append_default_model_info(image_model_info_list[0])
  57. .append_model_info_list(embedding_model_info_list)
  58. .append_default_model_info(embedding_model_info_list[0])
  59. .append_model_info_list(whisper_model_info_list)
  60. .append_default_model_info(whisper_model_info_list[0])
  61. .append_model_info_list(reranker_model_info_list)
  62. .append_default_model_info(reranker_model_info_list[0])
  63. .build()
  64. )
  65. def get_base_url(url: str):
  66. parse = urlparse(url)
  67. result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
  68. query='',
  69. fragment='').geturl()
  70. return result_url[:-1] if result_url.endswith("/") else result_url
  71. class VllmModelProvider(IModelProvider):
  72. def get_model_info_manage(self):
  73. return model_info_manage
  74. def get_model_provide_info(self):
  75. return ModelProvideInfo(provider='model_vllm_provider', name='vLLM', icon=get_file_content(
  76. os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'vllm_model_provider', 'icon',
  77. 'vllm_icon_svg')))
  78. @staticmethod
  79. def get_base_model_list(api_base, api_key):
  80. base_url = get_base_url(api_base)
  81. base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
  82. headers = {}
  83. if api_key:
  84. headers['Authorization'] = f"Bearer {api_key}"
  85. r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5)
  86. r.raise_for_status()
  87. return r.json().get('data')
  88. @staticmethod
  89. def get_model_info_by_name(model_list, model_name):
  90. if model_list is None:
  91. return []
  92. return [model for model in model_list if model.get('id') == model_name]