# coding=utf-8 """ OpenAI 兼容私域模型调用 """ import json from typing import Iterator, Dict, List, Optional, Any import requests from django.utils.translation import gettext as _ from common.config.tokenizer_manage_config import TokenizerManage from common.exception.app_exception import AppApiException from common.utils.logger import maxkb_logger from langchain_core.messages import BaseMessage, get_buffer_string from models_provider.base_model_provider import ModelInfo, ModelTypeConst, MaxKBBaseModel class OpenAICompatibleChatModel(MaxKBBaseModel): """OpenAI 兼容模型调用""" @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): from models_provider.impl.openai_compatible_provider.openai_compatible_provider import model_info_manage model_info = model_info_manage.get_model_info(model_type, model_name) return OpenAICompatibleChatModel(model_info=model_info, model_credential=model_credential, model_name=model_name, **model_kwargs) def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], model_name: str = None, **kwargs): self.model_info = model_info self.model_credential = model_credential self.api_base = model_credential.get('api_base', '').rstrip('/') self.api_key = model_credential.get('api_key', '') self.model_name = model_name or model_info.name def _build_headers(self): return { 'Content-Type': 'application/json', 'Authorization': f'Bearer {self.api_key}', } @staticmethod def _convert_messages(messages): """将 langchain Message 对象转换为 OpenAI API 的 dict 格式""" converted = [] for msg in messages: if hasattr(msg, 'content'): role = 'system' if hasattr(msg, '__class__'): class_name = msg.__class__.__name__ if 'Human' in class_name: role = 'user' elif 'AI' in class_name: role = 'assistant' elif 'System' in class_name: role = 'system' converted.append({'role': role, 'content': msg.content}) elif isinstance(msg, dict): converted.append(msg) else: converted.append({'role': 'user', 'content': str(msg)}) return converted def stream(self, messages, **kwargs): """流式聊天 - yield 带 content 和 additional_kwargs 属性的对象""" url = f"{self.api_base}/chat/completions" api_messages = self._convert_messages(messages) body = { 'model': self.model_name, 'messages': api_messages, 'stream': True, **kwargs } try: response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120) response.raise_for_status() for line in response.iter_lines(): if not line: continue line = line.decode('utf-8') if not line.startswith('data: '): continue data = line[6:] if data.strip() == '[DONE]': break try: chunk = json.loads(data) delta = chunk.get('choices', [{}])[0].get('delta', {}) content = delta.get('content') if content: yield type('StreamChunk', (), { 'content': content, 'additional_kwargs': {}, })() except json.JSONDecodeError: continue except requests.exceptions.Timeout: raise AppApiException(504, _('Request timeout')) except requests.exceptions.ConnectionError: raise AppApiException(502, _('Failed to connect to API server')) except AppApiException: raise except Exception as e: maxkb_logger.error(f'OpenAI compatible stream error: {e}') raise AppApiException(500, str(e)) def invoke(self, messages, **kwargs): """非流式聊天 - 返回带 content 和 additional_kwargs 的对象""" url = f"{self.api_base}/chat/completions" api_messages = self._convert_messages(messages) body = { 'model': self.model_name, 'messages': api_messages, 'stream': False, **kwargs } try: response = requests.post(url, json=body, headers=self._build_headers(), timeout=120) response.raise_for_status() result = response.json() content = result.get('choices', [{}])[0].get('message', {}).get('content', '') return type('InvokeResult', (), { 'content': content, 'additional_kwargs': {}, })() except requests.exceptions.Timeout: raise AppApiException(504, _('Request timeout')) except requests.exceptions.ConnectionError: raise AppApiException(502, _('Failed to connect to API server')) except Exception as e: maxkb_logger.error(f'OpenAI compatible invoke error: {e}') raise AppApiException(500, str(e)) def chat(self, messages, stream=False, **kwargs): """聊天补全""" if stream: return self.stream(messages, **kwargs) return self.invoke(messages, **kwargs) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: try: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) except Exception: return 0 def get_num_tokens(self, text: str) -> int: try: tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) except Exception: return 0 def verify_connection(self): """验证连接""" try: url = f"{self.api_base}/models" response = requests.get(url, headers=self._build_headers(), timeout=10) response.raise_for_status() return True except Exception as e: raise AppApiException(500, _('Connection verification failed: {error}').format(error=str(e)))