Parcourir la source

feat: 重构 OpenAI 兼容模型为 LangChain 标准接口

将 OpenAI Compatible 模型重构为继承 MaxKBBaseModel,实现标准
stream/invoke/chat 接口,支持 langchain Message 对象转换,
新增 token 计数功能。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
mengboxin137-blip il y a 6 jours
Parent
commit
948f52c28d
1 fichiers modifiés avec 114 ajouts et 30 suppressions
  1. 114 30
      apps/models_provider/impl/openai_compatible_provider/model/llm.py

+ 114 - 30
apps/models_provider/impl/openai_compatible_provider/model/llm.py

@@ -3,25 +3,34 @@
 OpenAI 兼容私域模型调用
 """
 import json
-from typing import Iterator, Dict
+from typing import Iterator, Dict, List, Optional, Any
 
 import requests
 from django.utils.translation import gettext as _
 
+from common.config.tokenizer_manage_config import TokenizerManage
 from common.exception.app_exception import AppApiException
 from common.utils.logger import maxkb_logger
-from models_provider.base_model_provider import ModelInfo, ModelTypeConst
+from langchain_core.messages import BaseMessage, get_buffer_string
+from models_provider.base_model_provider import ModelInfo, ModelTypeConst, MaxKBBaseModel
 
 
-class OpenAICompatibleChatModel:
+class OpenAICompatibleChatModel(MaxKBBaseModel):
     """OpenAI 兼容模型调用"""
 
-    def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], **kwargs):
+    @staticmethod
+    def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
+        from models_provider.impl.openai_compatible_provider.openai_compatible_provider import model_info_manage
+        model_info = model_info_manage.get_model_info(model_type, model_name)
+        return OpenAICompatibleChatModel(model_info=model_info, model_credential=model_credential,
+                                         model_name=model_name, **model_kwargs)
+
+    def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], model_name: str = None, **kwargs):
         self.model_info = model_info
         self.model_credential = model_credential
         self.api_base = model_credential.get('api_base', '').rstrip('/')
         self.api_key = model_credential.get('api_key', '')
-        self.model_name = model_credential.get('model_name', model_info.model_name)
+        self.model_name = model_name or model_info.name
 
     def _build_headers(self):
         return {
@@ -29,50 +38,125 @@ class OpenAICompatibleChatModel:
             'Authorization': f'Bearer {self.api_key}',
         }
 
-    def chat(self, messages, stream=False, **kwargs):
-        """聊天补全"""
-        url = f"{self.api_base}/v1/chat/completions"
+    @staticmethod
+    def _convert_messages(messages):
+        """将 langchain Message 对象转换为 OpenAI API 的 dict 格式"""
+        converted = []
+        for msg in messages:
+            if hasattr(msg, 'content'):
+                role = 'system'
+                if hasattr(msg, '__class__'):
+                    class_name = msg.__class__.__name__
+                    if 'Human' in class_name:
+                        role = 'user'
+                    elif 'AI' in class_name:
+                        role = 'assistant'
+                    elif 'System' in class_name:
+                        role = 'system'
+                converted.append({'role': role, 'content': msg.content})
+            elif isinstance(msg, dict):
+                converted.append(msg)
+            else:
+                converted.append({'role': 'user', 'content': str(msg)})
+        return converted
+
+    def stream(self, messages, **kwargs):
+        """流式聊天 - yield 带 content 和 additional_kwargs 属性的对象"""
+        url = f"{self.api_base}/chat/completions"
+        api_messages = self._convert_messages(messages)
         body = {
             'model': self.model_name,
-            'messages': messages,
-            'stream': stream,
+            'messages': api_messages,
+            'stream': True,
             **kwargs
         }
 
         try:
-            if stream:
-                return self._stream_chat(url, body)
-            else:
-                response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
-                response.raise_for_status()
-                return response.json()
+            response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
+            response.raise_for_status()
+
+            for line in response.iter_lines():
+                if not line:
+                    continue
+                line = line.decode('utf-8')
+                if not line.startswith('data: '):
+                    continue
+                data = line[6:]
+                if data.strip() == '[DONE]':
+                    break
+                try:
+                    chunk = json.loads(data)
+                    delta = chunk.get('choices', [{}])[0].get('delta', {})
+                    content = delta.get('content')
+                    if content:
+                        yield type('StreamChunk', (), {
+                            'content': content,
+                            'additional_kwargs': {},
+                        })()
+                except json.JSONDecodeError:
+                    continue
+        except requests.exceptions.Timeout:
+            raise AppApiException(504, _('Request timeout'))
+        except requests.exceptions.ConnectionError:
+            raise AppApiException(502, _('Failed to connect to API server'))
+        except AppApiException:
+            raise
+        except Exception as e:
+            maxkb_logger.error(f'OpenAI compatible stream error: {e}')
+            raise AppApiException(500, str(e))
+
+    def invoke(self, messages, **kwargs):
+        """非流式聊天 - 返回带 content 和 additional_kwargs 的对象"""
+        url = f"{self.api_base}/chat/completions"
+        api_messages = self._convert_messages(messages)
+        body = {
+            'model': self.model_name,
+            'messages': api_messages,
+            'stream': False,
+            **kwargs
+        }
+
+        try:
+            response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
+            response.raise_for_status()
+            result = response.json()
+            content = result.get('choices', [{}])[0].get('message', {}).get('content', '')
+            return type('InvokeResult', (), {
+                'content': content,
+                'additional_kwargs': {},
+            })()
         except requests.exceptions.Timeout:
             raise AppApiException(504, _('Request timeout'))
         except requests.exceptions.ConnectionError:
             raise AppApiException(502, _('Failed to connect to API server'))
         except Exception as e:
-            maxkb_logger.error(f'OpenAI compatible API error: {e}')
+            maxkb_logger.error(f'OpenAI compatible invoke error: {e}')
             raise AppApiException(500, str(e))
 
-    def _stream_chat(self, url, body):
-        """流式聊天"""
-        response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
-        response.raise_for_status()
+    def chat(self, messages, stream=False, **kwargs):
+        """聊天补全"""
+        if stream:
+            return self.stream(messages, **kwargs)
+        return self.invoke(messages, **kwargs)
 
-        def generate():
-            try:
-                for line in response.iter_lines():
-                    if line:
-                        yield line.decode('utf-8') + '\n'
-            finally:
-                response.close()
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        try:
+            tokenizer = TokenizerManage.get_tokenizer()
+            return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+        except Exception:
+            return 0
 
-        return generate()
+    def get_num_tokens(self, text: str) -> int:
+        try:
+            tokenizer = TokenizerManage.get_tokenizer()
+            return len(tokenizer.encode(text))
+        except Exception:
+            return 0
 
     def verify_connection(self):
         """验证连接"""
         try:
-            url = f"{self.api_base}/v1/models"
+            url = f"{self.api_base}/models"
             response = requests.get(url, headers=self._build_headers(), timeout=10)
             response.raise_for_status()
             return True