| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- # coding=utf-8
- """
- @project: MaxKB
- @file: image.py
- @desc: AWS Bedrock Vision-Language Model Implementation
- """
- from typing import Dict, List
- from botocore.config import Config
- from langchain_aws import ChatBedrock
- from langchain_core.messages import BaseMessage, get_buffer_string
- from common.config.tokenizer_manage_config import TokenizerManage
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
- class BedrockVLModel(MaxKBBaseModel, ChatBedrock):
- """
- AWS Bedrock Vision-Language Model
- Supports Claude 3 models with vision capabilities (Haiku, Sonnet, Opus)
- """
- @staticmethod
- def is_cache_model():
- return False
- def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
- streaming: bool = False, config: Config = None, **kwargs):
- super().__init__(
- model_id=model_id,
- region_name=region_name,
- credentials_profile_name=credentials_profile_name,
- streaming=streaming,
- config=config,
- **kwargs
- )
- @classmethod
- def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
- **model_kwargs) -> 'BedrockVLModel':
- optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
- config = {}
- # Check if proxy URL is provided
- if 'base_url' in model_credential and model_credential['base_url']:
- proxy_url = model_credential['base_url']
- config = Config(
- proxies={
- 'http': proxy_url,
- 'https': proxy_url
- },
- connect_timeout=60,
- read_timeout=60
- )
- _update_aws_credentials(
- model_credential['access_key_id'],
- model_credential['access_key_id'],
- model_credential['secret_access_key']
- )
- return cls(
- model_id=model_name,
- region_name=model_credential['region_name'],
- credentials_profile_name=model_credential['access_key_id'],
- streaming=model_kwargs.pop('streaming', True),
- model_kwargs=optional_params,
- config=config
- )
- def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
- """
- Get the number of tokens from messages
- Falls back to local tokenizer if the model's tokenizer fails
- """
- try:
- return super().get_num_tokens_from_messages(messages)
- except Exception as e:
- tokenizer = TokenizerManage.get_tokenizer()
- return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
- def get_num_tokens(self, text: str) -> int:
- """
- Get the number of tokens from text
- Falls back to local tokenizer if the model's tokenizer fails
- """
- try:
- return super().get_num_tokens(text)
- except Exception as e:
- tokenizer = TokenizerManage.get_tokenizer()
- return len(tokenizer.encode(text))
|