image.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @file: image.py
  5. @desc: AWS Bedrock Vision-Language Model Implementation
  6. """
  7. from typing import Dict, List
  8. from botocore.config import Config
  9. from langchain_aws import ChatBedrock
  10. from langchain_core.messages import BaseMessage, get_buffer_string
  11. from common.config.tokenizer_manage_config import TokenizerManage
  12. from models_provider.base_model_provider import MaxKBBaseModel
  13. from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
  14. class BedrockVLModel(MaxKBBaseModel, ChatBedrock):
  15. """
  16. AWS Bedrock Vision-Language Model
  17. Supports Claude 3 models with vision capabilities (Haiku, Sonnet, Opus)
  18. """
  19. @staticmethod
  20. def is_cache_model():
  21. return False
  22. def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
  23. streaming: bool = False, config: Config = None, **kwargs):
  24. super().__init__(
  25. model_id=model_id,
  26. region_name=region_name,
  27. credentials_profile_name=credentials_profile_name,
  28. streaming=streaming,
  29. config=config,
  30. **kwargs
  31. )
  32. @classmethod
  33. def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
  34. **model_kwargs) -> 'BedrockVLModel':
  35. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  36. config = {}
  37. # Check if proxy URL is provided
  38. if 'base_url' in model_credential and model_credential['base_url']:
  39. proxy_url = model_credential['base_url']
  40. config = Config(
  41. proxies={
  42. 'http': proxy_url,
  43. 'https': proxy_url
  44. },
  45. connect_timeout=60,
  46. read_timeout=60
  47. )
  48. _update_aws_credentials(
  49. model_credential['access_key_id'],
  50. model_credential['access_key_id'],
  51. model_credential['secret_access_key']
  52. )
  53. return cls(
  54. model_id=model_name,
  55. region_name=model_credential['region_name'],
  56. credentials_profile_name=model_credential['access_key_id'],
  57. streaming=model_kwargs.pop('streaming', True),
  58. model_kwargs=optional_params,
  59. config=config
  60. )
  61. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  62. """
  63. Get the number of tokens from messages
  64. Falls back to local tokenizer if the model's tokenizer fails
  65. """
  66. try:
  67. return super().get_num_tokens_from_messages(messages)
  68. except Exception as e:
  69. tokenizer = TokenizerManage.get_tokenizer()
  70. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  71. def get_num_tokens(self, text: str) -> int:
  72. """
  73. Get the number of tokens from text
  74. Falls back to local tokenizer if the model's tokenizer fails
  75. """
  76. try:
  77. return super().get_num_tokens(text)
  78. except Exception as e:
  79. tokenizer = TokenizerManage.get_tokenizer()
  80. return len(tokenizer.encode(text))