| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- # coding=utf-8
- '''
- requires Python 3.6 or later
- pip install asyncio
- pip install websockets
- '''
- from typing import Dict
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.impl.base_tti import BaseTextToImage
- from volcenginesdkarkruntime import Ark
- class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage):
- api_key: str
- api_base: str
- model_version: str
- params: dict
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.api_key = kwargs.get('api_key')
- self.api_base = kwargs.get('api_base')
- self.model_version = kwargs.get('model_version')
- self.params = kwargs.get('params')
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = {'params': {}}
- for key, value in model_kwargs.items():
- if key not in ['model_id', 'use_local', 'streaming']:
- optional_params['params'][key] = value
- return VolcanicEngineTextToImage(
- model_version=model_name,
- api_key=model_credential.get('api_key'),
- api_base=model_credential.get('volcanic_api_url') or 'https://ark-api.volcengine.com',
- **optional_params
- )
- def check_auth(self):
- return True
- def generate_image(self, prompt: str, negative_prompt: str = None):
- client = Ark(
- # 此为默认路径,您可根据业务所在地域进行配置
- base_url=self.api_base,
- # 从环境变量中获取您的 API Key。此为默认方式,您可根据需要进行修改
- api_key=self.api_key,
- )
- file_urls = []
- imagesResponse = client.images.generate(
- model=self.model_version,
- prompt=prompt,
- **self.params
- )
- if imagesResponse.data[0].url:
- file_urls.append(imagesResponse.data[0].url)
- elif imagesResponse.data[0].b64_json:
- file_urls.append(imagesResponse.data[0].b64_json)
- return file_urls
|