video_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. """
  2. 视频合成服务
  3. 提供文生视频和图生视频功能
  4. """
  5. import logging
  6. import threading
  7. import httpx
  8. from datetime import datetime
  9. from decimal import Decimal
  10. from typing import Optional
  11. from fastapi import HTTPException
  12. from sqlalchemy.orm import Session
  13. from app.models.ai_video import AIVideo
  14. from app.schemas.video_schema import (
  15. VideoGenerateRequest, VideoTaskResponse, VideoTaskResult,
  16. VideoHistoryItem, VideoHistoryResponse
  17. )
  18. from app.services.oss_service import get_oss_service
  19. logger = logging.getLogger(__name__)
  20. DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
  21. # 分辨率映射
  22. RESOLUTION_MAP = {
  23. "720P": {"16:9": "1280*720", "9:16": "720*1280", "1:1": "960*960"},
  24. "1080P": {"16:9": "1920*1080", "9:16": "1080*1920", "1:1": "1440*1440"}
  25. }
  26. class VideoService:
  27. """视频合成服务"""
  28. def __init__(self, db: Session, user_id: int, api_key: str):
  29. self.db = db
  30. self.user_id = user_id
  31. self.api_key = api_key
  32. self.oss_service = get_oss_service()
  33. async def generate(self, request: VideoGenerateRequest) -> VideoTaskResponse:
  34. """统一视频生成入口"""
  35. # 统一策略:只要有首帧(不管是否包含尾帧)都使用图生视频接口
  36. if request.first_frame_url:
  37. return await self._image_to_video(request)
  38. else:
  39. return await self._text_to_video(request)
  40. async def _text_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse:
  41. """文生视频(HTTP接口)"""
  42. # 使用前端传入的模型,默认 wan2.6-t2v
  43. model = request.model or "wan2.6-t2v"
  44. url = f"{DASHSCOPE_BASE_URL}/services/aigc/video-generation/video-synthesis"
  45. headers = {
  46. "Content-Type": "application/json",
  47. "Authorization": f"Bearer {self.api_key}",
  48. "X-DashScope-Async": "enable"
  49. }
  50. # 获取分辨率尺寸
  51. size = RESOLUTION_MAP.get(request.resolution, RESOLUTION_MAP["720P"])["16:9"]
  52. body = {
  53. "model": model,
  54. "input": {
  55. "prompt": request.prompt
  56. },
  57. "parameters": {
  58. "size": size,
  59. "duration": request.duration,
  60. "prompt_extend": request.prompt_extend,
  61. "shot_type": request.shot_type,
  62. "watermark": request.watermark
  63. }
  64. }
  65. if request.negative_prompt:
  66. body["input"]["negative_prompt"] = request.negative_prompt
  67. if request.audio_url:
  68. body["input"]["audio_url"] = request.audio_url
  69. if request.seed:
  70. body["parameters"]["seed"] = request.seed
  71. async with httpx.AsyncClient(timeout=30.0) as client:
  72. response = await client.post(url, json=body, headers=headers)
  73. result = response.json()
  74. if "code" in result:
  75. raise Exception(f"创建任务失败: {result.get('message')}")
  76. task_id = result["output"]["task_id"]
  77. task_status = result["output"]["task_status"]
  78. # 保存记录
  79. record = AIVideo(
  80. user_id=self.user_id,
  81. task_id=task_id,
  82. model_name=model,
  83. video_type="t2v",
  84. input_params=body,
  85. prompt=request.prompt,
  86. audio_url=request.audio_url,
  87. resolution=request.resolution,
  88. status=task_status,
  89. submit_time=datetime.now()
  90. )
  91. self.db.add(record)
  92. self.db.commit()
  93. return VideoTaskResponse(task_id=task_id, task_status=task_status)
  94. async def _first_last_frame_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse:
  95. """首尾帧生视频 - 调用wan2.2-kf2v-flash(HTTP接口)"""
  96. url = f"{DASHSCOPE_BASE_URL}/services/aigc/image2video/video-synthesis"
  97. headers = {
  98. "Content-Type": "application/json",
  99. "Authorization": f"Bearer {self.api_key}",
  100. "X-DashScope-Async": "enable"
  101. }
  102. body = {
  103. "model": "wan2.2-kf2v-flash",
  104. "input": {
  105. "first_frame_url": request.first_frame_url,
  106. "last_frame_url": request.last_frame_url,
  107. "prompt": request.prompt
  108. },
  109. "parameters": {
  110. "resolution": request.resolution,
  111. "prompt_extend": request.prompt_extend,
  112. "watermark": request.watermark
  113. }
  114. }
  115. if request.negative_prompt:
  116. body["input"]["negative_prompt"] = request.negative_prompt
  117. if request.seed:
  118. body["parameters"]["seed"] = request.seed
  119. async with httpx.AsyncClient(timeout=30.0) as client:
  120. response = await client.post(url, json=body, headers=headers)
  121. result = response.json()
  122. if "code" in result:
  123. raise Exception(f"创建任务失败: {result.get('message')}")
  124. task_id = result["output"]["task_id"]
  125. task_status = result["output"]["task_status"]
  126. # 保存记录
  127. record = AIVideo(
  128. user_id=self.user_id,
  129. task_id=task_id,
  130. model_name="wan2.2-kf2v-flash",
  131. video_type="kf2v",
  132. input_params=body,
  133. prompt=request.prompt,
  134. first_frame_url=request.first_frame_url,
  135. last_frame_url=request.last_frame_url,
  136. resolution=request.resolution,
  137. status=task_status,
  138. submit_time=datetime.now()
  139. )
  140. self.db.add(record)
  141. self.db.commit()
  142. return VideoTaskResponse(task_id=task_id, task_status=task_status)
  143. async def _image_to_video(self, request: VideoGenerateRequest) -> VideoTaskResponse:
  144. """图生视频(HTTP接口)"""
  145. # 使用前端传入的模型,默认 wan2.6-i2v
  146. model = request.model or "wan2.6-i2v"
  147. url = f"{DASHSCOPE_BASE_URL}/services/aigc/video-generation/video-synthesis"
  148. headers = {
  149. "Content-Type": "application/json",
  150. "Authorization": f"Bearer {self.api_key}",
  151. "X-DashScope-Async": "enable"
  152. }
  153. # 获取分辨率尺寸
  154. size = RESOLUTION_MAP.get(request.resolution, RESOLUTION_MAP["720P"])["16:9"]
  155. body = {
  156. "model": model,
  157. "input": {
  158. "prompt": request.prompt,
  159. "img_url": request.first_frame_url
  160. },
  161. "parameters": {
  162. "size": size,
  163. "duration": request.duration,
  164. "prompt_extend": request.prompt_extend,
  165. "watermark": request.watermark
  166. }
  167. }
  168. if request.negative_prompt:
  169. body["input"]["negative_prompt"] = request.negative_prompt
  170. if request.last_frame_url:
  171. body["input"]["last_img_url"] = request.last_frame_url
  172. if request.audio_url:
  173. body["input"]["audio_url"] = request.audio_url
  174. if request.audio is not None:
  175. body["parameters"]["audio"] = request.audio
  176. if request.seed:
  177. body["parameters"]["seed"] = request.seed
  178. async with httpx.AsyncClient(timeout=30.0) as client:
  179. response = await client.post(url, json=body, headers=headers)
  180. result = response.json()
  181. if "code" in result:
  182. raise Exception(f"创建任务失败: {result.get('message')}")
  183. task_id = result["output"]["task_id"]
  184. task_status = result["output"]["task_status"]
  185. # 保存记录
  186. record = AIVideo(
  187. user_id=self.user_id,
  188. task_id=task_id,
  189. model_name=model,
  190. video_type="i2v",
  191. input_params={
  192. "prompt": request.prompt,
  193. "first_frame_url": request.first_frame_url,
  194. "last_frame_url": request.last_frame_url,
  195. "resolution": request.resolution,
  196. "duration": request.duration
  197. },
  198. prompt=request.prompt,
  199. first_frame_url=request.first_frame_url,
  200. last_frame_url=request.last_frame_url,
  201. audio_url=request.audio_url,
  202. resolution=request.resolution,
  203. status=task_status,
  204. submit_time=datetime.now()
  205. )
  206. self.db.add(record)
  207. self.db.commit()
  208. return VideoTaskResponse(task_id=task_id, task_status=task_status)
  209. async def get_task_status(self, task_id: str) -> VideoTaskResult:
  210. """查询任务状态"""
  211. # 验证任务归属
  212. record = self.db.query(AIVideo).filter(
  213. AIVideo.task_id == task_id,
  214. AIVideo.user_id == self.user_id
  215. ).first()
  216. if not record:
  217. raise Exception("任务不存在")
  218. # 如果已完成,直接返回
  219. if record.status in ["SUCCEEDED", "FAILED"]:
  220. return VideoTaskResult(
  221. task_id=task_id,
  222. task_status=record.status,
  223. video_url=record.video_url,
  224. video_duration=float(record.video_duration) if record.video_duration else None,
  225. actual_prompt=record.actual_prompt,
  226. error_message=record.error_message
  227. )
  228. # 查询百炼API
  229. url = f"{DASHSCOPE_BASE_URL}/tasks/{task_id}"
  230. headers = {"Authorization": f"Bearer {self.api_key}"}
  231. async with httpx.AsyncClient(timeout=30.0) as client:
  232. response = await client.get(url, headers=headers)
  233. result = response.json()
  234. output = result.get("output", {})
  235. task_status = output.get("task_status", "UNKNOWN")
  236. # 更新数据库
  237. record.status = task_status
  238. if task_status == "RUNNING" and not record.scheduled_time:
  239. record.scheduled_time = datetime.now()
  240. if task_status == "SUCCEEDED":
  241. record.end_time = datetime.now()
  242. # 保存actual_prompt
  243. actual_prompt = output.get("actual_prompt")
  244. if actual_prompt:
  245. record.actual_prompt = actual_prompt
  246. # 费用(API调用免费)
  247. record.bill = Decimal("0")
  248. # 先落库(状态 SUCCEEDED),让前端立刻拿到结果
  249. # OSS 上传放到后台线程,不阻塞当前请求
  250. dashscope_video_url = output.get("video_url")
  251. self.db.commit()
  252. if dashscope_video_url:
  253. threading.Thread(
  254. target=self._upload_video_to_oss_sync,
  255. args=(record.id, dashscope_video_url),
  256. daemon=True
  257. ).start()
  258. return VideoTaskResult(
  259. task_id=task_id,
  260. task_status=record.status,
  261. # OSS 上传未完成前先返回 DashScope 原始 URL,前端可立即播放
  262. video_url=dashscope_video_url,
  263. video_duration=float(record.video_duration) if record.video_duration else None,
  264. actual_prompt=record.actual_prompt,
  265. error_message=record.error_message
  266. )
  267. elif task_status == "FAILED":
  268. record.end_time = datetime.now()
  269. record.error_message = output.get("message", "任务失败")
  270. self.db.commit()
  271. # 返回时以数据库中最终的 record.status 为准,确保扣费失败后前端不会误判任务为 SUCCEEDED
  272. return VideoTaskResult(
  273. task_id=task_id,
  274. task_status=record.status if record.status else task_status,
  275. video_url=record.video_url,
  276. video_duration=float(record.video_duration) if record.video_duration else None,
  277. actual_prompt=record.actual_prompt,
  278. error_message=record.error_message
  279. )
  280. def _upload_video_to_oss_sync(self, record_id: int, video_url: str) -> None:
  281. """后台线程:将视频从 DashScope URL 上传到 OSS,完成后更新数据库。
  282. 使用独立的数据库会话,避免与请求会话冲突。
  283. """
  284. from app.database import SessionLocal
  285. try:
  286. oss_url = self.oss_service.upload_from_url_sync(video_url, "ai-videos")
  287. except Exception as e:
  288. logger.error(f"OSS 上传失败: record_id={record_id}, error={e}")
  289. return
  290. db = SessionLocal()
  291. try:
  292. rec = db.query(AIVideo).filter(AIVideo.id == record_id).first()
  293. if rec:
  294. rec.video_url = oss_url
  295. db.commit()
  296. logger.info(f"OSS 上传完成: record_id={record_id}, oss_url={oss_url}")
  297. except Exception as e:
  298. logger.error(f"更新 OSS URL 失败: record_id={record_id}, error={e}")
  299. db.rollback()
  300. finally:
  301. db.close()
  302. def get_history(self, page: int = 1, page_size: int = 20) -> VideoHistoryResponse:
  303. """获取用户历史记录"""
  304. query = self.db.query(AIVideo).filter(
  305. AIVideo.user_id == self.user_id,
  306. AIVideo.review_status != 'rejected' # 排除被拒绝的内容
  307. ).order_by(AIVideo.created_at.desc())
  308. total = query.count()
  309. offset = (page - 1) * page_size
  310. records = query.offset(offset).limit(page_size).all()
  311. items = [
  312. VideoHistoryItem(
  313. id=r.id,
  314. task_id=r.task_id,
  315. model_name=r.model_name,
  316. video_type=r.video_type,
  317. prompt=r.prompt,
  318. custom_name=r.custom_name,
  319. video_url=r.video_url,
  320. video_duration=float(r.video_duration) if r.video_duration else None,
  321. resolution=r.resolution,
  322. status=r.status,
  323. bill=float(r.bill) if r.bill else 0,
  324. created_at=r.created_at.isoformat() if r.created_at else ""
  325. )
  326. for r in records
  327. ]
  328. return VideoHistoryResponse(
  329. items=items,
  330. total=total,
  331. page=page,
  332. page_size=page_size
  333. )