avatar_service.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """
  2. 数字人合成服务
  3. 提供数字人图像检测和视频合成功能
  4. """
  5. import logging
  6. import httpx
  7. from datetime import datetime
  8. from decimal import Decimal
  9. from typing import Optional
  10. from fastapi import HTTPException
  11. from sqlalchemy.orm import Session
  12. from app.models.ai_video import AIVideo
  13. from app.schemas.video_schema import (
  14. AvatarDetectRequest, AvatarDetectResponse,
  15. AvatarGenerateRequest, VideoTaskResponse, VideoTaskResult
  16. )
  17. from app.services.oss_service import get_oss_service
  18. logger = logging.getLogger(__name__)
  19. DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
  20. class AvatarService:
  21. """数字人合成服务"""
  22. def __init__(self, db: Session, user_id: int, api_key: str):
  23. self.db = db
  24. self.user_id = user_id
  25. self.api_key = api_key
  26. self.oss_service = get_oss_service()
  27. async def detect_image(self, image_url: str) -> AvatarDetectResponse:
  28. """图像检测 - 调用wan2.2-s2v-detect"""
  29. url = f"{DASHSCOPE_BASE_URL}/services/aigc/image2video/face-detect"
  30. headers = {
  31. "Content-Type": "application/json",
  32. "Authorization": f"Bearer {self.api_key}"
  33. }
  34. body = {
  35. "model": "wan2.2-s2v-detect",
  36. "input": {"image_url": image_url}
  37. }
  38. async with httpx.AsyncClient(timeout=30.0) as client:
  39. response = await client.post(url, json=body, headers=headers)
  40. result = response.json()
  41. if "code" in result:
  42. return AvatarDetectResponse(
  43. check_pass=False,
  44. humanoid=False,
  45. message=result.get("message", "检测失败")
  46. )
  47. output = result.get("output", {})
  48. return AvatarDetectResponse(
  49. check_pass=output.get("check_pass", False),
  50. humanoid=output.get("humanoid", False),
  51. message=output.get("message")
  52. )
  53. async def generate(self, request: AvatarGenerateRequest) -> VideoTaskResponse:
  54. """数字人视频合成 - 调用wan2.2-s2v"""
  55. url = f"{DASHSCOPE_BASE_URL}/services/aigc/image2video/video-synthesis"
  56. headers = {
  57. "Content-Type": "application/json",
  58. "Authorization": f"Bearer {self.api_key}",
  59. "X-DashScope-Async": "enable"
  60. }
  61. body = {
  62. "model": "wan2.2-s2v",
  63. "input": {
  64. "image_url": request.image_url,
  65. "audio_url": request.audio_url
  66. },
  67. "parameters": {
  68. "resolution": request.resolution
  69. }
  70. }
  71. async with httpx.AsyncClient(timeout=30.0) as client:
  72. response = await client.post(url, json=body, headers=headers)
  73. result = response.json()
  74. if "code" in result:
  75. raise Exception(f"创建任务失败: {result.get('message')}")
  76. task_id = result["output"]["task_id"]
  77. task_status = result["output"]["task_status"]
  78. # 保存记录
  79. record = AIVideo(
  80. user_id=self.user_id,
  81. task_id=task_id,
  82. model_name="wan2.2-s2v",
  83. video_type="s2v",
  84. input_params={
  85. "image_url": request.image_url,
  86. "audio_url": request.audio_url,
  87. "resolution": request.resolution
  88. },
  89. audio_url=request.audio_url,
  90. resolution=request.resolution,
  91. status=task_status,
  92. submit_time=datetime.now()
  93. )
  94. self.db.add(record)
  95. self.db.commit()
  96. return VideoTaskResponse(task_id=task_id, task_status=task_status)
  97. async def get_task_status(self, task_id: str) -> VideoTaskResult:
  98. """查询任务状态"""
  99. # 验证任务归属
  100. record = self.db.query(AIVideo).filter(
  101. AIVideo.task_id == task_id,
  102. AIVideo.user_id == self.user_id
  103. ).first()
  104. if not record:
  105. raise Exception("任务不存在")
  106. # 如果已完成,直接返回
  107. if record.status in ["SUCCEEDED", "FAILED"]:
  108. return VideoTaskResult(
  109. task_id=task_id,
  110. task_status=record.status,
  111. video_url=record.video_url,
  112. video_duration=float(record.video_duration) if record.video_duration else None,
  113. error_message=record.error_message
  114. )
  115. # 查询百炼API
  116. url = f"{DASHSCOPE_BASE_URL}/tasks/{task_id}"
  117. headers = {"Authorization": f"Bearer {self.api_key}"}
  118. async with httpx.AsyncClient(timeout=30.0) as client:
  119. response = await client.get(url, headers=headers)
  120. result = response.json()
  121. output = result.get("output", {})
  122. task_status = output.get("task_status", "UNKNOWN")
  123. # 更新数据库
  124. record.status = task_status
  125. if task_status == "RUNNING" and not record.scheduled_time:
  126. record.scheduled_time = datetime.now()
  127. if task_status == "SUCCEEDED":
  128. record.end_time = datetime.now()
  129. # 注意: 视频URL在 output.results.video_url
  130. results = output.get("results", {})
  131. video_url = results.get("video_url")
  132. if video_url:
  133. # 下载到OSS
  134. oss_url = await self.oss_service.upload_from_url(video_url, "ai-videos")
  135. record.video_url = oss_url
  136. usage = result.get("usage", {})
  137. video_duration = usage.get("duration", 0)
  138. record.video_duration = Decimal(str(video_duration))
  139. # 费用(API调用免费)
  140. record.bill = Decimal("0")
  141. elif task_status == "FAILED":
  142. record.end_time = datetime.now()
  143. record.error_message = output.get("message", "任务失败")
  144. self.db.commit()
  145. return VideoTaskResult(
  146. task_id=task_id,
  147. task_status=task_status,
  148. video_url=record.video_url,
  149. video_duration=float(record.video_duration) if record.video_duration else None,
  150. error_message=record.error_message
  151. )