|
|
@@ -3,25 +3,34 @@
|
|
|
OpenAI 兼容私域模型调用
|
|
|
"""
|
|
|
import json
|
|
|
-from typing import Iterator, Dict
|
|
|
+from typing import Iterator, Dict, List, Optional, Any
|
|
|
|
|
|
import requests
|
|
|
from django.utils.translation import gettext as _
|
|
|
|
|
|
+from common.config.tokenizer_manage_config import TokenizerManage
|
|
|
from common.exception.app_exception import AppApiException
|
|
|
from common.utils.logger import maxkb_logger
|
|
|
-from models_provider.base_model_provider import ModelInfo, ModelTypeConst
|
|
|
+from langchain_core.messages import BaseMessage, get_buffer_string
|
|
|
+from models_provider.base_model_provider import ModelInfo, ModelTypeConst, MaxKBBaseModel
|
|
|
|
|
|
|
|
|
-class OpenAICompatibleChatModel:
|
|
|
+class OpenAICompatibleChatModel(MaxKBBaseModel):
|
|
|
"""OpenAI 兼容模型调用"""
|
|
|
|
|
|
- def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], **kwargs):
|
|
|
+ @staticmethod
|
|
|
+ def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
|
|
+ from models_provider.impl.openai_compatible_provider.openai_compatible_provider import model_info_manage
|
|
|
+ model_info = model_info_manage.get_model_info(model_type, model_name)
|
|
|
+ return OpenAICompatibleChatModel(model_info=model_info, model_credential=model_credential,
|
|
|
+ model_name=model_name, **model_kwargs)
|
|
|
+
|
|
|
+ def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], model_name: str = None, **kwargs):
|
|
|
self.model_info = model_info
|
|
|
self.model_credential = model_credential
|
|
|
self.api_base = model_credential.get('api_base', '').rstrip('/')
|
|
|
self.api_key = model_credential.get('api_key', '')
|
|
|
- self.model_name = model_credential.get('model_name', model_info.model_name)
|
|
|
+ self.model_name = model_name or model_info.name
|
|
|
|
|
|
def _build_headers(self):
|
|
|
return {
|
|
|
@@ -29,50 +38,125 @@ class OpenAICompatibleChatModel:
|
|
|
'Authorization': f'Bearer {self.api_key}',
|
|
|
}
|
|
|
|
|
|
- def chat(self, messages, stream=False, **kwargs):
|
|
|
- """聊天补全"""
|
|
|
- url = f"{self.api_base}/v1/chat/completions"
|
|
|
+ @staticmethod
|
|
|
+ def _convert_messages(messages):
|
|
|
+ """将 langchain Message 对象转换为 OpenAI API 的 dict 格式"""
|
|
|
+ converted = []
|
|
|
+ for msg in messages:
|
|
|
+ if hasattr(msg, 'content'):
|
|
|
+ role = 'system'
|
|
|
+ if hasattr(msg, '__class__'):
|
|
|
+ class_name = msg.__class__.__name__
|
|
|
+ if 'Human' in class_name:
|
|
|
+ role = 'user'
|
|
|
+ elif 'AI' in class_name:
|
|
|
+ role = 'assistant'
|
|
|
+ elif 'System' in class_name:
|
|
|
+ role = 'system'
|
|
|
+ converted.append({'role': role, 'content': msg.content})
|
|
|
+ elif isinstance(msg, dict):
|
|
|
+ converted.append(msg)
|
|
|
+ else:
|
|
|
+ converted.append({'role': 'user', 'content': str(msg)})
|
|
|
+ return converted
|
|
|
+
|
|
|
+ def stream(self, messages, **kwargs):
|
|
|
+ """流式聊天 - yield 带 content 和 additional_kwargs 属性的对象"""
|
|
|
+ url = f"{self.api_base}/chat/completions"
|
|
|
+ api_messages = self._convert_messages(messages)
|
|
|
body = {
|
|
|
'model': self.model_name,
|
|
|
- 'messages': messages,
|
|
|
- 'stream': stream,
|
|
|
+ 'messages': api_messages,
|
|
|
+ 'stream': True,
|
|
|
**kwargs
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
- if stream:
|
|
|
- return self._stream_chat(url, body)
|
|
|
- else:
|
|
|
- response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
|
|
|
- response.raise_for_status()
|
|
|
- return response.json()
|
|
|
+ response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
|
|
|
+ response.raise_for_status()
|
|
|
+
|
|
|
+ for line in response.iter_lines():
|
|
|
+ if not line:
|
|
|
+ continue
|
|
|
+ line = line.decode('utf-8')
|
|
|
+ if not line.startswith('data: '):
|
|
|
+ continue
|
|
|
+ data = line[6:]
|
|
|
+ if data.strip() == '[DONE]':
|
|
|
+ break
|
|
|
+ try:
|
|
|
+ chunk = json.loads(data)
|
|
|
+ delta = chunk.get('choices', [{}])[0].get('delta', {})
|
|
|
+ content = delta.get('content')
|
|
|
+ if content:
|
|
|
+ yield type('StreamChunk', (), {
|
|
|
+ 'content': content,
|
|
|
+ 'additional_kwargs': {},
|
|
|
+ })()
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ continue
|
|
|
+ except requests.exceptions.Timeout:
|
|
|
+ raise AppApiException(504, _('Request timeout'))
|
|
|
+ except requests.exceptions.ConnectionError:
|
|
|
+ raise AppApiException(502, _('Failed to connect to API server'))
|
|
|
+ except AppApiException:
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ maxkb_logger.error(f'OpenAI compatible stream error: {e}')
|
|
|
+ raise AppApiException(500, str(e))
|
|
|
+
|
|
|
+ def invoke(self, messages, **kwargs):
|
|
|
+ """非流式聊天 - 返回带 content 和 additional_kwargs 的对象"""
|
|
|
+ url = f"{self.api_base}/chat/completions"
|
|
|
+ api_messages = self._convert_messages(messages)
|
|
|
+ body = {
|
|
|
+ 'model': self.model_name,
|
|
|
+ 'messages': api_messages,
|
|
|
+ 'stream': False,
|
|
|
+ **kwargs
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
|
|
|
+ response.raise_for_status()
|
|
|
+ result = response.json()
|
|
|
+ content = result.get('choices', [{}])[0].get('message', {}).get('content', '')
|
|
|
+ return type('InvokeResult', (), {
|
|
|
+ 'content': content,
|
|
|
+ 'additional_kwargs': {},
|
|
|
+ })()
|
|
|
except requests.exceptions.Timeout:
|
|
|
raise AppApiException(504, _('Request timeout'))
|
|
|
except requests.exceptions.ConnectionError:
|
|
|
raise AppApiException(502, _('Failed to connect to API server'))
|
|
|
except Exception as e:
|
|
|
- maxkb_logger.error(f'OpenAI compatible API error: {e}')
|
|
|
+ maxkb_logger.error(f'OpenAI compatible invoke error: {e}')
|
|
|
raise AppApiException(500, str(e))
|
|
|
|
|
|
- def _stream_chat(self, url, body):
|
|
|
- """流式聊天"""
|
|
|
- response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
|
|
|
- response.raise_for_status()
|
|
|
+ def chat(self, messages, stream=False, **kwargs):
|
|
|
+ """聊天补全"""
|
|
|
+ if stream:
|
|
|
+ return self.stream(messages, **kwargs)
|
|
|
+ return self.invoke(messages, **kwargs)
|
|
|
|
|
|
- def generate():
|
|
|
- try:
|
|
|
- for line in response.iter_lines():
|
|
|
- if line:
|
|
|
- yield line.decode('utf-8') + '\n'
|
|
|
- finally:
|
|
|
- response.close()
|
|
|
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
|
+ try:
|
|
|
+ tokenizer = TokenizerManage.get_tokenizer()
|
|
|
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
|
|
+ except Exception:
|
|
|
+ return 0
|
|
|
|
|
|
- return generate()
|
|
|
+ def get_num_tokens(self, text: str) -> int:
|
|
|
+ try:
|
|
|
+ tokenizer = TokenizerManage.get_tokenizer()
|
|
|
+ return len(tokenizer.encode(text))
|
|
|
+ except Exception:
|
|
|
+ return 0
|
|
|
|
|
|
def verify_connection(self):
|
|
|
"""验证连接"""
|
|
|
try:
|
|
|
- url = f"{self.api_base}/v1/models"
|
|
|
+ url = f"{self.api_base}/models"
|
|
|
response = requests.get(url, headers=self._build_headers(), timeout=10)
|
|
|
response.raise_for_status()
|
|
|
return True
|