""" AI视频API路由 提供数字人合成和视频合成的RESTful API端点 """ from typing import List from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from app.database import get_db from app.models.user import User from app.middleware import get_current_user_from_request from app.schemas.model_schema import ApiResponse from app.schemas.video_schema import ( AvatarDetectRequest, AvatarDetectResponse, AvatarGenerateRequest, VideoGenerateRequest, VideoTaskResponse, VideoTaskResult, VideoHistoryResponse, UpdateVideoNameRequest, VideoModelInfo ) from app.services.avatar_service import AvatarService from app.services.video_service import VideoService from app.services.system_config_manager import get_config_int from app.models.ai_video import AIVideo from app.models.model import ModelNew, ModelCategory router = APIRouter(prefix="/api/video", tags=["AI视频"]) @router.get("/models", response_model=ApiResponse[list]) def get_video_models( db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """获取视频模型列表(文生视频 + 图生视频 + 数字人)""" from app.models.model import ModelPriceNew from sqlalchemy.orm import selectinload # 用 selectinload 一次性加载所有模型的价格,避免 N+1 models = db.query(ModelNew).options( selectinload(ModelNew.prices) ).filter( ModelNew.categories.any(int(ModelCategory.VIDEO_GEN)), ModelNew.is_api_enabled == True, ModelNew.is_show_enabled == True, ).all() result = [] for m in models: kw = (m.keywords or "").lower() if "数字人" in kw or "avatar" in kw or "s2v" in m.model_code: video_type = "avatar" elif "i2v" in m.model_code or "图生视频" in kw: video_type = "image_to_video" else: video_type = "text_to_video" # 从已预加载的 prices 里取,不再单独查询 price_row = next((p for p in (m.prices or []) if p.is_active), None) price_per_second = str(price_row.output_price_discounted) if price_row else "0" result.append(VideoModelInfo( model_id=m.model_code, model_name=m.display_name or m.model_code, description=m.custom_description or m.description or "", video_type=video_type, price_per_second=price_per_second, )) return ApiResponse(code=200, message="success", data=result) # ==================== 数字人端点 ==================== @router.post("/avatar/detect", response_model=ApiResponse[AvatarDetectResponse]) async def detect_avatar_image( request: AvatarDetectRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """数字人图像检测""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥") from app.services.crypto_utils import get_effective_api_key effective_key = get_effective_api_key(db, "wan2.2-s2v", current_user.apikey) service = AvatarService(db, current_user.id, effective_key) result = await service.detect_image(request.image_url) return ApiResponse(code=200, message="success", data=result) @router.post("/avatar/generate", response_model=ApiResponse[VideoTaskResponse]) async def create_avatar_task( request: AvatarGenerateRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """创建数字人合成任务,需要余额检查""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥") from app.services.crypto_utils import get_effective_api_key effective_key = get_effective_api_key(db, "wan2.2-s2v", current_user.apikey) service = AvatarService(db, current_user.id, effective_key) result = await service.generate(request) return ApiResponse(code=200, message="success", data=result) @router.get("/avatar/task/{task_id}", response_model=ApiResponse[VideoTaskResult]) async def get_avatar_task_status( task_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """查询数字人任务状态""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥") from app.services.crypto_utils import get_effective_api_key effective_key = get_effective_api_key(db, "wan2.2-s2v", current_user.apikey) service = AvatarService(db, current_user.id, effective_key) result = await service.get_task_status(task_id) return ApiResponse(code=200, message="success", data=result) # ==================== 视频合成端点 ==================== @router.post("/generate", response_model=ApiResponse[VideoTaskResponse]) async def create_video_task( request: VideoGenerateRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """创建视频生成任务(统一端点),需要余额检查""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥") # 检查视频时长限制 max_duration = get_config_int("max_video_duration", 10) if request.duration > max_duration: raise HTTPException(status_code=400, detail=f"视频时长超过限制(最大{max_duration}秒)") from app.services.crypto_utils import get_effective_api_key # 根据是否有首帧判断实际使用的模型 model_code = "wan2.6-i2v" if request.first_frame_url else "wan2.6-t2v" effective_key = get_effective_api_key(db, model_code, current_user.apikey) service = VideoService(db, current_user.id, effective_key) result = await service.generate(request) return ApiResponse(code=200, message="success", data=result) @router.get("/task/{task_id}", response_model=ApiResponse[VideoTaskResult]) async def get_video_task_status( task_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """查询视频任务状态""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥") # 查询任务对应的模型,以便获取正确的 API Key(与创建任务时保持一致) from app.models.ai_video import AIVideo as AIVideoModel from app.services.crypto_utils import get_effective_api_key record = db.query(AIVideoModel).filter( AIVideoModel.task_id == task_id, AIVideoModel.user_id == current_user.id ).first() model_code = record.model_name if record else "wan2.6-t2v" effective_key = get_effective_api_key(db, model_code, current_user.apikey) service = VideoService(db, current_user.id, effective_key) result = await service.get_task_status(task_id) return ApiResponse(code=200, message="success", data=result) @router.get("/history", response_model=ApiResponse[VideoHistoryResponse]) def get_video_history( page: int = Query(default=1, ge=1, description="页码"), page_size: int = Query(default=20, ge=1, le=100, description="每页数量"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """获取视频生成历史""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥") service = VideoService(db, current_user.id, current_user.apikey) result = service.get_history(page, page_size) return ApiResponse(code=200, message="success", data=result) @router.put("/history/{record_id}/name", response_model=ApiResponse[None]) def update_video_name( record_id: int, request: UpdateVideoNameRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """更新视频记录的自定义名称""" record = db.query(AIVideo).filter( AIVideo.id == record_id, AIVideo.user_id == current_user.id ).first() if not record: raise HTTPException(status_code=404, detail="记录不存在") record.custom_name = request.custom_name.strip() db.commit() return ApiResponse(code=200, message="success", data=None)