tti.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from typing import Dict
  2. from django.utils.translation import gettext
  3. from langchain_core.messages import HumanMessage
  4. from langchain_openai import ChatOpenAI
  5. from zhipuai import ZhipuAI
  6. from common.config.tokenizer_manage_config import TokenizerManage
  7. from models_provider.base_model_provider import MaxKBBaseModel
  8. from models_provider.impl.base_tti import BaseTextToImage
  9. def custom_get_token_ids(text: str):
  10. tokenizer = TokenizerManage.get_tokenizer()
  11. return tokenizer.encode(text)
  12. class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
  13. api_key: str
  14. model: str
  15. params: dict
  16. def __init__(self, **kwargs):
  17. super().__init__(**kwargs)
  18. self.api_key = kwargs.get('api_key')
  19. self.model = kwargs.get('model')
  20. self.params = kwargs.get('params')
  21. @staticmethod
  22. def is_cache_model():
  23. return False
  24. @staticmethod
  25. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  26. optional_params = {'params': {'size': '1024x1024'}}
  27. for key, value in model_kwargs.items():
  28. if key not in ['model_id', 'use_local', 'streaming']:
  29. optional_params['params'][key] = value
  30. return ZhiPuTextToImage(
  31. model=model_name,
  32. api_key=model_credential.get('api_key'),
  33. **optional_params,
  34. )
  35. def is_cache_model(self):
  36. return False
  37. def check_auth(self):
  38. chat = ChatOpenAI(
  39. api_key=self.api_key,
  40. base_url='https://open.bigmodel.cn/api/paas/v4',
  41. model=self.model,
  42. )
  43. chat.invoke([HumanMessage([{"type": "text", "text": gettext('Hello')}])])
  44. # self.generate_image('生成一个小猫图片')
  45. def generate_image(self, prompt: str, negative_prompt: str = None):
  46. # chat = ChatZhipuAI(
  47. # zhipuai_api_key=self.api_key,
  48. # model_name=self.model,
  49. # )
  50. chat = ZhipuAI(api_key=self.api_key)
  51. response = chat.images.generations(
  52. model=self.model, # 填写需要调用的模型编码
  53. prompt=prompt, # 填写需要生成图片的文本
  54. **self.params # 填写额外参数
  55. )
  56. file_urls = []
  57. try:
  58. for content in response.data:
  59. url = content.url
  60. file_urls.append(url)
  61. return file_urls
  62. except Exception as e:
  63. raise e