gemini_model_provider.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. @Project :MaxKB
  5. @File :gemini_model_provider.py
  6. @Author :Brian Yang
  7. @Date :5/13/24 7:47 AM
  8. """
  9. import os
  10. from common.utils.common import get_file_content
  11. from models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
  12. ModelInfoManage
  13. from models_provider.impl.gemini_model_provider.credential.embedding import GeminiEmbeddingCredential
  14. from models_provider.impl.gemini_model_provider.credential.image import GeminiImageModelCredential
  15. from models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential
  16. from models_provider.impl.gemini_model_provider.credential.stt import GeminiSTTModelCredential
  17. from models_provider.impl.gemini_model_provider.credential.tti import GeminiTextToImageModelCredential
  18. from models_provider.impl.gemini_model_provider.model.embedding import GeminiEmbeddingModel
  19. from models_provider.impl.gemini_model_provider.model.image import GeminiImage
  20. from models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel
  21. from models_provider.impl.gemini_model_provider.model.stt import GeminiSpeechToText
  22. from maxkb.conf import PROJECT_DIR
  23. from django.utils.translation import gettext as _
  24. from models_provider.impl.gemini_model_provider.model.tti import GeminiTextToImage
  25. gemini_llm_model_credential = GeminiLLMModelCredential()
  26. gemini_image_model_credential = GeminiImageModelCredential()
  27. gemini_stt_model_credential = GeminiSTTModelCredential()
  28. gemini_embedding_model_credential = GeminiEmbeddingCredential()
  29. gemini_tti_model_credential = GeminiTextToImageModelCredential()
  30. model_info_list = [
  31. ModelInfo('gemini-1.0-pro', _('Latest Gemini 1.0 Pro model, updated with Google update'),
  32. ModelTypeConst.LLM,
  33. gemini_llm_model_credential,
  34. GeminiChatModel),
  35. ModelInfo('gemini-1.0-pro-vision', _('Latest Gemini 1.0 Pro Vision model, updated with Google update'),
  36. ModelTypeConst.LLM,
  37. gemini_llm_model_credential,
  38. GeminiChatModel),
  39. ]
  40. model_image_info_list = [
  41. ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
  42. ModelTypeConst.IMAGE,
  43. gemini_image_model_credential,
  44. GeminiImage),
  45. ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
  46. ModelTypeConst.IMAGE,
  47. gemini_image_model_credential,
  48. GeminiImage),
  49. ]
  50. model_stt_info_list = [
  51. ModelInfo('gemini-1.5-flash', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
  52. ModelTypeConst.STT,
  53. gemini_stt_model_credential,
  54. GeminiSpeechToText),
  55. ModelInfo('gemini-1.5-pro', _('Latest Gemini 1.5 Flash model, updated with Google updates'),
  56. ModelTypeConst.STT,
  57. gemini_stt_model_credential,
  58. GeminiSpeechToText),
  59. ]
  60. model_embedding_info_list = [
  61. ModelInfo('models/embedding-001', '',
  62. ModelTypeConst.EMBEDDING,
  63. gemini_embedding_model_credential,
  64. GeminiEmbeddingModel),
  65. ModelInfo('models/text-embedding-004', '',
  66. ModelTypeConst.EMBEDDING,
  67. gemini_embedding_model_credential,
  68. GeminiEmbeddingModel),
  69. ]
  70. model_tti_info_list = [
  71. ModelInfo('gemini-3.1-flash-image-preview', "",
  72. ModelTypeConst.TTI,
  73. gemini_tti_model_credential,
  74. GeminiTextToImage)
  75. ]
  76. model_info_manage = (
  77. ModelInfoManage.builder()
  78. .append_model_info_list(model_info_list)
  79. .append_model_info_list(model_image_info_list)
  80. .append_model_info_list(model_stt_info_list)
  81. .append_model_info_list(model_embedding_info_list)
  82. .append_model_info_list(model_tti_info_list)
  83. .append_default_model_info(model_info_list[0])
  84. .append_default_model_info(model_image_info_list[0])
  85. .append_default_model_info(model_stt_info_list[0])
  86. .append_default_model_info(model_embedding_info_list[0])
  87. .append_default_model_info(model_tti_info_list[0])
  88. .build()
  89. )
  90. class GeminiModelProvider(IModelProvider):
  91. def get_model_info_manage(self):
  92. return model_info_manage
  93. def get_model_provide_info(self):
  94. return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
  95. os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
  96. 'gemini_icon_svg')))