image.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # coding=utf-8
  2. import base64
  3. import os
  4. from typing import Dict, Any, List, Optional, Iterator
  5. #from docutils.utils import SystemMessage
  6. from langchain_community.chat_models.sparkllm import ChatSparkLLM, _convert_delta_to_message_chunk
  7. from langchain_core.callbacks import CallbackManagerForLLMRun
  8. from langchain_core.messages import BaseMessage, ChatMessage, HumanMessage, AIMessage, AIMessageChunk
  9. from langchain_core.outputs import ChatGenerationChunk
  10. from models_provider.base_model_provider import MaxKBBaseModel
  11. class ImageMessage(HumanMessage):
  12. content: str
  13. def convert_message_to_dict(message: BaseMessage) -> dict:
  14. message_dict: Dict[str, Any]
  15. if isinstance(message, ChatMessage):
  16. message_dict = {"role": "user", "content": message.content}
  17. elif isinstance(message, ImageMessage):
  18. message_dict = {"role": "user", "content": message.content, "content_type": "image"}
  19. elif isinstance(message, HumanMessage):
  20. message_dict = {"role": "user", "content": message.content}
  21. elif isinstance(message, AIMessage):
  22. message_dict = {"role": "assistant", "content": message.content}
  23. if "function_call" in message.additional_kwargs:
  24. message_dict["function_call"] = message.additional_kwargs["function_call"]
  25. # If function call only, content is None not empty string
  26. if message_dict["content"] == "":
  27. message_dict["content"] = None
  28. if "tool_calls" in message.additional_kwargs:
  29. message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
  30. # If tool calls only, content is None not empty string
  31. if message_dict["content"] == "":
  32. message_dict["content"] = None
  33. # elif isinstance(message, SystemMessage):
  34. # message_dict = {"role": "system", "content": message.content}
  35. else:
  36. raise ValueError(f"Got unknown type {message}")
  37. return message_dict
  38. class XFSparkImage(MaxKBBaseModel, ChatSparkLLM):
  39. spark_app_id: str
  40. spark_api_key: str
  41. spark_api_secret: str
  42. spark_api_url: str
  43. @staticmethod
  44. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  45. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  46. return XFSparkImage(
  47. spark_app_id=model_credential.get('spark_app_id'),
  48. spark_api_key=model_credential.get('spark_api_key'),
  49. spark_api_secret=model_credential.get('spark_api_secret'),
  50. spark_api_url=model_credential.get('spark_api_url'),
  51. **optional_params
  52. )
  53. @staticmethod
  54. def generate_message(prompt: str, image) -> list[BaseMessage]:
  55. if image is None:
  56. cwd = os.path.dirname(os.path.abspath(__file__))
  57. with open(f'{cwd}/img_1.png', 'rb') as f:
  58. base64_image = base64.b64encode(f.read()).decode("utf-8")
  59. return [ImageMessage(f'data:image/jpeg;base64,{base64_image}'), HumanMessage(prompt)]
  60. return [HumanMessage(prompt)]
  61. def _stream(
  62. self,
  63. messages: List[BaseMessage],
  64. stop: Optional[List[str]] = None,
  65. run_manager: Optional[CallbackManagerForLLMRun] = None,
  66. **kwargs: Any,
  67. ) -> Iterator[ChatGenerationChunk]:
  68. default_chunk_class = AIMessageChunk
  69. self.client.arun(
  70. [convert_message_to_dict(m) for m in messages],
  71. self.spark_user_id,
  72. self.model_kwargs,
  73. streaming=True,
  74. )
  75. for content in self.client.subscribe(timeout=self.request_timeout):
  76. if "data" not in content:
  77. continue
  78. delta = content["data"]
  79. chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
  80. cg_chunk = ChatGenerationChunk(message=chunk)
  81. if run_manager:
  82. run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
  83. yield cg_chunk
  84. @staticmethod
  85. def is_cache_model():
  86. return False