| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # coding=utf-8
- """
- OpenAI 兼容私域模型调用
- """
- import json
- from typing import Iterator, Dict, List, Optional, Any
- import requests
- from django.utils.translation import gettext as _
- from common.config.tokenizer_manage_config import TokenizerManage
- from common.exception.app_exception import AppApiException
- from common.utils.logger import maxkb_logger
- from langchain_core.messages import BaseMessage, get_buffer_string
- from models_provider.base_model_provider import ModelInfo, ModelTypeConst, MaxKBBaseModel
- class OpenAICompatibleChatModel(MaxKBBaseModel):
- """OpenAI 兼容模型调用"""
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- from models_provider.impl.openai_compatible_provider.openai_compatible_provider import model_info_manage
- model_info = model_info_manage.get_model_info(model_type, model_name)
- return OpenAICompatibleChatModel(model_info=model_info, model_credential=model_credential,
- model_name=model_name, **model_kwargs)
- def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], model_name: str = None, **kwargs):
- self.model_info = model_info
- self.model_credential = model_credential
- self.api_base = model_credential.get('api_base', '').rstrip('/')
- self.api_key = model_credential.get('api_key', '')
- self.model_name = model_name or model_info.name
- def _build_headers(self):
- return {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {self.api_key}',
- }
- @staticmethod
- def _convert_messages(messages):
- """将 langchain Message 对象转换为 OpenAI API 的 dict 格式"""
- converted = []
- for msg in messages:
- if hasattr(msg, 'content'):
- role = 'system'
- if hasattr(msg, '__class__'):
- class_name = msg.__class__.__name__
- if 'Human' in class_name:
- role = 'user'
- elif 'AI' in class_name:
- role = 'assistant'
- elif 'System' in class_name:
- role = 'system'
- converted.append({'role': role, 'content': msg.content})
- elif isinstance(msg, dict):
- converted.append(msg)
- else:
- converted.append({'role': 'user', 'content': str(msg)})
- return converted
- def stream(self, messages, **kwargs):
- """流式聊天 - yield 带 content 和 additional_kwargs 属性的对象"""
- url = f"{self.api_base}/chat/completions"
- api_messages = self._convert_messages(messages)
- body = {
- 'model': self.model_name,
- 'messages': api_messages,
- 'stream': True,
- **kwargs
- }
- try:
- response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
- response.raise_for_status()
- for line in response.iter_lines():
- if not line:
- continue
- line = line.decode('utf-8')
- if not line.startswith('data: '):
- continue
- data = line[6:]
- if data.strip() == '[DONE]':
- break
- try:
- chunk = json.loads(data)
- delta = chunk.get('choices', [{}])[0].get('delta', {})
- content = delta.get('content')
- if content:
- yield type('StreamChunk', (), {
- 'content': content,
- 'additional_kwargs': {},
- })()
- except json.JSONDecodeError:
- continue
- except requests.exceptions.Timeout:
- raise AppApiException(504, _('Request timeout'))
- except requests.exceptions.ConnectionError:
- raise AppApiException(502, _('Failed to connect to API server'))
- except AppApiException:
- raise
- except Exception as e:
- maxkb_logger.error(f'OpenAI compatible stream error: {e}')
- raise AppApiException(500, str(e))
- def invoke(self, messages, **kwargs):
- """非流式聊天 - 返回带 content 和 additional_kwargs 的对象"""
- url = f"{self.api_base}/chat/completions"
- api_messages = self._convert_messages(messages)
- body = {
- 'model': self.model_name,
- 'messages': api_messages,
- 'stream': False,
- **kwargs
- }
- try:
- response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
- response.raise_for_status()
- result = response.json()
- content = result.get('choices', [{}])[0].get('message', {}).get('content', '')
- return type('InvokeResult', (), {
- 'content': content,
- 'additional_kwargs': {},
- })()
- except requests.exceptions.Timeout:
- raise AppApiException(504, _('Request timeout'))
- except requests.exceptions.ConnectionError:
- raise AppApiException(502, _('Failed to connect to API server'))
- except Exception as e:
- maxkb_logger.error(f'OpenAI compatible invoke error: {e}')
- raise AppApiException(500, str(e))
- def chat(self, messages, stream=False, **kwargs):
- """聊天补全"""
- if stream:
- return self.stream(messages, **kwargs)
- return self.invoke(messages, **kwargs)
- def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
- try:
- tokenizer = TokenizerManage.get_tokenizer()
- return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
- except Exception:
- return 0
- def get_num_tokens(self, text: str) -> int:
- try:
- tokenizer = TokenizerManage.get_tokenizer()
- return len(tokenizer.encode(text))
- except Exception:
- return 0
- def verify_connection(self):
- """验证连接"""
- try:
- url = f"{self.api_base}/models"
- response = requests.get(url, headers=self._build_headers(), timeout=10)
- response.raise_for_status()
- return True
- except Exception as e:
- raise AppApiException(500, _('Connection verification failed: {error}').format(error=str(e)))
|