| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- # coding=utf-8
- """
- @project: maxkb
- @Author:虎
- @file: llm.py
- @date:2023/11/10 17:45
- @desc:
- """
- from typing import List, Dict, Optional, Any, Iterator
- from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint
- from langchain_core.callbacks import CallbackManagerForLLMRun
- from langchain_core.messages import (
- AIMessageChunk,
- BaseMessage,
- )
- from langchain_core.outputs import ChatGenerationChunk
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
- class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
- return QianfanChatModelQianfan(model=model_name,
- qianfan_ak=model_credential.get('api_key'),
- qianfan_sk=model_credential.get('secret_key'),
- streaming=model_kwargs.get('streaming', False),
- init_kwargs=optional_params)
- usage_metadata: dict = {}
- def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
- return self.usage_metadata
- def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
- return self.usage_metadata.get('prompt_tokens', 0)
- def get_num_tokens(self, text: str) -> int:
- return self.usage_metadata.get('completion_tokens', 0)
- def _stream(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> Iterator[ChatGenerationChunk]:
- kwargs = {**self.init_kwargs, **kwargs}
- params = self._convert_prompt_msg_params(messages, **kwargs)
- params["stop"] = stop
- params["stream"] = True
- for res in self.client.do(**params):
- if res:
- msg = _convert_dict_to_message(res)
- additional_kwargs = msg.additional_kwargs.get("function_call", {})
- if msg.content == "" or res.get("body").get("is_end"):
- token_usage = res.get("body").get("usage")
- self.usage_metadata = token_usage
- chunk = ChatGenerationChunk(
- text=res["result"],
- message=AIMessageChunk( # type: ignore[call-arg]
- content=msg.content,
- role="assistant",
- additional_kwargs=additional_kwargs,
- ),
- generation_info=msg.additional_kwargs,
- )
- if run_manager:
- run_manager.on_llm_new_token(chunk.text, chunk=chunk)
- yield chunk
- class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
- return QianfanChatModelOpenai(
- model=model_name,
- openai_api_base=model_credential.get('api_base'),
- openai_api_key=model_credential.get('api_key'),
- extra_body=optional_params
- )
- class QianfanChatModel(MaxKBBaseModel):
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- api_version = model_credential.get('api_version', 'v1')
- if api_version == "v1":
- return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
- elif api_version == "v2":
- return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)
|