llm.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: __init__.py.py
  6. @date:2024/04/19 15:55
  7. @desc:
  8. """
  9. from typing import List, Optional, Any, Iterator, Dict
  10. from langchain_community.chat_models.sparkllm import \
  11. ChatSparkLLM, convert_message_to_dict, _convert_delta_to_message_chunk
  12. from langchain_core.callbacks import CallbackManagerForLLMRun
  13. from langchain_core.messages import BaseMessage, AIMessageChunk
  14. from langchain_core.outputs import ChatGenerationChunk
  15. from models_provider.base_model_provider import MaxKBBaseModel
  16. class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM):
  17. @staticmethod
  18. def is_cache_model():
  19. return False
  20. @staticmethod
  21. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  22. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  23. return XFChatSparkLLM(
  24. spark_app_id=model_credential.get('spark_app_id'),
  25. spark_api_key=model_credential.get('spark_api_key'),
  26. spark_api_secret=model_credential.get('spark_api_secret'),
  27. spark_api_url=model_credential.get('spark_api_url'),
  28. spark_llm_domain=model_name,
  29. streaming=model_kwargs.get('streaming', False),
  30. **optional_params
  31. )
  32. usage_metadata: dict = {}
  33. def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
  34. return self.usage_metadata
  35. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  36. return self.usage_metadata.get('prompt_tokens', 0)
  37. def get_num_tokens(self, text: str) -> int:
  38. return self.usage_metadata.get('completion_tokens', 0)
  39. def _stream(
  40. self,
  41. messages: List[BaseMessage],
  42. stop: Optional[List[str]] = None,
  43. run_manager: Optional[CallbackManagerForLLMRun] = None,
  44. **kwargs: Any,
  45. ) -> Iterator[ChatGenerationChunk]:
  46. default_chunk_class = AIMessageChunk
  47. self.client.arun(
  48. [convert_message_to_dict(m) for m in messages],
  49. self.spark_user_id,
  50. self.model_kwargs,
  51. True,
  52. )
  53. for content in self.client.subscribe(timeout=self.request_timeout):
  54. if "data" in content:
  55. delta = content["data"]
  56. chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
  57. cg_chunk = ChatGenerationChunk(message=chunk)
  58. elif "usage" in content:
  59. generation_info = content["usage"]
  60. self.usage_metadata = generation_info
  61. continue
  62. else:
  63. continue
  64. if cg_chunk is not None:
  65. if run_manager:
  66. run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk)
  67. yield cg_chunk