tti.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from typing import Dict
  2. from openai import OpenAI
  3. from common.config.tokenizer_manage_config import TokenizerManage
  4. from models_provider.base_model_provider import MaxKBBaseModel
  5. from models_provider.impl.base_tti import BaseTextToImage
  6. def custom_get_token_ids(text: str):
  7. tokenizer = TokenizerManage.get_tokenizer()
  8. return tokenizer.encode(text)
  9. class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
  10. api_base: str
  11. api_key: str
  12. model: str
  13. params: dict
  14. def __init__(self, **kwargs):
  15. super().__init__(**kwargs)
  16. self.api_key = kwargs.get('api_key')
  17. self.api_base = kwargs.get('api_base')
  18. self.model = kwargs.get('model')
  19. self.params = kwargs.get('params')
  20. @staticmethod
  21. def is_cache_model():
  22. return False
  23. @staticmethod
  24. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  25. optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
  26. for key, value in model_kwargs.items():
  27. if key not in ['model_id', 'use_local', 'streaming']:
  28. optional_params['params'][key] = value
  29. return OpenAITextToImage(
  30. model=model_name,
  31. api_base=model_credential.get('api_base'),
  32. api_key=model_credential.get('api_key'),
  33. **optional_params,
  34. )
  35. def check_auth(self):
  36. chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
  37. response_list = chat.models.with_raw_response.list()
  38. # self.generate_image('生成一个小猫图片')
  39. def generate_image(self, prompt: str, negative_prompt: str = None):
  40. chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
  41. res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
  42. file_urls = []
  43. try:
  44. for content in res.data:
  45. if content.url:
  46. file_urls.append(content.url)
  47. elif content.b64_json:
  48. file_urls.append(content.b64_json)
  49. return file_urls
  50. except Exception as e:
  51. raise f"OpenAI generate image error: {e}"