""" 视频合成服务 提供文生视频和图生视频功能 """ import logging import threading import httpx from datetime import datetime from decimal import Decimal from typing import Optional from fastapi import HTTPException from sqlalchemy.orm import Session from app.models.ai_video import AIVideo from app.schemas.video_schema import ( VideoGenerateRequest, VideoTaskResponse, VideoTaskResult, VideoHistoryItem, VideoHistoryResponse ) from app.services.oss_service import get_oss_service logger = logging.getLogger(__name__) DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1" # 分辨率映射 RESOLUTION_MAP = { "720P": {"16:9": "1280*720", "9:16": "720*1280", "1:1": "960*960"}, "1080P": {"16:9": "1920*1080", "9:16": "1080*1920", "1:1": "1440*1440"} } class VideoService: """视频合成服务""" def __init__(self, db: Session, user_id: int, api_key: str): self.db = db self.user_id = user_id self.api_key = api_key self.oss_service = get_oss_service() async def generate(self, request: VideoGenerateRequest) -> VideoTaskResponse: """统一视频生成入口""" # 统一策略:只要有首帧(不管是否包含尾帧)都使用图生视频接口 if request.first_frame_url: return await self._image_to_video(request) else: return await self._text_to_video(request) async def _text_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse: """文生视频(HTTP接口)""" # 使用前端传入的模型,默认 wan2.6-t2v model = request.model or "wan2.6-t2v" url = f"{DASHSCOPE_BASE_URL}/services/aigc/video-generation/video-synthesis" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", "X-DashScope-Async": "enable" } # 获取分辨率尺寸 size = RESOLUTION_MAP.get(request.resolution, RESOLUTION_MAP["720P"])["16:9"] body = { "model": model, "input": { "prompt": request.prompt }, "parameters": { "size": size, "duration": request.duration, "prompt_extend": request.prompt_extend, "shot_type": request.shot_type, "watermark": request.watermark } } if request.negative_prompt: body["input"]["negative_prompt"] = request.negative_prompt if request.audio_url: body["input"]["audio_url"] = request.audio_url if request.seed: body["parameters"]["seed"] = request.seed async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=body, headers=headers) result = response.json() if "code" in result: raise Exception(f"创建任务失败: {result.get('message')}") task_id = result["output"]["task_id"] task_status = result["output"]["task_status"] # 保存记录 record = AIVideo( user_id=self.user_id, task_id=task_id, model_name=model, video_type="t2v", input_params=body, prompt=request.prompt, audio_url=request.audio_url, resolution=request.resolution, status=task_status, submit_time=datetime.now() ) self.db.add(record) self.db.commit() return VideoTaskResponse(task_id=task_id, task_status=task_status) async def _first_last_frame_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse: """首尾帧生视频 - 调用wan2.2-kf2v-flash(HTTP接口)""" url = f"{DASHSCOPE_BASE_URL}/services/aigc/image2video/video-synthesis" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", "X-DashScope-Async": "enable" } body = { "model": "wan2.2-kf2v-flash", "input": { "first_frame_url": request.first_frame_url, "last_frame_url": request.last_frame_url, "prompt": request.prompt }, "parameters": { "resolution": request.resolution, "prompt_extend": request.prompt_extend, "watermark": request.watermark } } if request.negative_prompt: body["input"]["negative_prompt"] = request.negative_prompt if request.seed: body["parameters"]["seed"] = request.seed async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=body, headers=headers) result = response.json() if "code" in result: raise Exception(f"创建任务失败: {result.get('message')}") task_id = result["output"]["task_id"] task_status = result["output"]["task_status"] # 保存记录 record = AIVideo( user_id=self.user_id, task_id=task_id, model_name="wan2.2-kf2v-flash", video_type="kf2v", input_params=body, prompt=request.prompt, first_frame_url=request.first_frame_url, last_frame_url=request.last_frame_url, resolution=request.resolution, status=task_status, submit_time=datetime.now() ) self.db.add(record) self.db.commit() return VideoTaskResponse(task_id=task_id, task_status=task_status) async def _image_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse: """图生视频(HTTP接口)""" # 使用前端传入的模型,默认 wan2.6-i2v model = request.model or "wan2.6-i2v" url = f"{DASHSCOPE_BASE_URL}/services/aigc/video-generation/video-synthesis" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", "X-DashScope-Async": "enable" } # 获取分辨率尺寸 size = RESOLUTION_MAP.get(request.resolution, RESOLUTION_MAP["720P"])["16:9"] body = { "model": model, "input": { "prompt": request.prompt, "img_url": request.first_frame_url }, "parameters": { "size": size, "duration": request.duration, "prompt_extend": request.prompt_extend, "watermark": request.watermark } } if request.negative_prompt: body["input"]["negative_prompt"] = request.negative_prompt if request.last_frame_url: body["input"]["last_img_url"] = request.last_frame_url if request.audio_url: body["input"]["audio_url"] = request.audio_url if request.audio is not None: body["parameters"]["audio"] = request.audio if request.seed: body["parameters"]["seed"] = request.seed async with httpx.AsyncClient(timeout=30.0) as client: response = await client.post(url, json=body, headers=headers) result = response.json() if "code" in result: raise Exception(f"创建任务失败: {result.get('message')}") task_id = result["output"]["task_id"] task_status = result["output"]["task_status"] # 保存记录 record = AIVideo( user_id=self.user_id, task_id=task_id, model_name=model, video_type="i2v", input_params={ "prompt": request.prompt, "first_frame_url": request.first_frame_url, "last_frame_url": request.last_frame_url, "resolution": request.resolution, "duration": request.duration }, prompt=request.prompt, first_frame_url=request.first_frame_url, last_frame_url=request.last_frame_url, audio_url=request.audio_url, resolution=request.resolution, status=task_status, submit_time=datetime.now() ) self.db.add(record) self.db.commit() return VideoTaskResponse(task_id=task_id, task_status=task_status) async def get_task_status(self, task_id: str) -> VideoTaskResult: """查询任务状态""" # 验证任务归属 record = self.db.query(AIVideo).filter( AIVideo.task_id == task_id, AIVideo.user_id == self.user_id ).first() if not record: raise Exception("任务不存在") # 如果已完成,直接返回 if record.status in ["SUCCEEDED", "FAILED"]: return VideoTaskResult( task_id=task_id, task_status=record.status, video_url=record.video_url, video_duration=float(record.video_duration) if record.video_duration else None, actual_prompt=record.actual_prompt, error_message=record.error_message ) # 查询百炼API url = f"{DASHSCOPE_BASE_URL}/tasks/{task_id}" headers = {"Authorization": f"Bearer {self.api_key}"} async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(url, headers=headers) result = response.json() output = result.get("output", {}) task_status = output.get("task_status", "UNKNOWN") # 更新数据库 record.status = task_status if task_status == "RUNNING" and not record.scheduled_time: record.scheduled_time = datetime.now() if task_status == "SUCCEEDED": record.end_time = datetime.now() # 保存actual_prompt actual_prompt = output.get("actual_prompt") if actual_prompt: record.actual_prompt = actual_prompt # 费用(API调用免费) record.bill = Decimal("0") # 先落库(状态 SUCCEEDED),让前端立刻拿到结果 # OSS 上传放到后台线程,不阻塞当前请求 dashscope_video_url = output.get("video_url") self.db.commit() if dashscope_video_url: threading.Thread( target=self._upload_video_to_oss_sync, args=(record.id, dashscope_video_url), daemon=True ).start() return VideoTaskResult( task_id=task_id, task_status=record.status, # OSS 上传未完成前先返回 DashScope 原始 URL,前端可立即播放 video_url=dashscope_video_url, video_duration=float(record.video_duration) if record.video_duration else None, actual_prompt=record.actual_prompt, error_message=record.error_message ) elif task_status == "FAILED": record.end_time = datetime.now() record.error_message = output.get("message", "任务失败") self.db.commit() # 返回时以数据库中最终的 record.status 为准,确保扣费失败后前端不会误判任务为 SUCCEEDED return VideoTaskResult( task_id=task_id, task_status=record.status if record.status else task_status, video_url=record.video_url, video_duration=float(record.video_duration) if record.video_duration else None, actual_prompt=record.actual_prompt, error_message=record.error_message ) def _upload_video_to_oss_sync(self, record_id: int, video_url: str) -> None: """后台线程:将视频从 DashScope URL 上传到 OSS,完成后更新数据库。 使用独立的数据库会话,避免与请求会话冲突。 """ from app.database import SessionLocal try: oss_url = self.oss_service.upload_from_url_sync(video_url, "ai-videos") except Exception as e: logger.error(f"OSS 上传失败: record_id={record_id}, error={e}") return db = SessionLocal() try: rec = db.query(AIVideo).filter(AIVideo.id == record_id).first() if rec: rec.video_url = oss_url db.commit() logger.info(f"OSS 上传完成: record_id={record_id}, oss_url={oss_url}") except Exception as e: logger.error(f"更新 OSS URL 失败: record_id={record_id}, error={e}") db.rollback() finally: db.close() def get_history(self, page: int = 1, page_size: int = 20) -> VideoHistoryResponse: """获取用户历史记录""" query = self.db.query(AIVideo).filter( AIVideo.user_id == self.user_id, AIVideo.review_status != 'rejected' # 排除被拒绝的内容 ).order_by(AIVideo.created_at.desc()) total = query.count() offset = (page - 1) * page_size records = query.offset(offset).limit(page_size).all() items = [ VideoHistoryItem( id=r.id, task_id=r.task_id, model_name=r.model_name, video_type=r.video_type, prompt=r.prompt, custom_name=r.custom_name, video_url=r.video_url, video_duration=float(r.video_duration) if r.video_duration else None, resolution=r.resolution, status=r.status, bill=float(r.bill) if r.bill else 0, created_at=r.created_at.isoformat() if r.created_at else "" ) for r in records ] return VideoHistoryResponse( items=items, total=total, page=page, page_size=page_size )