llm.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # coding=utf-8
  2. from typing import Dict, List
  3. from urllib.parse import urlparse, ParseResult
  4. from langchain_core.messages import BaseMessage, get_buffer_string
  5. from common.config.tokenizer_manage_config import TokenizerManage
  6. from models_provider.base_model_provider import MaxKBBaseModel
  7. from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
  8. def get_base_url(url: str):
  9. parse = urlparse(url)
  10. result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
  11. query='',
  12. fragment='').geturl()
  13. return result_url[:-1] if result_url.endswith("/") else result_url
  14. class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI):
  15. @staticmethod
  16. def is_cache_model():
  17. return False
  18. @staticmethod
  19. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  20. api_base = model_credential.get('api_base', '')
  21. base_url = get_base_url(api_base)
  22. base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1')
  23. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  24. return XinferenceChatModel(
  25. model=model_name,
  26. openai_api_base=base_url,
  27. openai_api_key=model_credential.get('api_key'),
  28. **optional_params,
  29. )
  30. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  31. if self.usage_metadata is None or self.usage_metadata == {}:
  32. tokenizer = TokenizerManage.get_tokenizer()
  33. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  34. return self.usage_metadata.get('input_tokens', 0)
  35. def get_num_tokens(self, text: str) -> int:
  36. if self.usage_metadata is None or self.usage_metadata == {}:
  37. tokenizer = TokenizerManage.get_tokenizer()
  38. return len(tokenizer.encode(text))
  39. return self.get_last_generation_info().get('output_tokens', 0)