| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # coding=utf-8
- import base64
- import os
- from typing import Dict, Any, List, Optional, Iterator
- #from docutils.utils import SystemMessage
- from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk
- from langchain_core.callbacks import CallbackManagerForLLMRun
- from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk
- from langchain_core.outputs import ChatGenerationChunk
- from models_provider.base_model_provider import MaxKBBaseModel
- class ImageMessage(HumanMessage):
- content: str
- def convert_message_to_dict(message: BaseMessage) -> dict:
- message_dict: Dict[str, Any]
- if isinstance(message, ChatMessage):
- message_dict = {"role": "user", "content": message.content}
- elif isinstance(message, ImageMessage):
- message_dict = {"role": "user", "content": message.content, "content_type": "image"}
- elif isinstance(message, HumanMessage):
- message_dict = {"role": "user", "content": message.content}
- elif isinstance(message, AIMessage):
- message_dict = {"role": "assistant", "content": message.content}
- if "function_call" in message.additional_kwargs:
- message_dict["function_call"] = message.additional_kwargs["function_call"]
- # If function call only, content is None not empty string
- if message_dict["content"] == "":
- message_dict["content"] = None
- if "tool_calls" in message.additional_kwargs:
- message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
- # If tool calls only, content is None not empty string
- if message_dict["content"] == "":
- message_dict["content"] = None
- # elif isinstance(message, SystemMessage):
- # message_dict = {"role": "system", "content": message.content}
- else:
- raise ValueError(f"Got unknown type {message}")
- return message_dict
- class XFSparkImage(MaxKBBaseModel, ChatSparkLLM):
- spark_app_id: str
- spark_api_key: str
- spark_api_secret: str
- spark_api_url: str
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
- return XFSparkImage(
- spark_app_id=model_credential.get('spark_app_id'),
- spark_api_key=model_credential.get('spark_api_key'),
- spark_api_secret=model_credential.get('spark_api_secret'),
- spark_api_url=model_credential.get('spark_api_url'),
- **optional_params
- )
- @staticmethod
- def generate_message(prompt: str, image) -> list[BaseMessage]:
- if image is None:
- cwd = os.path.dirname(os.path.abspath(__file__))
- with open(f'{cwd}/img_1.png', 'rb') as f:
- base64_image = base64.b64encode(f.read()).decode("utf-8")
- return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)]
- return [HumanMessage(prompt)]
- def _stream(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> Iterator[ChatGenerationChunk]:
- default_chunk_class = AIMessageChunk
- self.client.arun(
- [convert_message_to_dict(m) for m in messages],
- self.spark_user_id,
- self.model_kwargs,
- streaming=True,
- )
- for content in self.client.subscribe(timeout=self.request_timeout):
- if "data" not in content:
- continue
- delta = content["data"]
- chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
- cg_chunk = ChatGenerationChunk(message=chunk)
- if run_manager:
- run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
- yield cg_chunk
- @staticmethod
- def is_cache_model():
- return False
|