image.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from typing import Dict, List
  2. from langchain_core.messages import BaseMessage, get_buffer_string
  3. from langchain_openai import AzureChatOpenAI
  4. from common.config.tokenizer_manage_config import TokenizerManage
  5. from models_provider.base_model_provider import MaxKBBaseModel
  6. def custom_get_token_ids(text: str):
  7. tokenizer = TokenizerManage.get_tokenizer()
  8. return tokenizer.encode(text)
  9. class AzureOpenAIImage(MaxKBBaseModel, AzureChatOpenAI):
  10. @staticmethod
  11. def is_cache_model():
  12. return False
  13. @staticmethod
  14. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  15. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  16. return AzureOpenAIImage(
  17. model_name=model_name,
  18. openai_api_key=model_credential.get('api_key'),
  19. azure_endpoint=model_credential.get('api_base'),
  20. openai_api_version=model_credential.get('api_version'),
  21. openai_api_type="azure",
  22. streaming=True,
  23. **optional_params,
  24. )
  25. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  26. try:
  27. return super().get_num_tokens_from_messages(messages)
  28. except Exception as e:
  29. tokenizer = TokenizerManage.get_tokenizer()
  30. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  31. def get_num_tokens(self, text: str) -> int:
  32. try:
  33. return super().get_num_tokens(text)
  34. except Exception as e:
  35. tokenizer = TokenizerManage.get_tokenizer()
  36. return len(tokenizer.encode(text))