hunyuan.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import json
  2. import logging
  3. from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
  4. from langchain_core.callbacks import CallbackManagerForLLMRun
  5. from langchain_core.language_models.chat_models import (
  6. BaseChatModel,
  7. generate_from_stream,
  8. )
  9. from langchain_core.messages import (
  10. AIMessage,
  11. AIMessageChunk,
  12. BaseMessage,
  13. BaseMessageChunk,
  14. ChatMessage,
  15. ChatMessageChunk,
  16. HumanMessage,
  17. HumanMessageChunk, SystemMessage,
  18. )
  19. from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
  20. from pydantic import Field, SecretStr, root_validator
  21. from langchain_core.utils import (
  22. convert_to_secret_str,
  23. get_from_dict_or_env,
  24. get_pydantic_field_names,
  25. pre_init,
  26. )
  27. logger = logging.getLogger(__name__)
  28. def _convert_message_to_dict(message: BaseMessage) -> dict:
  29. message_dict: Dict[str, Any]
  30. if isinstance(message, ChatMessage):
  31. message_dict = {"Role": message.role, "Content": message.content}
  32. elif isinstance(message, HumanMessage):
  33. message_dict = {"Role": "user", "Content": message.content}
  34. elif isinstance(message, AIMessage):
  35. message_dict = {"Role": "assistant", "Content": message.content}
  36. elif isinstance(message, SystemMessage):
  37. message_dict = {"Role": "system", "Content": message.content}
  38. else:
  39. raise TypeError(f"Got unknown type {message}")
  40. return message_dict
  41. def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
  42. role = _dict["Role"]
  43. if role == "user":
  44. return HumanMessage(content=_dict["Content"])
  45. elif role == "assistant":
  46. return AIMessage(content=_dict.get("Content", "") or "")
  47. else:
  48. return ChatMessage(content=_dict["Content"], role=role)
  49. def _convert_delta_to_message_chunk(
  50. _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
  51. ) -> BaseMessageChunk:
  52. role = _dict.get("Role")
  53. content = _dict.get("Content") or ""
  54. if role == "user" or default_class == HumanMessageChunk:
  55. return HumanMessageChunk(content=content)
  56. elif role == "assistant" or default_class == AIMessageChunk:
  57. return AIMessageChunk(content=content)
  58. elif role or default_class == ChatMessageChunk:
  59. return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type]
  60. else:
  61. return default_class(content=content) # type: ignore[call-arg]
  62. def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
  63. generations = []
  64. for choice in response["Choices"]:
  65. message = _convert_dict_to_message(choice["Message"])
  66. generations.append(ChatGeneration(message=message))
  67. token_usage = response["Usage"]
  68. llm_output = {"token_usage": token_usage}
  69. return ChatResult(generations=generations, llm_output=llm_output)
  70. class ChatHunyuan(BaseChatModel):
  71. """Tencent Hunyuan chat models API by Tencent.
  72. For more information, see https://cloud.tencent.com/document/product/1729
  73. """
  74. @property
  75. def lc_secrets(self) -> Dict[str, str]:
  76. return {
  77. "hunyuan_app_id": "HUNYUAN_APP_ID",
  78. "hunyuan_secret_id": "HUNYUAN_SECRET_ID",
  79. "hunyuan_secret_key": "HUNYUAN_SECRET_KEY",
  80. }
  81. @property
  82. def lc_serializable(self) -> bool:
  83. return True
  84. hunyuan_app_id: Optional[int] = None
  85. """Hunyuan App ID"""
  86. hunyuan_secret_id: Optional[str] = None
  87. """Hunyuan Secret ID"""
  88. hunyuan_secret_key: Optional[SecretStr] = None
  89. """Hunyuan Secret Key"""
  90. streaming: bool = False
  91. """Whether to stream the results or not."""
  92. request_timeout: int = 60
  93. """Timeout for requests to Hunyuan API. Default is 60 seconds."""
  94. temperature: float = 1.0
  95. """What sampling temperature to use."""
  96. top_p: float = 1.0
  97. """What probability mass to use."""
  98. model: str = "hunyuan-lite"
  99. """What Model to use.
  100. Optional model:
  101. - hunyuan-lite、
  102. - hunyuan-standard
  103. - hunyuan-standard-256K
  104. - hunyuan-pro
  105. - hunyuan-code
  106. - hunyuan-role
  107. - hunyuan-functioncall
  108. - hunyuan-vision
  109. """
  110. stream_moderation: bool = False
  111. """Whether to review the results or not when streaming is true."""
  112. enable_enhancement: bool = True
  113. """Whether to enhancement the results or not."""
  114. model_kwargs: Dict[str, Any] = Field(default_factory=dict)
  115. """Holds any model parameters valid for API call not explicitly specified."""
  116. class Config:
  117. """Configuration for this pydantic object."""
  118. validate_by_name = True
  119. @root_validator(pre=True)
  120. def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
  121. """Build extra kwargs from additional params that were passed in."""
  122. all_required_field_names = get_pydantic_field_names(cls)
  123. extra = values.get("model_kwargs", {})
  124. for field_name in list(values):
  125. if field_name in extra:
  126. raise ValueError(f"Found {field_name} supplied twice.")
  127. if field_name not in all_required_field_names:
  128. logger.warning(
  129. f"""WARNING! {field_name} is not default parameter.
  130. {field_name} was transferred to model_kwargs.
  131. Please confirm that {field_name} is what you intended."""
  132. )
  133. extra[field_name] = values.pop(field_name)
  134. invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
  135. if invalid_model_kwargs:
  136. raise ValueError(
  137. f"Parameters {invalid_model_kwargs} should be specified explicitly. "
  138. f"Instead they were passed in as part of `model_kwargs` parameter."
  139. )
  140. values["model_kwargs"] = extra
  141. return values
  142. @pre_init
  143. def validate_environment(cls, values: Dict) -> Dict:
  144. values["hunyuan_app_id"] = get_from_dict_or_env(
  145. values,
  146. "hunyuan_app_id",
  147. "HUNYUAN_APP_ID",
  148. )
  149. values["hunyuan_secret_id"] = get_from_dict_or_env(
  150. values,
  151. "hunyuan_secret_id",
  152. "HUNYUAN_SECRET_ID",
  153. )
  154. values["hunyuan_secret_key"] = convert_to_secret_str(
  155. get_from_dict_or_env(
  156. values,
  157. "hunyuan_secret_key",
  158. "HUNYUAN_SECRET_KEY",
  159. )
  160. )
  161. return values
  162. @property
  163. def _default_params(self) -> Dict[str, Any]:
  164. """Get the default parameters for calling Hunyuan API."""
  165. normal_params = {
  166. "Temperature": self.temperature,
  167. "TopP": self.top_p,
  168. "Model": self.model,
  169. "Stream": self.streaming,
  170. "StreamModeration": self.stream_moderation,
  171. "EnableEnhancement": self.enable_enhancement,
  172. }
  173. return {**normal_params, **self.model_kwargs}
  174. def _generate(
  175. self,
  176. messages: List[BaseMessage],
  177. stop: Optional[List[str]] = None,
  178. run_manager: Optional[CallbackManagerForLLMRun] = None,
  179. **kwargs: Any,
  180. ) -> ChatResult:
  181. if self.streaming:
  182. stream_iter = self._stream(
  183. messages=messages, stop=stop, run_manager=run_manager, **kwargs
  184. )
  185. return generate_from_stream(stream_iter)
  186. res = self._chat(messages, **kwargs)
  187. return _create_chat_result(json.loads(res.to_json_string()))
  188. usage_metadata: dict = {}
  189. def _stream(
  190. self,
  191. messages: List[BaseMessage],
  192. stop: Optional[List[str]] = None,
  193. run_manager: Optional[CallbackManagerForLLMRun] = None,
  194. **kwargs: Any,
  195. ) -> Iterator[ChatGenerationChunk]:
  196. res = self._chat(messages, **kwargs)
  197. default_chunk_class = AIMessageChunk
  198. for chunk in res:
  199. chunk = chunk.get("data", "")
  200. if len(chunk) == 0:
  201. continue
  202. response = json.loads(chunk)
  203. if "error" in response:
  204. raise ValueError(f"Error from Hunyuan api response: {response}")
  205. for choice in response["Choices"]:
  206. chunk = _convert_delta_to_message_chunk(
  207. choice["Delta"], default_chunk_class
  208. )
  209. default_chunk_class = chunk.__class__
  210. # FinishReason === stop
  211. if choice.get("FinishReason") == "stop":
  212. self.usage_metadata = response.get("Usage", {})
  213. cg_chunk = ChatGenerationChunk(message=chunk)
  214. if run_manager:
  215. run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
  216. yield cg_chunk
  217. def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any:
  218. if self.hunyuan_secret_key is None:
  219. raise ValueError("Hunyuan secret key is not set.")
  220. try:
  221. from tencentcloud.common import credential
  222. from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
  223. except ImportError:
  224. raise ImportError(
  225. "Could not import tencentcloud python package. "
  226. "Please install it with `pip install tencentcloud-sdk-python`."
  227. )
  228. parameters = {**self._default_params, **kwargs}
  229. cred = credential.Credential(
  230. self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value())
  231. )
  232. client = hunyuan_client.HunyuanClient(cred, "")
  233. req = models.ChatCompletionsRequest()
  234. params = {
  235. "Messages": [_convert_message_to_dict(m) for m in messages],
  236. **parameters,
  237. }
  238. req.from_json_string(json.dumps(params))
  239. resp = client.ChatCompletions(req)
  240. return resp
  241. @property
  242. def _llm_type(self) -> str:
  243. return "hunyuan-chat"