image_router.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """
  2. 图片生成API路由
  3. 提供AI图片生成的RESTful API端点
  4. """
  5. from typing import List, Optional
  6. from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
  7. from sqlalchemy.orm import Session
  8. from app.database import get_db
  9. from app.models.user import User
  10. from app.middleware import get_current_user_from_request
  11. from app.services.image_service import ImageGenerationService
  12. from app.services.system_config_manager import get_config_int
  13. from app.schemas.model_schema import ApiResponse
  14. from app.schemas.image_schema import (
  15. TextToImageRequest,
  16. TextToImageResponse,
  17. ImageToImageResponse,
  18. ImageModelInfo,
  19. ImageHistoryItem,
  20. ImageHistoryResponse
  21. )
  22. router = APIRouter(prefix="/api/image", tags=["图片生成"])
  23. def _normalize_error_detail(detail: Optional[str]) -> str:
  24. """移除形如 '402: ' 的前缀,返回更友好的错误文案。"""
  25. if not detail:
  26. return "操作失败,请稍后重试"
  27. normalized = detail.strip()
  28. if ":" in normalized:
  29. prefix, rest = normalized.split(":", 1)
  30. if prefix.strip().isdigit() and len(prefix.strip()) == 3:
  31. return rest.strip() or normalized
  32. return normalized
  33. @router.post("/text-to-image", response_model=ApiResponse[TextToImageResponse])
  34. async def text_to_image(
  35. request: TextToImageRequest,
  36. db: Session = Depends(get_db),
  37. current_user: User = Depends(get_current_user_from_request)
  38. ):
  39. """文生图接口,需要余额检查"""
  40. if not current_user.apikey:
  41. raise HTTPException(status_code=403, detail="未配置API密钥,请在用户设置中配置apikey")
  42. # 检查图片数量限制
  43. max_images = get_config_int("max_images_per_request", 4)
  44. if request.n > max_images:
  45. raise HTTPException(status_code=400, detail=f"单次最多生成{max_images}张图片")
  46. from app.services.crypto_utils import get_effective_api_key
  47. effective_key = get_effective_api_key(db, request.model, current_user.apikey)
  48. service = ImageGenerationService(db, effective_key)
  49. result = await service.text_to_image(
  50. user_id=current_user.id,
  51. prompt=request.prompt,
  52. model=request.model,
  53. n=request.n,
  54. size=request.size,
  55. negative_prompt=request.negative_prompt,
  56. prompt_extend=request.prompt_extend,
  57. watermark=request.watermark,
  58. seed=request.seed
  59. )
  60. if not result.success:
  61. error_detail = _normalize_error_detail(result.error)
  62. if "余额不足" in error_detail:
  63. raise HTTPException(status_code=402, detail=error_detail)
  64. raise HTTPException(status_code=400, detail=error_detail)
  65. return ApiResponse(
  66. code=200,
  67. message="success",
  68. data=TextToImageResponse(
  69. success=result.success,
  70. images=result.images,
  71. bill=result.bill,
  72. record_id=result.record_id
  73. )
  74. )
  75. @router.post("/image-to-image", response_model=ApiResponse[ImageToImageResponse])
  76. async def image_to_image(
  77. images: List[UploadFile] = File(..., description="参考图片(1~4张)"),
  78. prompt: str = Form(..., description="文本提示词"),
  79. model: str = Form(default="wan2.6-image", description="模型名称"),
  80. n: int = Form(default=1, ge=1, le=4, description="生成图片数量"),
  81. size: str = Form(default="1280*1280", description="图片尺寸"),
  82. negative_prompt: Optional[str] = Form(default=None, description="反向提示词"),
  83. prompt_extend: bool = Form(default=True, description="是否开启提示词智能改写"),
  84. watermark: bool = Form(default=False, description="是否添加水印"),
  85. seed: Optional[int] = Form(default=None, description="随机数种子"),
  86. db: Session = Depends(get_db),
  87. current_user: User = Depends(get_current_user_from_request)
  88. ):
  89. """图生图接口(适配wan2.6-image模型),需要余额检查"""
  90. if not current_user.apikey:
  91. raise HTTPException(status_code=403, detail="未配置API密钥,请在用户设置中配置apikey")
  92. # 检查图片数量限制
  93. max_images = get_config_int("max_images_per_request", 4)
  94. if n > max_images:
  95. raise HTTPException(status_code=400, detail=f"单次最多生成{max_images}张图片")
  96. if len(images) < 1 or len(images) > 4:
  97. raise HTTPException(status_code=400, detail="参考图片数量必须为1~4张")
  98. # 上传所有参考图片到OSS获取URL
  99. from app.services.crypto_utils import get_effective_api_key
  100. effective_key = get_effective_api_key(db, model, current_user.apikey)
  101. service = ImageGenerationService(db, effective_key)
  102. image_urls = []
  103. for img in images:
  104. image_data = await img.read()
  105. img_url = service.oss_service.upload_image(image_data, "ai-images/input")
  106. image_urls.append(img_url)
  107. result = await service.image_to_image(
  108. user_id=current_user.id,
  109. image_urls=image_urls,
  110. prompt=prompt,
  111. model=model,
  112. n=n,
  113. size=size,
  114. negative_prompt=negative_prompt,
  115. prompt_extend=prompt_extend,
  116. watermark=watermark,
  117. seed=seed
  118. )
  119. if not result.success:
  120. error_detail = _normalize_error_detail(result.error)
  121. if "余额不足" in error_detail:
  122. raise HTTPException(status_code=402, detail=error_detail)
  123. raise HTTPException(status_code=400, detail=error_detail)
  124. return ApiResponse(
  125. code=200,
  126. message="success",
  127. data=ImageToImageResponse(
  128. success=result.success,
  129. images=result.images,
  130. bill=result.bill,
  131. record_id=result.record_id
  132. )
  133. )
  134. @router.get("/text-to-image/models", response_model=ApiResponse[List[ImageModelInfo]])
  135. def get_text_to_image_models(
  136. db: Session = Depends(get_db),
  137. current_user: User = Depends(get_current_user_from_request)
  138. ):
  139. """获取文生图模型列表"""
  140. service = ImageGenerationService(db, current_user.apikey or "")
  141. models = service.get_text_to_image_models()
  142. return ApiResponse(
  143. code=200,
  144. message="success",
  145. data=[ImageModelInfo(
  146. model_id=m.model_id,
  147. model_name=m.model_name,
  148. description=m.description,
  149. price_per_image=m.price_per_image,
  150. supported_sizes=m.supported_sizes
  151. ) for m in models]
  152. )
  153. @router.get("/image-to-image/models", response_model=ApiResponse[List[ImageModelInfo]])
  154. def get_image_to_image_models(
  155. db: Session = Depends(get_db),
  156. current_user: User = Depends(get_current_user_from_request)
  157. ):
  158. """获取图生图模型列表"""
  159. service = ImageGenerationService(db, current_user.apikey or "")
  160. models = service.get_image_to_image_models()
  161. return ApiResponse(
  162. code=200,
  163. message="success",
  164. data=[ImageModelInfo(
  165. model_id=m.model_id,
  166. model_name=m.model_name,
  167. description=m.description,
  168. price_per_image=m.price_per_image,
  169. supported_sizes=m.supported_sizes
  170. ) for m in models]
  171. )
  172. @router.get("/history", response_model=ApiResponse[ImageHistoryResponse])
  173. def get_image_history(
  174. page: int = 1,
  175. page_size: int = 20,
  176. db: Session = Depends(get_db),
  177. current_user: User = Depends(get_current_user_from_request)
  178. ):
  179. """获取用户图片生成历史记录"""
  180. service = ImageGenerationService(db, current_user.apikey or "")
  181. result = service.get_user_history(
  182. user_id=current_user.id,
  183. page=page,
  184. page_size=page_size
  185. )
  186. return ApiResponse(
  187. code=200,
  188. message="success",
  189. data=ImageHistoryResponse(
  190. items=[ImageHistoryItem.model_validate(item) for item in result.items],
  191. total=result.total,
  192. page=result.page,
  193. page_size=result.page_size
  194. )
  195. )