| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- import os
- import re
- from typing import Dict, List
- from botocore.config import Config
- from langchain_aws import ChatBedrock
- from langchain_core.messages import BaseMessage, get_buffer_string
- from common.config.tokenizer_manage_config import TokenizerManage
- from models_provider.base_model_provider import MaxKBBaseModel
- def get_max_tokens_keyword(model_name):
- """
- 根据模型名称返回正确的 max_tokens 关键字。
- :param model_name: 模型名称字符串
- :return: 对应的 max_tokens 关键字字符串
- """
- maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]
- # max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"]
- maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"]
- max_new_tokens = [
- "us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0",
- "us.meta.llama3-2-90b-instruct-v1:0"]
- if model_name in maxTokens:
- return 'maxTokens'
- elif model_name in maxTokenCount:
- return 'maxTokenCount'
- elif model_name in max_new_tokens:
- return 'max_new_tokens'
- else:
- return 'max_tokens'
- class BedrockModel(MaxKBBaseModel, ChatBedrock):
- @staticmethod
- def is_cache_model():
- return False
- def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
- streaming: bool = False, config: Config = None, **kwargs):
- super().__init__(model_id=model_id, region_name=region_name,
- credentials_profile_name=credentials_profile_name, streaming=streaming, config=config,
- **kwargs)
- @classmethod
- def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
- **model_kwargs) -> 'BedrockModel':
- optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
- config = {}
- # 判断model_kwargs是否包含 base_url 且不为空
- if 'base_url' in model_credential and model_credential['base_url']:
- proxy_url = model_credential['base_url']
- config = Config(
- proxies={
- 'http': proxy_url,
- 'https': proxy_url
- },
- connect_timeout=60,
- read_timeout=60
- )
- _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
- model_credential['secret_access_key'])
- return cls(
- model_id=model_name,
- region_name=model_credential['region_name'],
- credentials_profile_name=model_credential['access_key_id'],
- streaming=model_kwargs.pop('streaming', True),
- model_kwargs=optional_params,
- config=config
- )
- def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
- try:
- return super().get_num_tokens_from_messages(messages)
- except Exception as e:
- tokenizer = TokenizerManage.get_tokenizer()
- return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
- def get_num_tokens(self, text: str) -> int:
- try:
- return super().get_num_tokens(text)
- except Exception as e:
- tokenizer = TokenizerManage.get_tokenizer()
- return len(tokenizer.encode(text))
- def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
- credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
- os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
- content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
- pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
- content = re.sub(pattern, '', content, flags=re.DOTALL)
- if not re.search(rf'\[{profile_name}\]', content):
- content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
- with open(credentials_path, 'w') as file:
- file.write(content)
|