| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- """
- 视频合成服务
- 提供文生视频和图生视频功能
- """
- import logging
- import threading
- import httpx
- from datetime import datetime
- from decimal import Decimal
- from typing import Optional
- from fastapi import HTTPException
- from sqlalchemy.orm import Session
- from app.models.ai_video import AIVideo
- from app.schemas.video_schema import (
- VideoGenerateRequest, VideoTaskResponse, VideoTaskResult,
- VideoHistoryItem, VideoHistoryResponse
- )
- from app.services.oss_service import get_oss_service
- logger = logging.getLogger(__name__)
- DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
- # 分辨率映射
- RESOLUTION_MAP = {
- "720P": {"16:9": "1280*720", "9:16": "720*1280", "1:1": "960*960"},
- "1080P": {"16:9": "1920*1080", "9:16": "1080*1920", "1:1": "1440*1440"}
- }
- class VideoService:
- """视频合成服务"""
- def __init__(self, db: Session, user_id: int, api_key: str):
- self.db = db
- self.user_id = user_id
- self.api_key = api_key
- self.oss_service = get_oss_service()
- async def generate(self, request: VideoGenerateRequest) -> VideoTaskResponse:
- """统一视频生成入口"""
- # 统一策略:只要有首帧(不管是否包含尾帧)都使用图生视频接口
- if request.first_frame_url:
- return await self._image_to_video(request)
- else:
- return await self._text_to_video(request)
- async def _text_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse:
- """文生视频(HTTP接口)"""
- # 使用前端传入的模型,默认 wan2.6-t2v
- model = request.model or "wan2.6-t2v"
- url = f"{DASHSCOPE_BASE_URL}/services/aigc/video-generation/video-synthesis"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}",
- "X-DashScope-Async": "enable"
- }
- # 获取分辨率尺寸
- size = RESOLUTION_MAP.get(request.resolution, RESOLUTION_MAP["720P"])["16:9"]
- body = {
- "model": model,
- "input": {
- "prompt": request.prompt
- },
- "parameters": {
- "size": size,
- "duration": request.duration,
- "prompt_extend": request.prompt_extend,
- "shot_type": request.shot_type,
- "watermark": request.watermark
- }
- }
- if request.negative_prompt:
- body["input"]["negative_prompt"] = request.negative_prompt
- if request.audio_url:
- body["input"]["audio_url"] = request.audio_url
- if request.seed:
- body["parameters"]["seed"] = request.seed
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(url, json=body, headers=headers)
- result = response.json()
- if "code" in result:
- raise Exception(f"创建任务失败: {result.get('message')}")
- task_id = result["output"]["task_id"]
- task_status = result["output"]["task_status"]
- # 保存记录
- record = AIVideo(
- user_id=self.user_id,
- task_id=task_id,
- model_name=model,
- video_type="t2v",
- input_params=body,
- prompt=request.prompt,
- audio_url=request.audio_url,
- resolution=request.resolution,
- status=task_status,
- submit_time=datetime.now()
- )
- self.db.add(record)
- self.db.commit()
- return VideoTaskResponse(task_id=task_id, task_status=task_status)
- async def _first_last_frame_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse:
- """首尾帧生视频 - 调用wan2.2-kf2v-flash(HTTP接口)"""
- url = f"{DASHSCOPE_BASE_URL}/services/aigc/image2video/video-synthesis"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}",
- "X-DashScope-Async": "enable"
- }
- body = {
- "model": "wan2.2-kf2v-flash",
- "input": {
- "first_frame_url": request.first_frame_url,
- "last_frame_url": request.last_frame_url,
- "prompt": request.prompt
- },
- "parameters": {
- "resolution": request.resolution,
- "prompt_extend": request.prompt_extend,
- "watermark": request.watermark
- }
- }
- if request.negative_prompt:
- body["input"]["negative_prompt"] = request.negative_prompt
- if request.seed:
- body["parameters"]["seed"] = request.seed
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(url, json=body, headers=headers)
- result = response.json()
- if "code" in result:
- raise Exception(f"创建任务失败: {result.get('message')}")
- task_id = result["output"]["task_id"]
- task_status = result["output"]["task_status"]
- # 保存记录
- record = AIVideo(
- user_id=self.user_id,
- task_id=task_id,
- model_name="wan2.2-kf2v-flash",
- video_type="kf2v",
- input_params=body,
- prompt=request.prompt,
- first_frame_url=request.first_frame_url,
- last_frame_url=request.last_frame_url,
- resolution=request.resolution,
- status=task_status,
- submit_time=datetime.now()
- )
- self.db.add(record)
- self.db.commit()
- return VideoTaskResponse(task_id=task_id, task_status=task_status)
- async def _image_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse:
- """图生视频(HTTP接口)"""
- # 使用前端传入的模型,默认 wan2.6-i2v
- model = request.model or "wan2.6-i2v"
- url = f"{DASHSCOPE_BASE_URL}/services/aigc/video-generation/video-synthesis"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}",
- "X-DashScope-Async": "enable"
- }
- # 获取分辨率尺寸
- size = RESOLUTION_MAP.get(request.resolution, RESOLUTION_MAP["720P"])["16:9"]
- body = {
- "model": model,
- "input": {
- "prompt": request.prompt,
- "img_url": request.first_frame_url
- },
- "parameters": {
- "size": size,
- "duration": request.duration,
- "prompt_extend": request.prompt_extend,
- "watermark": request.watermark
- }
- }
- if request.negative_prompt:
- body["input"]["negative_prompt"] = request.negative_prompt
- if request.last_frame_url:
- body["input"]["last_img_url"] = request.last_frame_url
- if request.audio_url:
- body["input"]["audio_url"] = request.audio_url
- if request.audio is not None:
- body["parameters"]["audio"] = request.audio
- if request.seed:
- body["parameters"]["seed"] = request.seed
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(url, json=body, headers=headers)
- result = response.json()
- if "code" in result:
- raise Exception(f"创建任务失败: {result.get('message')}")
- task_id = result["output"]["task_id"]
- task_status = result["output"]["task_status"]
- # 保存记录
- record = AIVideo(
- user_id=self.user_id,
- task_id=task_id,
- model_name=model,
- video_type="i2v",
- input_params={
- "prompt": request.prompt,
- "first_frame_url": request.first_frame_url,
- "last_frame_url": request.last_frame_url,
- "resolution": request.resolution,
- "duration": request.duration
- },
- prompt=request.prompt,
- first_frame_url=request.first_frame_url,
- last_frame_url=request.last_frame_url,
- audio_url=request.audio_url,
- resolution=request.resolution,
- status=task_status,
- submit_time=datetime.now()
- )
- self.db.add(record)
- self.db.commit()
- return VideoTaskResponse(task_id=task_id, task_status=task_status)
- async def get_task_status(self, task_id: str) -> VideoTaskResult:
- """查询任务状态"""
- # 验证任务归属
- record = self.db.query(AIVideo).filter(
- AIVideo.task_id == task_id,
- AIVideo.user_id == self.user_id
- ).first()
- if not record:
- raise Exception("任务不存在")
- # 如果已完成,直接返回
- if record.status in ["SUCCEEDED", "FAILED"]:
- return VideoTaskResult(
- task_id=task_id,
- task_status=record.status,
- video_url=record.video_url,
- video_duration=float(record.video_duration) if record.video_duration else None,
- actual_prompt=record.actual_prompt,
- error_message=record.error_message
- )
- # 查询百炼API
- url = f"{DASHSCOPE_BASE_URL}/tasks/{task_id}"
- headers = {"Authorization": f"Bearer {self.api_key}"}
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.get(url, headers=headers)
- result = response.json()
- output = result.get("output", {})
- task_status = output.get("task_status", "UNKNOWN")
- # 更新数据库
- record.status = task_status
- if task_status == "RUNNING" and not record.scheduled_time:
- record.scheduled_time = datetime.now()
- if task_status == "SUCCEEDED":
- record.end_time = datetime.now()
- # 保存actual_prompt
- actual_prompt = output.get("actual_prompt")
- if actual_prompt:
- record.actual_prompt = actual_prompt
- # 费用(API调用免费)
- record.bill = Decimal("0")
- # 先落库(状态 SUCCEEDED),让前端立刻拿到结果
- # OSS 上传放到后台线程,不阻塞当前请求
- dashscope_video_url = output.get("video_url")
- self.db.commit()
- if dashscope_video_url:
- threading.Thread(
- target=self._upload_video_to_oss_sync,
- args=(record.id, dashscope_video_url),
- daemon=True
- ).start()
- return VideoTaskResult(
- task_id=task_id,
- task_status=record.status,
- # OSS 上传未完成前先返回 DashScope 原始 URL,前端可立即播放
- video_url=dashscope_video_url,
- video_duration=float(record.video_duration) if record.video_duration else None,
- actual_prompt=record.actual_prompt,
- error_message=record.error_message
- )
- elif task_status == "FAILED":
- record.end_time = datetime.now()
- record.error_message = output.get("message", "任务失败")
- self.db.commit()
- # 返回时以数据库中最终的 record.status 为准,确保扣费失败后前端不会误判任务为 SUCCEEDED
- return VideoTaskResult(
- task_id=task_id,
- task_status=record.status if record.status else task_status,
- video_url=record.video_url,
- video_duration=float(record.video_duration) if record.video_duration else None,
- actual_prompt=record.actual_prompt,
- error_message=record.error_message
- )
- def _upload_video_to_oss_sync(self, record_id: int, video_url: str) -> None:
- """后台线程:将视频从 DashScope URL 上传到 OSS,完成后更新数据库。
- 使用独立的数据库会话,避免与请求会话冲突。
- """
- from app.database import SessionLocal
- try:
- oss_url = self.oss_service.upload_from_url_sync(video_url, "ai-videos")
- except Exception as e:
- logger.error(f"OSS 上传失败: record_id={record_id}, error={e}")
- return
- db = SessionLocal()
- try:
- rec = db.query(AIVideo).filter(AIVideo.id == record_id).first()
- if rec:
- rec.video_url = oss_url
- db.commit()
- logger.info(f"OSS 上传完成: record_id={record_id}, oss_url={oss_url}")
- except Exception as e:
- logger.error(f"更新 OSS URL 失败: record_id={record_id}, error={e}")
- db.rollback()
- finally:
- db.close()
- def get_history(self, page: int = 1, page_size: int = 20) -> VideoHistoryResponse:
- """获取用户历史记录"""
- query = self.db.query(AIVideo).filter(
- AIVideo.user_id == self.user_id,
- AIVideo.review_status != 'rejected' # 排除被拒绝的内容
- ).order_by(AIVideo.created_at.desc())
- total = query.count()
- offset = (page - 1) * page_size
- records = query.offset(offset).limit(page_size).all()
- items = [
- VideoHistoryItem(
- id=r.id,
- task_id=r.task_id,
- model_name=r.model_name,
- video_type=r.video_type,
- prompt=r.prompt,
- custom_name=r.custom_name,
- video_url=r.video_url,
- video_duration=float(r.video_duration) if r.video_duration else None,
- resolution=r.resolution,
- status=r.status,
- bill=float(r.bill) if r.bill else 0,
- created_at=r.created_at.isoformat() if r.created_at else ""
- )
- for r in records
- ]
- return VideoHistoryResponse(
- items=items,
- total=total,
- page=page,
- page_size=page_size
- )
|