| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # coding=utf-8
- """
- OpenAI 兼容私域模型调用
- """
- import json
- from typing import Iterator, Dict
- import requests
- from django.utils.translation import gettext as _
- from common.exception.app_exception import AppApiException
- from common.utils.logger import maxkb_logger
- from models_provider.base_model_provider import ModelInfo, ModelTypeConst
- class OpenAICompatibleChatModel:
- """OpenAI 兼容模型调用"""
- def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], **kwargs):
- self.model_info = model_info
- self.model_credential = model_credential
- self.api_base = model_credential.get('api_base', '').rstrip('/')
- self.api_key = model_credential.get('api_key', '')
- self.model_name = model_credential.get('model_name', model_info.model_name)
- def _build_headers(self):
- return {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {self.api_key}',
- }
- def chat(self, messages, stream=False, **kwargs):
- """聊天补全"""
- url = f"{self.api_base}/v1/chat/completions"
- body = {
- 'model': self.model_name,
- 'messages': messages,
- 'stream': stream,
- **kwargs
- }
- try:
- if stream:
- return self._stream_chat(url, body)
- else:
- response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
- response.raise_for_status()
- return response.json()
- except requests.exceptions.Timeout:
- raise AppApiException(504, _('Request timeout'))
- except requests.exceptions.ConnectionError:
- raise AppApiException(502, _('Failed to connect to API server'))
- except Exception as e:
- maxkb_logger.error(f'OpenAI compatible API error: {e}')
- raise AppApiException(500, str(e))
- def _stream_chat(self, url, body):
- """流式聊天"""
- response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
- response.raise_for_status()
- def generate():
- try:
- for line in response.iter_lines():
- if line:
- yield line.decode('utf-8') + '\n'
- finally:
- response.close()
- return generate()
- def verify_connection(self):
- """验证连接"""
- try:
- url = f"{self.api_base}/v1/models"
- response = requests.get(url, headers=self._build_headers(), timeout=10)
- response.raise_for_status()
- return True
- except Exception as e:
- raise AppApiException(500, _('Connection verification failed: {error}').format(error=str(e)))
|