tti.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # coding=utf-8
  2. '''
  3. requires Python 3.6 or later
  4. pip install asyncio
  5. pip install websockets
  6. '''
  7. from typing import Dict
  8. from models_provider.base_model_provider import MaxKBBaseModel
  9. from models_provider.impl.base_tti import BaseTextToImage
  10. from volcenginesdkarkruntime import Ark
  11. class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
  12. api_key: str
  13. api_base: str
  14. model_version: str
  15. params: dict
  16. def __init__(self, **kwargs):
  17. super().__init__(**kwargs)
  18. self.api_key = kwargs.get('api_key')
  19. self.api_base = kwargs.get('api_base')
  20. self.model_version = kwargs.get('model_version')
  21. self.params = kwargs.get('params')
  22. @staticmethod
  23. def is_cache_model():
  24. return False
  25. @staticmethod
  26. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  27. optional_params = {'params': {}}
  28. for key, value in model_kwargs.items():
  29. if key not in ['model_id', 'use_local', 'streaming']:
  30. optional_params['params'][key] = value
  31. return VolcanicEngineTextToImage(
  32. model_version=model_name,
  33. api_key=model_credential.get('api_key'),
  34. api_base=model_credential.get('volcanic_api_url') or 'https://ark-api.volcengine.com',
  35. **optional_params
  36. )
  37. def check_auth(self):
  38. return True
  39. def generate_image(self, prompt: str, negative_prompt: str = None):
  40. client = Ark(
  41. # 此为默认路径,您可根据业务所在地域进行配置
  42. base_url=self.api_base,
  43. # 从环境变量中获取您的 API Key。此为默认方式,您可根据需要进行修改
  44. api_key=self.api_key,
  45. )
  46. file_urls = []
  47. imagesResponse = client.images.generate(
  48. model=self.model_version,
  49. prompt=prompt,
  50. **self.params
  51. )
  52. if imagesResponse.data[0].url:
  53. file_urls.append(imagesResponse.data[0].url)
  54. elif imagesResponse.data[0].b64_json:
  55. file_urls.append(imagesResponse.data[0].b64_json)
  56. return file_urls