| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import base64
- import time
- from typing import Dict, Optional
- from models_provider.base_model_provider import MaxKBBaseModel
- from models_provider.base_ttv import BaseGenerationVideo
- from common.utils.logger import maxkb_logger
- from volcenginesdkarkruntime import Ark
- class GenerationVideoModel(MaxKBBaseModel, BaseGenerationVideo):
- api_key: str
- model_name: str
- params: dict
- max_retries: int = 3
- retry_delay: int = 5 # seconds
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.api_key = kwargs.get('api_key')
- self.model_name = kwargs.get('model_name')
- self.params = kwargs.get('params', {})
- self.retry_delay = 5
- @staticmethod
- def is_cache_model():
- return False
- @staticmethod
- def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
- optional_params = {'params': {}}
- for key, value in model_kwargs.items():
- if key not in ['model_id', 'use_local', 'streaming']:
- optional_params['params'][key] = value
- return GenerationVideoModel(
- model_name=model_name,
- api_key=model_credential.get('api_key'),
- **optional_params,
- )
- def check_auth(self):
- return True
- def _build_prompt(self, prompt: str) -> str:
- """拼接参数到 prompt 文本"""
- param_map = {
- "ratio": "rt",
- "duration": "dur",
- "framespersecond": "fps",
- "resolution": "rs",
- "watermark": "wm",
- "camerafixed": "cf",
- }
- for key, value in self.params.items():
- if key in param_map:
- prompt += f" --{param_map[key]} {value}"
- return prompt
- def _poll_task(self, client: Ark, task_id: str, max_wait: int = 60, interval: int = 5):
- """轮询任务状态,直到完成或超时"""
- elapsed = 0
- while elapsed < max_wait:
- result = client.content_generation.tasks.get(task_id=task_id)
- status = getattr(result, "status", None)
- maxkb_logger.info(f"[ArkVideo] Task {task_id} status={status}")
- if status in ("succeeded", "failed", "cancelled"):
- return result
- time.sleep(interval)
- elapsed += interval
- maxkb_logger.warning(f"[ArkVideo] Task {task_id} wait timeout")
- return None
- # --- 通用异步生成函数 ---
- def generate_video(self, prompt, negative_prompt=None, first_frame_url=None, last_frame_url=None, **kwargs):
- client = Ark(api_key=self.api_key)
- # 根据params设置其他参数 豆包的参数和别的不一样 需要拼接在text里
- # --rt 16:9 --dur 5 --fps 24 --rs 720p --wm true --cf false
- prompt = self._build_prompt(prompt)
- content = [{"type": "text", "text": prompt}]
- if first_frame_url:
- content.append({
- "type": "image_url",
- "image_url": {
- "url": first_frame_url
- },
- "role": "first_frame"
- })
- if last_frame_url:
- content.append({
- "type": "image_url",
- "image_url": {
- "url": last_frame_url
- },
- "role": "last_frame"
- })
- create_result = client.content_generation.tasks.create(
- model=self.model_name,
- content=content
- )
- task = client.content_generation.tasks.create(model=self.model_name, content=content)
- task_id = task.id
- maxkb_logger.info(f"[ArkVideo] Created task {task_id}")
- # 轮询获取结果
- result = self._poll_task(client, task_id)
- if not result:
- return {"status": "timeout", "task_id": task_id}
- try:
- if getattr(result, "status", None) in ("succeeded", "failed", "cancelled"):
- client.content_generation.tasks.delete(task_id=task_id)
- maxkb_logger.info(f"[ArkVideo] Deleted task {task_id}")
- except Exception as e:
- maxkb_logger.error(f"[ArkVideo] Failed to delete task {task_id}: {e}")
- raise e
- maxkb_logger.info("视频地址", result.content.video_url)
- return result.content.video_url
|