# coding=utf-8 import os from urllib.parse import urlparse, ParseResult import requests from common.utils.common import get_file_content from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ ModelInfoManage from models_provider.impl.vllm_model_provider.credential.embedding import VllmEmbeddingCredential from models_provider.impl.vllm_model_provider.credential.image import VllmImageModelCredential from models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential from models_provider.impl.vllm_model_provider.credential.reranker import VllmRerankerCredential from models_provider.impl.vllm_model_provider.credential.whisper_stt import VLLMWhisperModelCredential from models_provider.impl.vllm_model_provider.model.embedding import VllmEmbeddingModel from models_provider.impl.vllm_model_provider.model.image import VllmImage from models_provider.impl.vllm_model_provider.model.llm import VllmChatModel from maxkb.conf import PROJECT_DIR from django.utils.translation import gettext as _ from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker from models_provider.impl.vllm_model_provider.model.whisper_sst import VllmWhisperSpeechToText v_llm_model_credential = VLLMModelCredential() image_model_credential = VllmImageModelCredential() embedding_model_credential = VllmEmbeddingCredential() whisper_model_credential = VLLMWhisperModelCredential() rerank_model_credential = VllmRerankerCredential() model_info_list = [ ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), ModelInfo('BAAI/Aquila-7B', _('BAAI’s 7B parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), ModelInfo('BAAI/AquilaChat-7B', _('BAAI’s 13B parameter mode'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), ] image_model_info_list = [ ModelInfo('Qwen/Qwen2-VL-2B-Instruct', '', ModelTypeConst.IMAGE, image_model_credential, VllmImage), ] embedding_model_info_list = [ ModelInfo('HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5', '', ModelTypeConst.EMBEDDING, embedding_model_credential, VllmEmbeddingModel), ] whisper_model_info_list = [ ModelInfo('whisper-tiny', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), ModelInfo('whisper-large-v3-turbo', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), ModelInfo('whisper-small', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), ModelInfo('whisper-large-v3', '', ModelTypeConst.STT, whisper_model_credential, VllmWhisperSpeechToText), ] reranker_model_info_list = [ ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, rerank_model_credential, VllmBgeReranker), ] model_info_manage = ( ModelInfoManage.builder() .append_model_info_list(model_info_list) .append_default_model_info(ModelInfo('facebook/opt-125m', _('Facebook’s 125M parameter model'), ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel)) .append_model_info_list(image_model_info_list) .append_default_model_info(image_model_info_list[0]) .append_model_info_list(embedding_model_info_list) .append_default_model_info(embedding_model_info_list[0]) .append_model_info_list(whisper_model_info_list) .append_default_model_info(whisper_model_info_list[0]) .append_model_info_list(reranker_model_info_list) .append_default_model_info(reranker_model_info_list[0]) .build() ) def get_base_url(url: str): parse = urlparse(url) result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', query='', fragment='').geturl() return result_url[:-1] if result_url.endswith("/") else result_url class VllmModelProvider(IModelProvider): def get_model_info_manage(self): return model_info_manage def get_model_provide_info(self): return ModelProvideInfo(provider='model_vllm_provider', name='vLLM', icon=get_file_content( os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'vllm_model_provider', 'icon', 'vllm_icon_svg'))) @staticmethod def get_base_model_list(api_base, api_key): base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') headers = {} if api_key: headers['Authorization'] = f"Bearer {api_key}" r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5) r.raise_for_status() return r.json().get('data') @staticmethod def get_model_info_by_name(model_list, model_name): if model_list is None: return [] return [model for model in model_list if model.get('id') == model_name]