openai_gateway_view.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # coding=utf-8
  2. """
  3. OpenAI 兼容网关视图
  4. 提供 /api/v1/chat/completions 和 /api/v1/models 接口
  5. 用于外部客户端通过平台API Key访问私域模型
  6. """
  7. import json
  8. import time
  9. import uuid
  10. import requests
  11. from typing import Optional
  12. from django.http import StreamingHttpResponse
  13. from rest_framework.views import APIView
  14. from rest_framework.request import Request
  15. from common.result import result
  16. from models_provider.models.platform_api_key import PlatformApiKey, PlatformApiKeyStatus
  17. from models_provider.services.crypto_utils import hash_api_key
  18. from models_provider.models import Model
  19. def _verify_bearer_token(request):
  20. """
  21. 验证 Bearer Token,返回 (user_id, api_key_id) 或 None
  22. """
  23. auth_header = request.META.get('HTTP_AUTHORIZATION', '')
  24. if not auth_header.startswith('Bearer '):
  25. return None
  26. api_key = auth_header[7:]
  27. hashed_key = hash_api_key(api_key)
  28. api_key_record = PlatformApiKey.objects.filter(api_key_hash=hashed_key).first()
  29. if not api_key_record or api_key_record.status != PlatformApiKeyStatus.ACTIVE:
  30. return None
  31. # 更新最后使用时间
  32. from django.utils import timezone
  33. api_key_record.last_used_at = timezone.now()
  34. api_key_record.save(update_fields=["last_used_at"])
  35. return (str(api_key_record.user_id), str(api_key_record.id))
  36. def _get_model_by_name(model_name, user_id):
  37. """
  38. 根据模型名称查找模型
  39. 支持模糊匹配:优先精确匹配 model_name,再匹配 name
  40. """
  41. # 精确匹配 model_name
  42. model = Model.objects.filter(model_name=model_name, status='SUCCESS').first()
  43. if model:
  44. return model
  45. # 精确匹配 name
  46. model = Model.objects.filter(name=model_name, status='SUCCESS').first()
  47. if model:
  48. return model
  49. # 模糊匹配
  50. model = Model.objects.filter(model_name__icontains=model_name, status='SUCCESS').first()
  51. if model:
  52. return model
  53. return None
  54. def _get_model_credential(model):
  55. """解密模型凭证"""
  56. try:
  57. credential = model.credential
  58. if isinstance(credential, str):
  59. credential = json.loads(credential)
  60. return credential
  61. except Exception:
  62. return {}
  63. def _call_openai_compatible(base_url, api_key, model_name, request_body, stream=False):
  64. """
  65. 调用 OpenAI 兼容接口
  66. """
  67. url = f"{base_url.rstrip('/')}/v1/chat/completions"
  68. headers = {
  69. 'Content-Type': 'application/json',
  70. 'Authorization': f'Bearer {api_key}',
  71. }
  72. body = {**request_body, 'model': model_name}
  73. if stream:
  74. return requests.post(url, json=body, headers=headers, stream=True, timeout=120)
  75. else:
  76. response = requests.post(url, json=body, headers=headers, timeout=120)
  77. return response.json()
  78. class OpenAIGatewayView(APIView):
  79. """
  80. OpenAI 兼容网关
  81. POST /api/v1/chat/completions - 聊天补全
  82. GET /api/v1/models - 模型列表
  83. """
  84. def post(self, request: Request):
  85. """聊天补全接口"""
  86. auth_result = _verify_bearer_token(request)
  87. if not auth_result:
  88. return self._openai_error(401, "Incorrect API key provided", "authentication_error")
  89. user_id, api_key_id = auth_result
  90. body = request.data
  91. model_name = body.get('model')
  92. stream = body.get('stream', False)
  93. if not model_name:
  94. return self._openai_error(400, "model is required", "invalid_request_error")
  95. # 查找模型
  96. model = _get_model_by_name(model_name, user_id)
  97. if not model:
  98. return self._openai_error(404, f"The model '{model_name}' does not exist", "model_not_found")
  99. # 获取凭证(兼容不同提供商的字段名)
  100. credential = _get_model_credential(model)
  101. api_key = credential.get('api_key', '')
  102. # 兼容 api_base_url (OpenAI) 和 api_base (Docker AI/Ollama)
  103. base_url = credential.get('api_base_url', '') or credential.get('api_base', '')
  104. if not api_key or not base_url:
  105. return self._openai_error(500, "Model credential not configured", "server_error")
  106. try:
  107. if stream:
  108. return self._stream_response(base_url, api_key, model_name, body)
  109. else:
  110. response_data = _call_openai_compatible(base_url, api_key, model_name, body, stream=False)
  111. return result.success(response_data)
  112. except requests.exceptions.Timeout:
  113. return self._openai_error(504, "Gateway timeout", "server_error")
  114. except requests.exceptions.ConnectionError:
  115. return self._openai_error(502, "Failed to connect to upstream model", "server_error")
  116. except Exception as e:
  117. return self._openai_error(500, str(e), "server_error")
  118. def get(self, request: Request):
  119. """获取可用模型列表"""
  120. auth_result = _verify_bearer_token(request)
  121. if not auth_result:
  122. return self._openai_error(401, "Incorrect API key provided", "authentication_error")
  123. # 返回所有可用模型
  124. models = Model.objects.filter(status='SUCCESS').values('model_name', 'name')
  125. model_list = []
  126. seen = set()
  127. for m in models:
  128. name = m['model_name'] or m['name']
  129. if name and name not in seen:
  130. seen.add(name)
  131. model_list.append({
  132. "id": name,
  133. "object": "model",
  134. "owned_by": "zhagent",
  135. })
  136. return result.success({
  137. "object": "list",
  138. "data": model_list,
  139. })
  140. def _stream_response(self, base_url, api_key, model_name, body):
  141. """流式响应"""
  142. url = f"{base_url.rstrip('/')}/v1/chat/completions"
  143. headers = {
  144. 'Content-Type': 'application/json',
  145. 'Authorization': f'Bearer {api_key}',
  146. }
  147. body = {**body, 'model': model_name, 'stream': True}
  148. try:
  149. upstream = requests.post(url, json=body, headers=headers, stream=True, timeout=120)
  150. upstream.raise_for_status()
  151. def generate():
  152. try:
  153. for line in upstream.iter_lines():
  154. if line:
  155. yield line.decode('utf-8') + '\n'
  156. finally:
  157. upstream.close()
  158. response = StreamingHttpResponse(
  159. generate(),
  160. content_type='text/event-stream',
  161. headers={
  162. 'Cache-Control': 'no-cache',
  163. 'Connection': 'keep-alive',
  164. 'X-Accel-Buffering': 'no',
  165. }
  166. )
  167. return response
  168. except Exception as e:
  169. return self._openai_error(502, f"Failed to stream from upstream: {e}", "server_error")
  170. @staticmethod
  171. def _openai_error(status_code, message, error_type):
  172. """返回 OpenAI 格式的错误响应"""
  173. from django.http import JsonResponse
  174. return JsonResponse(
  175. {"error": {"message": message, "type": error_type}},
  176. status=status_code,
  177. )