video_router.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. """
  2. AI视频API路由
  3. 提供数字人合成和视频合成的RESTful API端点
  4. """
  5. from typing import List
  6. from fastapi import APIRouter, Depends, HTTPException, Query
  7. from sqlalchemy.orm import Session
  8. from app.database import get_db
  9. from app.models.user import User
  10. from app.middleware import get_current_user_from_request
  11. from app.schemas.model_schema import ApiResponse
  12. from app.schemas.video_schema import (
  13. AvatarDetectRequest, AvatarDetectResponse,
  14. AvatarGenerateRequest, VideoGenerateRequest,
  15. VideoTaskResponse, VideoTaskResult,
  16. VideoHistoryResponse, UpdateVideoNameRequest,
  17. VideoModelInfo
  18. )
  19. from app.services.avatar_service import AvatarService
  20. from app.services.video_service import VideoService
  21. from app.services.system_config_manager import get_config_int
  22. from app.models.ai_video import AIVideo
  23. from app.models.model import ModelNew, ModelCategory
  24. router = APIRouter(prefix="/api/video", tags=["AI视频"])
  25. @router.get("/models", response_model=ApiResponse[list])
  26. def get_video_models(
  27. db: Session = Depends(get_db),
  28. current_user: User = Depends(get_current_user_from_request)
  29. ):
  30. """获取视频模型列表(文生视频 + 图生视频 + 数字人)"""
  31. from app.models.model import ModelPriceNew
  32. from sqlalchemy.orm import selectinload
  33. # 用 selectinload 一次性加载所有模型的价格,避免 N+1
  34. models = db.query(ModelNew).options(
  35. selectinload(ModelNew.prices)
  36. ).filter(
  37. ModelNew.categories.any(int(ModelCategory.VIDEO_GEN)),
  38. ModelNew.is_api_enabled == True,
  39. ModelNew.is_show_enabled == True,
  40. ).all()
  41. result = []
  42. for m in models:
  43. kw = (m.keywords or "").lower()
  44. if "数字人" in kw or "avatar" in kw or "s2v" in m.model_code:
  45. video_type = "avatar"
  46. elif "i2v" in m.model_code or "图生视频" in kw:
  47. video_type = "image_to_video"
  48. else:
  49. video_type = "text_to_video"
  50. # 从已预加载的 prices 里取,不再单独查询
  51. price_row = next((p for p in (m.prices or []) if p.is_active), None)
  52. price_per_second = str(price_row.output_price_discounted) if price_row else "0"
  53. result.append(VideoModelInfo(
  54. model_id=m.model_code,
  55. model_name=m.display_name or m.model_code,
  56. description=m.custom_description or m.description or "",
  57. video_type=video_type,
  58. price_per_second=price_per_second,
  59. ))
  60. return ApiResponse(code=200, message="success", data=result)
  61. # ==================== 数字人端点 ====================
  62. @router.post("/avatar/detect", response_model=ApiResponse[AvatarDetectResponse])
  63. async def detect_avatar_image(
  64. request: AvatarDetectRequest,
  65. db: Session = Depends(get_db),
  66. current_user: User = Depends(get_current_user_from_request)
  67. ):
  68. """数字人图像检测"""
  69. if not current_user.apikey:
  70. raise HTTPException(status_code=403, detail="未配置API密钥")
  71. from app.services.crypto_utils import get_effective_api_key
  72. effective_key = get_effective_api_key(db, "wan2.2-s2v", current_user.apikey)
  73. service = AvatarService(db, current_user.id, effective_key)
  74. result = await service.detect_image(request.image_url)
  75. return ApiResponse(code=200, message="success", data=result)
  76. @router.post("/avatar/generate", response_model=ApiResponse[VideoTaskResponse])
  77. async def create_avatar_task(
  78. request: AvatarGenerateRequest,
  79. db: Session = Depends(get_db),
  80. current_user: User = Depends(get_current_user_from_request)
  81. ):
  82. """创建数字人合成任务,需要余额检查"""
  83. if not current_user.apikey:
  84. raise HTTPException(status_code=403, detail="未配置API密钥")
  85. from app.services.crypto_utils import get_effective_api_key
  86. effective_key = get_effective_api_key(db, "wan2.2-s2v", current_user.apikey)
  87. service = AvatarService(db, current_user.id, effective_key)
  88. result = await service.generate(request)
  89. return ApiResponse(code=200, message="success", data=result)
  90. @router.get("/avatar/task/{task_id}", response_model=ApiResponse[VideoTaskResult])
  91. async def get_avatar_task_status(
  92. task_id: str,
  93. db: Session = Depends(get_db),
  94. current_user: User = Depends(get_current_user_from_request)
  95. ):
  96. """查询数字人任务状态"""
  97. if not current_user.apikey:
  98. raise HTTPException(status_code=403, detail="未配置API密钥")
  99. from app.services.crypto_utils import get_effective_api_key
  100. effective_key = get_effective_api_key(db, "wan2.2-s2v", current_user.apikey)
  101. service = AvatarService(db, current_user.id, effective_key)
  102. result = await service.get_task_status(task_id)
  103. return ApiResponse(code=200, message="success", data=result)
  104. # ==================== 视频合成端点 ====================
  105. @router.post("/generate", response_model=ApiResponse[VideoTaskResponse])
  106. async def create_video_task(
  107. request: VideoGenerateRequest,
  108. db: Session = Depends(get_db),
  109. current_user: User = Depends(get_current_user_from_request)
  110. ):
  111. """创建视频生成任务(统一端点),需要余额检查"""
  112. if not current_user.apikey:
  113. raise HTTPException(status_code=403, detail="未配置API密钥")
  114. # 检查视频时长限制
  115. max_duration = get_config_int("max_video_duration", 10)
  116. if request.duration > max_duration:
  117. raise HTTPException(status_code=400, detail=f"视频时长超过限制(最大{max_duration}秒)")
  118. from app.services.crypto_utils import get_effective_api_key
  119. # 根据是否有首帧判断实际使用的模型
  120. model_code = "wan2.6-i2v" if request.first_frame_url else "wan2.6-t2v"
  121. effective_key = get_effective_api_key(db, model_code, current_user.apikey)
  122. service = VideoService(db, current_user.id, effective_key)
  123. result = await service.generate(request)
  124. return ApiResponse(code=200, message="success", data=result)
  125. @router.get("/task/{task_id}", response_model=ApiResponse[VideoTaskResult])
  126. async def get_video_task_status(
  127. task_id: str,
  128. db: Session = Depends(get_db),
  129. current_user: User = Depends(get_current_user_from_request)
  130. ):
  131. """查询视频任务状态"""
  132. if not current_user.apikey:
  133. raise HTTPException(status_code=403, detail="未配置API密钥")
  134. # 查询任务对应的模型,以便获取正确的 API Key(与创建任务时保持一致)
  135. from app.models.ai_video import AIVideo as AIVideoModel
  136. from app.services.crypto_utils import get_effective_api_key
  137. record = db.query(AIVideoModel).filter(
  138. AIVideoModel.task_id == task_id,
  139. AIVideoModel.user_id == current_user.id
  140. ).first()
  141. model_code = record.model_name if record else "wan2.6-t2v"
  142. effective_key = get_effective_api_key(db, model_code, current_user.apikey)
  143. service = VideoService(db, current_user.id, effective_key)
  144. result = await service.get_task_status(task_id)
  145. return ApiResponse(code=200, message="success", data=result)
  146. @router.get("/history", response_model=ApiResponse[VideoHistoryResponse])
  147. def get_video_history(
  148. page: int = Query(default=1, ge=1, description="页码"),
  149. page_size: int = Query(default=20, ge=1, le=100, description="每页数量"),
  150. db: Session = Depends(get_db),
  151. current_user: User = Depends(get_current_user_from_request)
  152. ):
  153. """获取视频生成历史"""
  154. if not current_user.apikey:
  155. raise HTTPException(status_code=403, detail="未配置API密钥")
  156. service = VideoService(db, current_user.id, current_user.apikey)
  157. result = service.get_history(page, page_size)
  158. return ApiResponse(code=200, message="success", data=result)
  159. @router.put("/history/{record_id}/name", response_model=ApiResponse[None])
  160. def update_video_name(
  161. record_id: int,
  162. request: UpdateVideoNameRequest,
  163. db: Session = Depends(get_db),
  164. current_user: User = Depends(get_current_user_from_request)
  165. ):
  166. """更新视频记录的自定义名称"""
  167. record = db.query(AIVideo).filter(
  168. AIVideo.id == record_id,
  169. AIVideo.user_id == current_user.id
  170. ).first()
  171. if not record:
  172. raise HTTPException(status_code=404, detail="记录不存在")
  173. record.custom_name = request.custom_name.strip()
  174. db.commit()
  175. return ApiResponse(code=200, message="success", data=None)