image.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Dict, List
  2. from langchain_core.messages import get_buffer_string, BaseMessage
  3. from common.config.tokenizer_manage_config import TokenizerManage
  4. from models_provider.base_model_provider import MaxKBBaseModel
  5. from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
  6. class VllmImage(MaxKBBaseModel, BaseChatOpenAI):
  7. @staticmethod
  8. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  9. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  10. return VllmImage(
  11. model_name=model_name,
  12. openai_api_base=model_credential.get('api_base'),
  13. openai_api_key=model_credential.get('api_key'),
  14. # stream_options={"include_usage": True},
  15. streaming=True,
  16. stream_usage=True,
  17. **optional_params,
  18. )
  19. def is_cache_model(self):
  20. return False
  21. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  22. if self.usage_metadata is None or self.usage_metadata == {}:
  23. tokenizer = TokenizerManage.get_tokenizer()
  24. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  25. return self.usage_metadata.get('input_tokens', 0)
  26. def get_num_tokens(self, text: str) -> int:
  27. if self.usage_metadata is None or self.usage_metadata == {}:
  28. tokenizer = TokenizerManage.get_tokenizer()
  29. return len(tokenizer.encode(text))
  30. return self.get_last_generation_info().get('output_tokens', 0)