llm.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: llm.py
  6. @date:2024/3/6 11:48
  7. @desc:
  8. """
  9. from typing import List, Dict
  10. from urllib.parse import urlparse, ParseResult
  11. from langchain_core.messages import BaseMessage, get_buffer_string
  12. from langchain_ollama.chat_models import ChatOllama
  13. from common.config.tokenizer_manage_config import TokenizerManage
  14. from models_provider.base_model_provider import MaxKBBaseModel
  15. def get_base_url(url: str):
  16. parse = urlparse(url)
  17. result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
  18. query='',
  19. fragment='').geturl()
  20. return result_url[:-1] if result_url.endswith("/") else result_url
  21. class OllamaChatModel(MaxKBBaseModel, ChatOllama):
  22. @staticmethod
  23. def is_cache_model():
  24. return False
  25. @staticmethod
  26. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  27. api_base = model_credential.get('api_base', '')
  28. base_url = get_base_url(api_base)
  29. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  30. return OllamaChatModel(model=model_name, base_url=base_url,
  31. stream=True, **optional_params)
  32. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  33. tokenizer = TokenizerManage.get_tokenizer()
  34. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  35. def get_num_tokens(self, text: str) -> int:
  36. tokenizer = TokenizerManage.get_tokenizer()
  37. return len(tokenizer.encode(text))