llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. import re
  3. from typing import Dict, List
  4. from botocore.config import Config
  5. from langchain_aws import ChatBedrock
  6. from langchain_core.messages import BaseMessage, get_buffer_string
  7. from common.config.tokenizer_manage_config import TokenizerManage
  8. from models_provider.base_model_provider import MaxKBBaseModel
  9. def get_max_tokens_keyword(model_name):
  10. """
  11. 根据模型名称返回正确的 max_tokens 关键字。
  12. :param model_name: 模型名称字符串
  13. :return: 对应的 max_tokens 关键字字符串
  14. """
  15. maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]
  16. # max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"]
  17. maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"]
  18. max_new_tokens = [
  19. "us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0",
  20. "us.meta.llama3-2-90b-instruct-v1:0"]
  21. if model_name in maxTokens:
  22. return 'maxTokens'
  23. elif model_name in maxTokenCount:
  24. return 'maxTokenCount'
  25. elif model_name in max_new_tokens:
  26. return 'max_new_tokens'
  27. else:
  28. return 'max_tokens'
  29. class BedrockModel(MaxKBBaseModel, ChatBedrock):
  30. @staticmethod
  31. def is_cache_model():
  32. return False
  33. def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
  34. streaming: bool = False, config: Config = None, **kwargs):
  35. super().__init__(model_id=model_id, region_name=region_name,
  36. credentials_profile_name=credentials_profile_name, streaming=streaming, config=config,
  37. **kwargs)
  38. @classmethod
  39. def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
  40. **model_kwargs) -> 'BedrockModel':
  41. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  42. config = {}
  43. # 判断model_kwargs是否包含 base_url 且不为空
  44. if 'base_url' in model_credential and model_credential['base_url']:
  45. proxy_url = model_credential['base_url']
  46. config = Config(
  47. proxies={
  48. 'http': proxy_url,
  49. 'https': proxy_url
  50. },
  51. connect_timeout=60,
  52. read_timeout=60
  53. )
  54. _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
  55. model_credential['secret_access_key'])
  56. return cls(
  57. model_id=model_name,
  58. region_name=model_credential['region_name'],
  59. credentials_profile_name=model_credential['access_key_id'],
  60. streaming=model_kwargs.pop('streaming', True),
  61. model_kwargs=optional_params,
  62. config=config
  63. )
  64. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  65. try:
  66. return super().get_num_tokens_from_messages(messages)
  67. except Exception as e:
  68. tokenizer = TokenizerManage.get_tokenizer()
  69. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  70. def get_num_tokens(self, text: str) -> int:
  71. try:
  72. return super().get_num_tokens(text)
  73. except Exception as e:
  74. tokenizer = TokenizerManage.get_tokenizer()
  75. return len(tokenizer.encode(text))
  76. def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
  77. credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
  78. os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
  79. content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
  80. pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
  81. content = re.sub(pattern, '', content, flags=re.DOTALL)
  82. if not re.search(rf'\[{profile_name}\]', content):
  83. content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
  84. with open(credentials_path, 'w') as file:
  85. file.write(content)