| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- # coding=utf-8
- """
- OpenAI 兼容网关视图
- 提供 /api/v1/chat/completions 和 /api/v1/models 接口
- 用于外部客户端通过平台API Key访问私域模型
- """
- import json
- import time
- import uuid
- import requests
- from typing import Optional
- from django.http import StreamingHttpResponse
- from rest_framework.views import APIView
- from rest_framework.request import Request
- from common.result import result
- from models_provider.models.platform_api_key import PlatformApiKey, PlatformApiKeyStatus
- from models_provider.services.crypto_utils import hash_api_key
- from models_provider.models import Model
- def _verify_bearer_token(request):
- """
- 验证 Bearer Token,返回 (user_id, api_key_id) 或 None
- """
- auth_header = request.META.get('HTTP_AUTHORIZATION', '')
- if not auth_header.startswith('Bearer '):
- return None
- api_key = auth_header[7:]
- hashed_key = hash_api_key(api_key)
- api_key_record = PlatformApiKey.objects.filter(api_key_hash=hashed_key).first()
- if not api_key_record or api_key_record.status != PlatformApiKeyStatus.ACTIVE:
- return None
- # 更新最后使用时间
- from django.utils import timezone
- api_key_record.last_used_at = timezone.now()
- api_key_record.save(update_fields=["last_used_at"])
- return (str(api_key_record.user_id), str(api_key_record.id))
- def _get_model_by_name(model_name, user_id):
- """
- 根据模型名称查找模型
- 支持模糊匹配:优先精确匹配 model_name,再匹配 name
- """
- # 精确匹配 model_name
- model = Model.objects.filter(model_name=model_name, status='SUCCESS').first()
- if model:
- return model
- # 精确匹配 name
- model = Model.objects.filter(name=model_name, status='SUCCESS').first()
- if model:
- return model
- # 模糊匹配
- model = Model.objects.filter(model_name__icontains=model_name, status='SUCCESS').first()
- if model:
- return model
- return None
- def _get_model_credential(model):
- """解密模型凭证"""
- try:
- credential = model.credential
- if isinstance(credential, str):
- credential = json.loads(credential)
- return credential
- except Exception:
- return {}
- def _call_openai_compatible(base_url, api_key, model_name, request_body, stream=False):
- """
- 调用 OpenAI 兼容接口
- """
- url = f"{base_url.rstrip('/')}/v1/chat/completions"
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {api_key}',
- }
- body = {**request_body, 'model': model_name}
- if stream:
- return requests.post(url, json=body, headers=headers, stream=True, timeout=120)
- else:
- response = requests.post(url, json=body, headers=headers, timeout=120)
- return response.json()
- class OpenAIGatewayView(APIView):
- """
- OpenAI 兼容网关
- POST /api/v1/chat/completions - 聊天补全
- GET /api/v1/models - 模型列表
- """
- def post(self, request: Request):
- """聊天补全接口"""
- auth_result = _verify_bearer_token(request)
- if not auth_result:
- return self._openai_error(401, "Incorrect API key provided", "authentication_error")
- user_id, api_key_id = auth_result
- body = request.data
- model_name = body.get('model')
- stream = body.get('stream', False)
- if not model_name:
- return self._openai_error(400, "model is required", "invalid_request_error")
- # 查找模型
- model = _get_model_by_name(model_name, user_id)
- if not model:
- return self._openai_error(404, f"The model '{model_name}' does not exist", "model_not_found")
- # 获取凭证(兼容不同提供商的字段名)
- credential = _get_model_credential(model)
- api_key = credential.get('api_key', '')
- # 兼容 api_base_url (OpenAI) 和 api_base (Docker AI/Ollama)
- base_url = credential.get('api_base_url', '') or credential.get('api_base', '')
- if not api_key or not base_url:
- return self._openai_error(500, "Model credential not configured", "server_error")
- try:
- if stream:
- return self._stream_response(base_url, api_key, model_name, body)
- else:
- response_data = _call_openai_compatible(base_url, api_key, model_name, body, stream=False)
- return result.success(response_data)
- except requests.exceptions.Timeout:
- return self._openai_error(504, "Gateway timeout", "server_error")
- except requests.exceptions.ConnectionError:
- return self._openai_error(502, "Failed to connect to upstream model", "server_error")
- except Exception as e:
- return self._openai_error(500, str(e), "server_error")
- def get(self, request: Request):
- """获取可用模型列表"""
- auth_result = _verify_bearer_token(request)
- if not auth_result:
- return self._openai_error(401, "Incorrect API key provided", "authentication_error")
- # 返回所有可用模型
- models = Model.objects.filter(status='SUCCESS').values('model_name', 'name')
- model_list = []
- seen = set()
- for m in models:
- name = m['model_name'] or m['name']
- if name and name not in seen:
- seen.add(name)
- model_list.append({
- "id": name,
- "object": "model",
- "owned_by": "zhagent",
- })
- return result.success({
- "object": "list",
- "data": model_list,
- })
- def _stream_response(self, base_url, api_key, model_name, body):
- """流式响应"""
- url = f"{base_url.rstrip('/')}/v1/chat/completions"
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {api_key}',
- }
- body = {**body, 'model': model_name, 'stream': True}
- try:
- upstream = requests.post(url, json=body, headers=headers, stream=True, timeout=120)
- upstream.raise_for_status()
- def generate():
- try:
- for line in upstream.iter_lines():
- if line:
- yield line.decode('utf-8') + '\n'
- finally:
- upstream.close()
- response = StreamingHttpResponse(
- generate(),
- content_type='text/event-stream',
- headers={
- 'Cache-Control': 'no-cache',
- 'Connection': 'keep-alive',
- 'X-Accel-Buffering': 'no',
- }
- )
- return response
- except Exception as e:
- return self._openai_error(502, f"Failed to stream from upstream: {e}", "server_error")
- @staticmethod
- def _openai_error(status_code, message, error_type):
- """返回 OpenAI 格式的错误响应"""
- from django.http import JsonResponse
- return JsonResponse(
- {"error": {"message": message, "type": error_type}},
- status=status_code,
- )
|