tti.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import base64
  2. from typing import Dict
  3. from openai import OpenAI
  4. from common.config.tokenizer_manage_config import TokenizerManage
  5. from models_provider.base_model_provider import MaxKBBaseModel
  6. from models_provider.impl.base_tti import BaseTextToImage
  7. def custom_get_token_ids(text: str):
  8. tokenizer = TokenizerManage.get_tokenizer()
  9. return tokenizer.encode(text)
  10. class GeminiTextToImage(MaxKBBaseModel, BaseTextToImage):
  11. base_url: str
  12. api_key: str
  13. model: str
  14. params: dict
  15. def __init__(self, **kwargs):
  16. super().__init__(**kwargs)
  17. self.api_key = kwargs.get('api_key')
  18. self.base_url = kwargs.get('base_url')
  19. self.model = kwargs.get('model')
  20. self.params = kwargs.get('params')
  21. @staticmethod
  22. def is_cache_model():
  23. return False
  24. @staticmethod
  25. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  26. optional_params = {'params': {}}
  27. for key, value in model_kwargs.items():
  28. if key not in ['model_id', 'use_local', 'streaming']:
  29. optional_params['params'][key] = value
  30. return GeminiTextToImage(
  31. model=model_name,
  32. base_url=model_credential.get('base_url', "https://generativelanguage.googleapis.com"),
  33. api_key=model_credential.get('api_key'),
  34. **optional_params,
  35. )
  36. def check_auth(self):
  37. return True
  38. def generate_image(self, prompt: str, negative_prompt: str = None):
  39. from google import genai
  40. from google.genai import types
  41. from PIL import Image
  42. file_urls = []
  43. client = genai.Client(api_key=self.api_key, http_options={"base_url": self.base_url}, **self.params)
  44. if self.model.startswith('imagen'):
  45. config = types.GenerateImagesConfig(**self.params)
  46. # 如果有 negative_prompt 就加入
  47. if negative_prompt:
  48. config.negative_prompt = negative_prompt
  49. response = client.models.generate_images(
  50. model=self.model,
  51. prompt=prompt,
  52. config=config
  53. )
  54. for generated_image in response.generated_images:
  55. img_base64 = base64.b64encode(generated_image.image.image_bytes).decode("utf-8")
  56. file_urls.append(f'data:{generated_image.image.mime_type};base64,{img_base64}')
  57. else:
  58. config = types.GenerateContentConfig(**self.params)
  59. if negative_prompt:
  60. config.negative_prompt = negative_prompt
  61. response = client.models.generate_content(
  62. model=self.model,
  63. contents=[prompt],
  64. config=config
  65. )
  66. for part in response.parts:
  67. if part.text is not None:
  68. print(part.text)
  69. elif part.inline_data is not None:
  70. image_bytes = part.inline_data.data
  71. img_base64 = base64.b64encode(image_bytes).decode("utf-8")
  72. file_urls.append(f'data:{part.inline_data.mime_type};base64,{img_base64}')
  73. return file_urls