tti.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # coding=utf-8
  2. import json
  3. import logging
  4. from typing import Dict
  5. from django.utils.translation import gettext as _
  6. from tencentcloud.common import credential
  7. from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
  8. from tencentcloud.common.profile.client_profile import ClientProfile
  9. from tencentcloud.common.profile.http_profile import HttpProfile
  10. from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
  11. from common.utils.logger import maxkb_logger
  12. from models_provider.base_model_provider import MaxKBBaseModel
  13. from models_provider.impl.base_tti import BaseTextToImage
  14. from models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
  15. class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage):
  16. hunyuan_secret_id: str
  17. hunyuan_secret_key: str
  18. model: str
  19. params: dict
  20. @staticmethod
  21. def is_cache_model():
  22. return False
  23. def __init__(self, **kwargs):
  24. super().__init__(**kwargs)
  25. self.hunyuan_secret_id = kwargs.get('hunyuan_secret_id')
  26. self.hunyuan_secret_key = kwargs.get('hunyuan_secret_key')
  27. self.model = kwargs.get('model_name')
  28. self.params = kwargs.get('params')
  29. @staticmethod
  30. def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object],
  31. **model_kwargs) -> 'TencentTextToImageModel':
  32. optional_params = {'params': {'Style': '201', 'Resolution': '768:768'}}
  33. for key, value in model_kwargs.items():
  34. if key not in ['model_id', 'use_local', 'streaming']:
  35. optional_params['params'][key] = value
  36. return TencentTextToImageModel(
  37. model=model_name,
  38. hunyuan_secret_id=model_credential.get('hunyuan_secret_id'),
  39. hunyuan_secret_key=model_credential.get('hunyuan_secret_key'),
  40. **optional_params
  41. )
  42. def check_auth(self):
  43. chat = ChatHunyuan(hunyuan_app_id='111111',
  44. hunyuan_secret_id=self.hunyuan_secret_id,
  45. hunyuan_secret_key=self.hunyuan_secret_key,
  46. model="hunyuan-standard")
  47. res = chat.invoke(_('Hello'))
  48. # print(res)
  49. def generate_image(self, prompt: str, negative_prompt: str = None):
  50. try:
  51. # 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
  52. # 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
  53. # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
  54. cred = credential.Credential(self.hunyuan_secret_id, self.hunyuan_secret_key)
  55. # 实例化一个http选项,可选的,没有特殊需求可以跳过
  56. httpProfile = HttpProfile()
  57. httpProfile.endpoint = "hunyuan.tencentcloudapi.com"
  58. # 实例化一个client选项,可选的,没有特殊需求可以跳过
  59. clientProfile = ClientProfile()
  60. clientProfile.httpProfile = httpProfile
  61. # 实例化要请求产品的client对象,clientProfile是可选的
  62. client = hunyuan_client.HunyuanClient(cred, "ap-guangzhou", clientProfile)
  63. # 实例化一个请求对象,每个接口都会对应一个request对象
  64. req = models.TextToImageLiteRequest()
  65. params = {
  66. "Prompt": prompt,
  67. "NegativePrompt": negative_prompt,
  68. "RspImgType": "url",
  69. **self.params
  70. }
  71. req.from_json_string(json.dumps(params))
  72. # 返回的resp是一个TextToImageLiteResponse的实例,与请求对象对应
  73. resp = client.TextToImageLite(req)
  74. file_urls = []
  75. file_urls.append(resp.ResultImage)
  76. return file_urls
  77. except TencentCloudSDKException as err:
  78. maxkb_logger.error(f"Tencent Text to Image API call failed: {err}")
  79. raise f"Tencent Text to Image API call failed: {err}"