llm.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # coding=utf-8
  2. """
  3. OpenAI 兼容私域模型调用
  4. """
  5. import json
  6. from typing import Iterator, Dict, List, Optional, Any
  7. import requests
  8. from django.utils.translation import gettext as _
  9. from common.config.tokenizer_manage_config import TokenizerManage
  10. from common.exception.app_exception import AppApiException
  11. from common.utils.logger import maxkb_logger
  12. from langchain_core.messages import BaseMessage, get_buffer_string
  13. from models_provider.base_model_provider import ModelInfo, ModelTypeConst, MaxKBBaseModel
  14. class OpenAICompatibleChatModel(MaxKBBaseModel):
  15. """OpenAI 兼容模型调用"""
  16. @staticmethod
  17. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  18. from models_provider.impl.openai_compatible_provider.openai_compatible_provider import model_info_manage
  19. model_info = model_info_manage.get_model_info(model_type, model_name)
  20. return OpenAICompatibleChatModel(model_info=model_info, model_credential=model_credential,
  21. model_name=model_name, **model_kwargs)
  22. def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], model_name: str = None, **kwargs):
  23. self.model_info = model_info
  24. self.model_credential = model_credential
  25. self.api_base = model_credential.get('api_base', '').rstrip('/')
  26. self.api_key = model_credential.get('api_key', '')
  27. self.model_name = model_name or model_info.name
  28. def _build_headers(self):
  29. return {
  30. 'Content-Type': 'application/json',
  31. 'Authorization': f'Bearer {self.api_key}',
  32. }
  33. @staticmethod
  34. def _convert_messages(messages):
  35. """将 langchain Message 对象转换为 OpenAI API 的 dict 格式"""
  36. converted = []
  37. for msg in messages:
  38. if hasattr(msg, 'content'):
  39. role = 'system'
  40. if hasattr(msg, '__class__'):
  41. class_name = msg.__class__.__name__
  42. if 'Human' in class_name:
  43. role = 'user'
  44. elif 'AI' in class_name:
  45. role = 'assistant'
  46. elif 'System' in class_name:
  47. role = 'system'
  48. converted.append({'role': role, 'content': msg.content})
  49. elif isinstance(msg, dict):
  50. converted.append(msg)
  51. else:
  52. converted.append({'role': 'user', 'content': str(msg)})
  53. return converted
  54. def stream(self, messages, **kwargs):
  55. """流式聊天 - yield 带 content 和 additional_kwargs 属性的对象"""
  56. url = f"{self.api_base}/chat/completions"
  57. api_messages = self._convert_messages(messages)
  58. body = {
  59. 'model': self.model_name,
  60. 'messages': api_messages,
  61. 'stream': True,
  62. **kwargs
  63. }
  64. try:
  65. response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
  66. response.raise_for_status()
  67. for line in response.iter_lines():
  68. if not line:
  69. continue
  70. line = line.decode('utf-8')
  71. if not line.startswith('data: '):
  72. continue
  73. data = line[6:]
  74. if data.strip() == '[DONE]':
  75. break
  76. try:
  77. chunk = json.loads(data)
  78. delta = chunk.get('choices', [{}])[0].get('delta', {})
  79. content = delta.get('content')
  80. if content:
  81. yield type('StreamChunk', (), {
  82. 'content': content,
  83. 'additional_kwargs': {},
  84. })()
  85. except json.JSONDecodeError:
  86. continue
  87. except requests.exceptions.Timeout:
  88. raise AppApiException(504, _('Request timeout'))
  89. except requests.exceptions.ConnectionError:
  90. raise AppApiException(502, _('Failed to connect to API server'))
  91. except AppApiException:
  92. raise
  93. except Exception as e:
  94. maxkb_logger.error(f'OpenAI compatible stream error: {e}')
  95. raise AppApiException(500, str(e))
  96. def invoke(self, messages, **kwargs):
  97. """非流式聊天 - 返回带 content 和 additional_kwargs 的对象"""
  98. url = f"{self.api_base}/chat/completions"
  99. api_messages = self._convert_messages(messages)
  100. body = {
  101. 'model': self.model_name,
  102. 'messages': api_messages,
  103. 'stream': False,
  104. **kwargs
  105. }
  106. try:
  107. response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
  108. response.raise_for_status()
  109. result = response.json()
  110. content = result.get('choices', [{}])[0].get('message', {}).get('content', '')
  111. return type('InvokeResult', (), {
  112. 'content': content,
  113. 'additional_kwargs': {},
  114. })()
  115. except requests.exceptions.Timeout:
  116. raise AppApiException(504, _('Request timeout'))
  117. except requests.exceptions.ConnectionError:
  118. raise AppApiException(502, _('Failed to connect to API server'))
  119. except Exception as e:
  120. maxkb_logger.error(f'OpenAI compatible invoke error: {e}')
  121. raise AppApiException(500, str(e))
  122. def chat(self, messages, stream=False, **kwargs):
  123. """聊天补全"""
  124. if stream:
  125. return self.stream(messages, **kwargs)
  126. return self.invoke(messages, **kwargs)
  127. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
  128. try:
  129. tokenizer = TokenizerManage.get_tokenizer()
  130. return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
  131. except Exception:
  132. return 0
  133. def get_num_tokens(self, text: str) -> int:
  134. try:
  135. tokenizer = TokenizerManage.get_tokenizer()
  136. return len(tokenizer.encode(text))
  137. except Exception:
  138. return 0
  139. def verify_connection(self):
  140. """验证连接"""
  141. try:
  142. url = f"{self.api_base}/models"
  143. response = requests.get(url, headers=self._build_headers(), timeout=10)
  144. response.raise_for_status()
  145. return True
  146. except Exception as e:
  147. raise AppApiException(500, _('Connection verification failed: {error}').format(error=str(e)))