llm.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. @Project :MaxKB
  5. @File :llm.py
  6. @Author :Brian Yang
  7. @Date :5/13/24 7:40 AM
  8. """
  9. from typing import List, Dict, Optional, Any
  10. from langchain_core.messages import BaseMessage, get_buffer_string
  11. from langchain_google_genai import ChatGoogleGenerativeAI
  12. from common.config.tokenizer_manage_config import TokenizerManage
  13. from models_provider.base_model_provider import MaxKBBaseModel
  14. class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI):
  15. @staticmethod
  16. def is_cache_model():
  17. return False
  18. @staticmethod
  19. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  20. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  21. base_url = model_credential.get('base_url', "https://generativelanguage.googleapis.com")
  22. if base_url:
  23. optional_params.setdefault("model_kwargs", {})
  24. optional_params["model_kwargs"]["http_options"] = {"base_url": base_url}
  25. gemini_chat = GeminiChatModel(
  26. model=model_name,
  27. api_key=model_credential.get('api_key'),
  28. **optional_params
  29. )
  30. return gemini_chat
  31. def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
  32. return self.__dict__.get('_last_generation_info')
  33. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  34. try:
  35. return self.get_last_generation_info().get('input_tokens', 0)
  36. except Exception as e:
  37. tokenizer = TokenizerManage.get_tokenizer()
  38. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  39. def get_num_tokens(self, text: str) -> int:
  40. try:
  41. return self.get_last_generation_info().get('output_tokens', 0)
  42. except Exception as e:
  43. tokenizer = TokenizerManage.get_tokenizer()
  44. return len(tokenizer.encode(text))