llm.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # coding=utf-8
  2. """
  3. OpenAI 兼容私域模型调用
  4. """
  5. import json
  6. from typing import Iterator, Dict
  7. import requests
  8. from django.utils.translation import gettext as _
  9. from common.exception.app_exception import AppApiException
  10. from common.utils.logger import maxkb_logger
  11. from models_provider.base_model_provider import ModelInfo, ModelTypeConst
  12. class OpenAICompatibleChatModel:
  13. """OpenAI 兼容模型调用"""
  14. def __init__(self, model_info: ModelInfo, model_credential: Dict[str, object], **kwargs):
  15. self.model_info = model_info
  16. self.model_credential = model_credential
  17. self.api_base = model_credential.get('api_base', '').rstrip('/')
  18. self.api_key = model_credential.get('api_key', '')
  19. self.model_name = model_credential.get('model_name', model_info.model_name)
  20. def _build_headers(self):
  21. return {
  22. 'Content-Type': 'application/json',
  23. 'Authorization': f'Bearer {self.api_key}',
  24. }
  25. def chat(self, messages, stream=False, **kwargs):
  26. """聊天补全"""
  27. url = f"{self.api_base}/v1/chat/completions"
  28. body = {
  29. 'model': self.model_name,
  30. 'messages': messages,
  31. 'stream': stream,
  32. **kwargs
  33. }
  34. try:
  35. if stream:
  36. return self._stream_chat(url, body)
  37. else:
  38. response = requests.post(url, json=body, headers=self._build_headers(), timeout=120)
  39. response.raise_for_status()
  40. return response.json()
  41. except requests.exceptions.Timeout:
  42. raise AppApiException(504, _('Request timeout'))
  43. except requests.exceptions.ConnectionError:
  44. raise AppApiException(502, _('Failed to connect to API server'))
  45. except Exception as e:
  46. maxkb_logger.error(f'OpenAI compatible API error: {e}')
  47. raise AppApiException(500, str(e))
  48. def _stream_chat(self, url, body):
  49. """流式聊天"""
  50. response = requests.post(url, json=body, headers=self._build_headers(), stream=True, timeout=120)
  51. response.raise_for_status()
  52. def generate():
  53. try:
  54. for line in response.iter_lines():
  55. if line:
  56. yield line.decode('utf-8') + '\n'
  57. finally:
  58. response.close()
  59. return generate()
  60. def verify_connection(self):
  61. """验证连接"""
  62. try:
  63. url = f"{self.api_base}/v1/models"
  64. response = requests.get(url, headers=self._build_headers(), timeout=10)
  65. response.raise_for_status()
  66. return True
  67. except Exception as e:
  68. raise AppApiException(500, _('Connection verification failed: {error}').format(error=str(e)))