llm.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # coding=utf-8
  2. from typing import List, Dict, Optional, Any
  3. from langchain_core.messages import BaseMessage
  4. from models_provider.base_model_provider import MaxKBBaseModel
  5. from models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
  6. class TencentModel(MaxKBBaseModel, ChatHunyuan):
  7. @staticmethod
  8. def is_cache_model():
  9. return False
  10. def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs):
  11. hunyuan_app_id = credentials.get('hunyuan_app_id')
  12. hunyuan_secret_id = credentials.get('hunyuan_secret_id')
  13. hunyuan_secret_key = credentials.get('hunyuan_secret_key')
  14. optional_params = MaxKBBaseModel.filter_optional_params(kwargs)
  15. if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]):
  16. raise ValueError(
  17. "All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.")
  18. super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id,
  19. hunyuan_secret_key=hunyuan_secret_key, streaming=streaming,
  20. temperature=optional_params.get('temperature', 1.0)
  21. )
  22. @staticmethod
  23. def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
  24. **model_kwargs) -> 'TencentModel':
  25. streaming = model_kwargs.pop('streaming', False)
  26. return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs)
  27. def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
  28. return self.usage_metadata
  29. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  30. return self.usage_metadata.get('PromptTokens', 0)
  31. def get_num_tokens(self, text: str) -> int:
  32. return self.usage_metadata.get('CompletionTokens', 0)