ttv.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import base64
  2. import time
  3. from typing import Dict, Optional
  4. from models_provider.base_model_provider import MaxKBBaseModel
  5. from models_provider.base_ttv import BaseGenerationVideo
  6. from common.utils.logger import maxkb_logger
  7. from volcenginesdkarkruntime import Ark
  8. class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
  9. api_key: str
  10. model_name: str
  11. params: dict
  12. max_retries: int = 3
  13. retry_delay: int = 5 # seconds
  14. def __init__(self, **kwargs):
  15. super().__init__(**kwargs)
  16. self.api_key = kwargs.get('api_key')
  17. self.model_name = kwargs.get('model_name')
  18. self.params = kwargs.get('params', {})
  19. self.retry_delay = 5
  20. @staticmethod
  21. def is_cache_model():
  22. return False
  23. @staticmethod
  24. def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
  25. optional_params = {'params': {}}
  26. for key, value in model_kwargs.items():
  27. if key not in ['model_id', 'use_local', 'streaming']:
  28. optional_params['params'][key] = value
  29. return GenerationVideoModel(
  30. model_name=model_name,
  31. api_key=model_credential.get('api_key'),
  32. **optional_params,
  33. )
  34. def check_auth(self):
  35. return True
  36. def _build_prompt(self, prompt: str) -> str:
  37. """拼接参数到 prompt 文本"""
  38. param_map = {
  39. "ratio": "rt",
  40. "duration": "dur",
  41. "framespersecond": "fps",
  42. "resolution": "rs",
  43. "watermark": "wm",
  44. "camerafixed": "cf",
  45. }
  46. for key, value in self.params.items():
  47. if key in param_map:
  48. prompt += f" --{param_map[key]} {value}"
  49. return prompt
  50. def _poll_task(self, client: Ark, task_id: str, max_wait: int = 60, interval: int = 5):
  51. """轮询任务状态,直到完成或超时"""
  52. elapsed = 0
  53. while elapsed < max_wait:
  54. result = client.content_generation.tasks.get(task_id=task_id)
  55. status = getattr(result, "status", None)
  56. maxkb_logger.info(f"[ArkVideo] Task {task_id} status={status}")
  57. if status in ("succeeded", "failed", "cancelled"):
  58. return result
  59. time.sleep(interval)
  60. elapsed += interval
  61. maxkb_logger.warning(f"[ArkVideo] Task {task_id} wait timeout")
  62. return None
  63. # --- 通用异步生成函数 ---
  64. def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs):
  65. client = Ark(api_key=self.api_key)
  66. # 根据params设置其他参数 豆包的参数和别的不一样 需要拼接在text里
  67. # --rt 16:9 --dur 5 --fps 24 --rs 720p --wm true --cf false
  68. prompt = self._build_prompt(prompt)
  69. content = [{"type": "text", "text": prompt}]
  70. if first_frame_url:
  71. content.append({
  72. "type": "image_url",
  73. "image_url": {
  74. "url": first_frame_url
  75. },
  76. "role": "first_frame"
  77. })
  78. if last_frame_url:
  79. content.append({
  80. "type": "image_url",
  81. "image_url": {
  82. "url": last_frame_url
  83. },
  84. "role": "last_frame"
  85. })
  86. create_result = client.content_generation.tasks.create(
  87. model=self.model_name,
  88. content=content
  89. )
  90. task = client.content_generation.tasks.create(model=self.model_name, content=content)
  91. task_id = task.id
  92. maxkb_logger.info(f"[ArkVideo] Created task {task_id}")
  93. # 轮询获取结果
  94. result = self._poll_task(client, task_id)
  95. if not result:
  96. return {"status": "timeout", "task_id": task_id}
  97. try:
  98. if getattr(result, "status", None) in ("succeeded", "failed", "cancelled"):
  99. client.content_generation.tasks.delete(task_id=task_id)
  100. maxkb_logger.info(f"[ArkVideo] Deleted task {task_id}")
  101. except Exception as e:
  102. maxkb_logger.error(f"[ArkVideo] Failed to delete task {task_id}: {e}")
  103. raise e
  104. maxkb_logger.info("视频地址", result.content.video_url)
  105. return result.content.video_url