base_chat_open_ai.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # coding=utf-8
  2. import base64
  3. from concurrent.futures import ThreadPoolExecutor
  4. from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping
  5. from langchain_core.language_models import LanguageModelInput
  6. from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, HumanMessageChunk, AIMessageChunk, \
  7. SystemMessageChunk, FunctionMessageChunk, ChatMessageChunk
  8. from langchain_core.messages.ai import UsageMetadata
  9. from langchain_core.messages.tool import tool_call_chunk, ToolMessageChunk
  10. from langchain_core.outputs import ChatGenerationChunk
  11. from langchain_core.runnables import RunnableConfig, ensure_config
  12. from langchain_core.tools import BaseTool
  13. from langchain_openai import ChatOpenAI
  14. from langchain_openai.chat_models.base import _create_usage_metadata
  15. from requests.exceptions import ReadTimeout
  16. from common.config.tokenizer_manage_config import TokenizerManage
  17. from common.utils.logger import maxkb_logger
  18. def custom_get_token_ids(text: str):
  19. tokenizer = TokenizerManage.get_tokenizer()
  20. return tokenizer.encode(text)
  21. def _convert_delta_to_message_chunk(
  22. _dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
  23. ) -> BaseMessageChunk:
  24. """Convert to a LangChain message chunk."""
  25. id_ = _dict.get("id")
  26. role = cast(str, _dict.get("role"))
  27. content = cast(str, _dict.get("content") or "")
  28. additional_kwargs: dict = {}
  29. if reasoning := _dict.get('reasoning_content') or _dict.get('reasoning'):
  30. additional_kwargs['reasoning_content'] = reasoning
  31. if _dict.get("function_call"):
  32. function_call = dict(_dict["function_call"])
  33. if "name" in function_call and function_call["name"] is None:
  34. function_call["name"] = ""
  35. additional_kwargs["function_call"] = function_call
  36. tool_call_chunks = []
  37. if raw_tool_calls := _dict.get("tool_calls"):
  38. try:
  39. tool_call_chunks = [
  40. tool_call_chunk(
  41. name=rtc["function"].get("name"),
  42. args=rtc["function"].get("arguments"),
  43. id=rtc.get("id"),
  44. index=rtc["index"],
  45. )
  46. for rtc in raw_tool_calls
  47. ]
  48. except KeyError:
  49. pass
  50. if role == "user" or default_class == HumanMessageChunk:
  51. return HumanMessageChunk(content=content, id=id_)
  52. if role == "assistant" or default_class == AIMessageChunk:
  53. return AIMessageChunk(
  54. content=content,
  55. additional_kwargs=additional_kwargs,
  56. id=id_,
  57. tool_call_chunks=tool_call_chunks, # type: ignore[arg-type]
  58. )
  59. if role in ("system", "developer") or default_class == SystemMessageChunk:
  60. if role == "developer":
  61. additional_kwargs = {"__openai_role__": "developer"}
  62. else:
  63. additional_kwargs = {}
  64. return SystemMessageChunk(
  65. content=content, id=id_, additional_kwargs=additional_kwargs
  66. )
  67. if role == "function" or default_class == FunctionMessageChunk:
  68. return FunctionMessageChunk(content=content, name=_dict["name"], id=id_)
  69. if role == "tool" or default_class == ToolMessageChunk:
  70. return ToolMessageChunk(
  71. content=content, tool_call_id=_dict["tool_call_id"], id=id_
  72. )
  73. if role or default_class == ChatMessageChunk:
  74. return ChatMessageChunk(content=content, role=role, id=id_)
  75. return default_class(content=content, id=id_) # type: ignore[call-arg]
  76. class BaseChatOpenAI(ChatOpenAI):
  77. usage_metadata: dict = {}
  78. custom_get_token_ids = custom_get_token_ids
  79. def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
  80. return self.usage_metadata
  81. def get_num_tokens_from_messages(
  82. self,
  83. messages: list[BaseMessage],
  84. tools: Optional[
  85. Sequence[Union[dict[str, Any], type, Callable, BaseTool]]
  86. ] = None,
  87. timeout: Optional[float] = 0.5,
  88. ) -> int:
  89. if self.usage_metadata is None or self.usage_metadata == {}:
  90. with ThreadPoolExecutor(max_workers=1) as executor:
  91. future = executor.submit(super().get_num_tokens_from_messages, messages, tools)
  92. try:
  93. response = future.result()
  94. maxkb_logger.info("请求成功(未超时)")
  95. return response
  96. except Exception as e:
  97. if isinstance(e, ReadTimeout):
  98. raise # 继续抛出
  99. else:
  100. tokenizer = TokenizerManage.get_tokenizer()
  101. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  102. return self.usage_metadata.get('input_tokens', self.usage_metadata.get('prompt_tokens', 0))
  103. def get_num_tokens(self, text: str) -> int:
  104. if self.usage_metadata is None or self.usage_metadata == {}:
  105. try:
  106. return super().get_num_tokens(text)
  107. except Exception as e:
  108. tokenizer = TokenizerManage.get_tokenizer()
  109. return len(tokenizer.encode(text))
  110. return self.get_last_generation_info().get('output_tokens',
  111. self.get_last_generation_info().get('completion_tokens', 0))
  112. def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
  113. kwargs['stream_usage'] = True
  114. for chunk in super()._stream(*args, **kwargs):
  115. if chunk.message.usage_metadata is not None:
  116. self.usage_metadata = chunk.message.usage_metadata
  117. yield chunk
  118. def _convert_chunk_to_generation_chunk(
  119. self,
  120. chunk: dict,
  121. default_chunk_class: type,
  122. base_generation_info: dict | None,
  123. ) -> ChatGenerationChunk | None:
  124. if chunk.get("type") == "content.delta": # From beta.chat.completions.stream
  125. return None
  126. token_usage = chunk.get("usage")
  127. choices = (
  128. chunk.get("choices", [])
  129. # From beta.chat.completions.stream
  130. or chunk.get("chunk", {}).get("choices", [])
  131. )
  132. usage_metadata: UsageMetadata | None = (
  133. _create_usage_metadata(token_usage, chunk.get("service_tier"))
  134. if token_usage
  135. else None
  136. )
  137. if len(choices) == 0:
  138. # logprobs is implicitly None
  139. generation_chunk = ChatGenerationChunk(
  140. message=default_chunk_class(content="", usage_metadata=usage_metadata),
  141. generation_info=base_generation_info,
  142. )
  143. if self.output_version == "v1":
  144. generation_chunk.message.content = []
  145. generation_chunk.message.response_metadata["output_version"] = "v1"
  146. return generation_chunk
  147. choice = choices[0]
  148. if choice["delta"] is None:
  149. return None
  150. message_chunk = _convert_delta_to_message_chunk(
  151. choice["delta"], default_chunk_class
  152. )
  153. generation_info = {**base_generation_info} if base_generation_info else {}
  154. if finish_reason := choice.get("finish_reason"):
  155. generation_info["finish_reason"] = finish_reason
  156. if model_name := chunk.get("model"):
  157. generation_info["model_name"] = model_name
  158. if system_fingerprint := chunk.get("system_fingerprint"):
  159. generation_info["system_fingerprint"] = system_fingerprint
  160. if service_tier := chunk.get("service_tier"):
  161. generation_info["service_tier"] = service_tier
  162. logprobs = choice.get("logprobs")
  163. if logprobs:
  164. generation_info["logprobs"] = logprobs
  165. if usage_metadata and isinstance(message_chunk, AIMessageChunk):
  166. message_chunk.usage_metadata = usage_metadata
  167. message_chunk.response_metadata["model_provider"] = "openai"
  168. return ChatGenerationChunk(
  169. message=message_chunk, generation_info=generation_info or None
  170. )
  171. def invoke(
  172. self,
  173. input: LanguageModelInput,
  174. config: Optional[RunnableConfig] = None,
  175. *,
  176. stop: Optional[list[str]] = None,
  177. **kwargs: Any,
  178. ) -> BaseMessage:
  179. config = ensure_config(config)
  180. chat_result = cast(
  181. "ChatGeneration",
  182. self.generate_prompt(
  183. [self._convert_input(input)],
  184. stop=stop,
  185. callbacks=config.get("callbacks"),
  186. tags=config.get("tags"),
  187. metadata=config.get("metadata"),
  188. run_name=config.get("run_name"),
  189. run_id=config.pop("run_id", None),
  190. **kwargs,
  191. ).generations[0][0],
  192. ).message
  193. self.usage_metadata = chat_result.response_metadata[
  194. 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
  195. return chat_result
  196. def upload_file_and_get_url(self, file_stream, file_name):
  197. """上传文件并获取文件URL"""
  198. base64_video = base64.b64encode(file_stream).decode("utf-8")
  199. video_format = get_video_format(file_name)
  200. return f'data:{video_format};base64,{base64_video}'
  201. def get_video_format(file_name):
  202. extension = file_name.split('.')[-1].lower()
  203. format_map = {
  204. 'mp4': 'video/mp4',
  205. 'avi': 'video/avi',
  206. 'mov': 'video/mov',
  207. 'wmv': 'video/x-ms-wmv'
  208. }
  209. return format_map.get(extension, 'video/mp4')