| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778 |
- # coding=utf-8
- """
- @project: maxkb
- @Author:虎
- @file: __init__.py.py
- @date:2024/04/19 15:55
- @desc:
- """
- from typing import List, Optional, Any, Iterator, Dict
- from langchain_community.chat_models.sparkllm import \
- ChatSparkLLM, convert_message_to_dict, _convert_delta_to_message_chunk
- from langchain_core.callbacks import CallbackManagerForLLMRun
- from langchain_core.messages import BaseMessage, AIMessageChunk
- from langchain_core.outputs import ChatGenerationChunk
- from models_provider.base_model_provider import MaxKBBaseModel
- class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
- return XFChatSparkLLM(
- spark_app_id=model_credential.get('spark_app_id'),
- spark_api_key=model_credential.get('spark_api_key'),
- spark_api_secret=model_credential.get('spark_api_secret'),
- spark_api_url=model_credential.get('spark_api_url'),
- spark_llm_domain=model_name,
- streaming=model_kwargs.get('streaming', False),
- **optional_params
- )
- usage_metadata: dict = {}
- def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
- return self.usage_metadata
- def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
- return self.usage_metadata.get('prompt_tokens', 0)
- def get_num_tokens(self, text: str) -> int:
- return self.usage_metadata.get('completion_tokens', 0)
- def _stream(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> Iterator[ChatGenerationChunk]:
- default_chunk_class = AIMessageChunk
- self.client.arun(
- [convert_message_to_dict(m) for m in messages],
- self.spark_user_id,
- self.model_kwargs,
- True,
- )
- for content in self.client.subscribe(timeout=self.request_timeout):
- if "data" in content:
- delta = content["data"]
- chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
- cg_chunk = ChatGenerationChunk(message=chunk)
- elif "usage" in content:
- generation_info = content["usage"]
- self.usage_metadata = generation_info
- continue
- else:
- continue
- if cg_chunk is not None:
- if run_manager:
- run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk)
- yield cg_chunk
|