llm.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: llm.py
  6. @date:2023/11/10 17:45
  7. @desc:
  8. """
  9. from typing import List, Dict, Optional, Any, Iterator
  10. from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint
  11. from langchain_core.callbacks import CallbackManagerForLLMRun
  12. from langchain_core.messages import (
  13. AIMessageChunk,
  14. BaseMessage,
  15. )
  16. from langchain_core.outputs import ChatGenerationChunk
  17. from models_provider.base_model_provider import MaxKBBaseModel
  18. from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
  19. class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
  20. @staticmethod
  21. def is_cache_model():
  22. return False
  23. @staticmethod
  24. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  25. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  26. return QianfanChatModelQianfan(model=model_name,
  27. qianfan_ak=model_credential.get('api_key'),
  28. qianfan_sk=model_credential.get('secret_key'),
  29. streaming=model_kwargs.get('streaming', False),
  30. init_kwargs=optional_params)
  31. usage_metadata: dict = {}
  32. def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
  33. return self.usage_metadata
  34. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  35. return self.usage_metadata.get('prompt_tokens', 0)
  36. def get_num_tokens(self, text: str) -> int:
  37. return self.usage_metadata.get('completion_tokens', 0)
  38. def _stream(
  39. self,
  40. messages: List[BaseMessage],
  41. stop: Optional[List[str]] = None,
  42. run_manager: Optional[CallbackManagerForLLMRun] = None,
  43. **kwargs: Any,
  44. ) -> Iterator[ChatGenerationChunk]:
  45. kwargs = {**self.init_kwargs, **kwargs}
  46. params = self._convert_prompt_msg_params(messages, **kwargs)
  47. params["stop"] = stop
  48. params["stream"] = True
  49. for res in self.client.do(**params):
  50. if res:
  51. msg = _convert_dict_to_message(res)
  52. additional_kwargs = msg.additional_kwargs.get("function_call", {})
  53. if msg.content == "" or res.get("body").get("is_end"):
  54. token_usage = res.get("body").get("usage")
  55. self.usage_metadata = token_usage
  56. chunk = ChatGenerationChunk(
  57. text=res["result"],
  58. message=AIMessageChunk( # type: ignore[call-arg]
  59. content=msg.content,
  60. role="assistant",
  61. additional_kwargs=additional_kwargs,
  62. ),
  63. generation_info=msg.additional_kwargs,
  64. )
  65. if run_manager:
  66. run_manager.on_llm_new_token(chunk.text, chunk=chunk)
  67. yield chunk
  68. class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
  69. @staticmethod
  70. def is_cache_model():
  71. return False
  72. @staticmethod
  73. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  74. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  75. return QianfanChatModelOpenai(
  76. model=model_name,
  77. openai_api_base=model_credential.get('api_base'),
  78. openai_api_key=model_credential.get('api_key'),
  79. extra_body=optional_params
  80. )
  81. class QianfanChatModel(MaxKBBaseModel):
  82. @staticmethod
  83. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  84. api_version = model_credential.get('api_version', 'v1')
  85. if api_version == "v1":
  86. return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
  87. elif api_version == "v2":
  88. return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)