llm.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. @Project :MaxKB
  5. @File :llm.py
  6. @Author :Brian Yang
  7. @Date :5/12/24 7:44 AM
  8. """
  9. import json
  10. from typing import Dict, Any
  11. from langchain_core.language_models import LanguageModelInput
  12. from langchain_core.messages import AIMessage
  13. from models_provider.base_model_provider import MaxKBBaseModel
  14. from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
  15. class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI):
  16. @staticmethod
  17. def is_cache_model():
  18. return False
  19. @staticmethod
  20. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  21. optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
  22. deepseek_chat_open_ai = DeepSeekChatModel(
  23. model=model_name,
  24. openai_api_base=model_credential.get('api_base') or 'https://api.deepseek.com',
  25. openai_api_key=model_credential.get('api_key'),
  26. **optional_params,
  27. )
  28. return deepseek_chat_open_ai
  29. def _get_request_payload(
  30. self,
  31. input_: LanguageModelInput,
  32. *,
  33. stop: list[str] | None = None,
  34. **kwargs: Any,
  35. ) -> dict:
  36. # Get original messages to preserve reasoning_content before base conversion
  37. messages = self._convert_input(input_).to_messages()
  38. # Store reasoning_content for AIMessages with tool_calls
  39. # According to DeepSeek API docs, reasoning_content is REQUIRED when tool_calls
  40. # are present during the tool invocation process (within same question/turn).
  41. # See: https://api-docs.deepseek.com/guides/thinking_mode#tool-calls
  42. reasoning_content_map = {}
  43. for i, msg in enumerate(messages):
  44. if (
  45. isinstance(msg, AIMessage)
  46. and (msg.tool_calls or msg.invalid_tool_calls)
  47. and (reasoning := msg.additional_kwargs.get("reasoning_content"))
  48. ):
  49. reasoning_content_map[i] = reasoning
  50. payload = super()._get_request_payload(input_, stop=stop, **kwargs)
  51. # Restore reasoning_content for assistant messages with tool_calls
  52. # This is required by DeepSeek API - missing it causes 400 error
  53. if "messages" in payload and reasoning_content_map:
  54. for i, message in enumerate(payload["messages"]):
  55. if (
  56. i in reasoning_content_map
  57. and message.get("role") == "assistant"
  58. and message.get("tool_calls")
  59. ):
  60. message["reasoning_content"] = reasoning_content_map[i]
  61. # Apply DeepSeek-specific message formatting
  62. for message in payload["messages"]:
  63. if message["role"] == "tool" and isinstance(message["content"], list):
  64. message["content"] = json.dumps(message["content"])
  65. elif message["role"] == "assistant" and isinstance(
  66. message["content"], list
  67. ):
  68. # DeepSeek API expects assistant content to be a string, not a list.
  69. # Extract text blocks and join them, or use empty string if none exist.
  70. text_parts = [
  71. block.get("text", "")
  72. for block in message["content"]
  73. if isinstance(block, dict) and block.get("type") == "text"
  74. ]
  75. message["content"] = "".join(text_parts) if text_parts else ""
  76. return payload