| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- """
- 图片生成API路由
- 提供AI图片生成的RESTful API端点
- """
- from typing import List, Optional
- from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
- from sqlalchemy.orm import Session
- from app.database import get_db
- from app.models.user import User
- from app.middleware import get_current_user_from_request
- from app.services.image_service import ImageGenerationService
- from app.services.system_config_manager import get_config_int
- from app.schemas.model_schema import ApiResponse
- from app.schemas.image_schema import (
- TextToImageRequest,
- TextToImageResponse,
- ImageToImageResponse,
- ImageModelInfo,
- ImageHistoryItem,
- ImageHistoryResponse
- )
- router = APIRouter(prefix="/api/image", tags=["图片生成"])
- def _normalize_error_detail(detail: Optional[str]) -> str:
- """移除形如 '402: ' 的前缀,返回更友好的错误文案。"""
- if not detail:
- return "操作失败,请稍后重试"
- normalized = detail.strip()
- if ":" in normalized:
- prefix, rest = normalized.split(":", 1)
- if prefix.strip().isdigit() and len(prefix.strip()) == 3:
- return rest.strip() or normalized
- return normalized
- @router.post("/text-to-image", response_model=ApiResponse[TextToImageResponse])
- async def text_to_image(
- request: TextToImageRequest,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user_from_request)
- ):
- """文生图接口,需要余额检查"""
- if not current_user.apikey:
- raise HTTPException(status_code=403, detail="未配置API密钥,请在用户设置中配置apikey")
-
- # 检查图片数量限制
- max_images = get_config_int("max_images_per_request", 4)
- if request.n > max_images:
- raise HTTPException(status_code=400, detail=f"单次最多生成{max_images}张图片")
- from app.services.crypto_utils import get_effective_api_key
- effective_key = get_effective_api_key(db, request.model, current_user.apikey)
- service = ImageGenerationService(db, effective_key)
- result = await service.text_to_image(
- user_id=current_user.id,
- prompt=request.prompt,
- model=request.model,
- n=request.n,
- size=request.size,
- negative_prompt=request.negative_prompt,
- prompt_extend=request.prompt_extend,
- watermark=request.watermark,
- seed=request.seed
- )
- if not result.success:
- error_detail = _normalize_error_detail(result.error)
- if "余额不足" in error_detail:
- raise HTTPException(status_code=402, detail=error_detail)
- raise HTTPException(status_code=400, detail=error_detail)
- return ApiResponse(
- code=200,
- message="success",
- data=TextToImageResponse(
- success=result.success,
- images=result.images,
- bill=result.bill,
- record_id=result.record_id
- )
- )
- @router.post("/image-to-image", response_model=ApiResponse[ImageToImageResponse])
- async def image_to_image(
- images: List[UploadFile] = File(..., description="参考图片(1~4张)"),
- prompt: str = Form(..., description="文本提示词"),
- model: str = Form(default="wan2.6-image", description="模型名称"),
- n: int = Form(default=1, ge=1, le=4, description="生成图片数量"),
- size: str = Form(default="1280*1280", description="图片尺寸"),
- negative_prompt: Optional[str] = Form(default=None, description="反向提示词"),
- prompt_extend: bool = Form(default=True, description="是否开启提示词智能改写"),
- watermark: bool = Form(default=False, description="是否添加水印"),
- seed: Optional[int] = Form(default=None, description="随机数种子"),
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user_from_request)
- ):
- """图生图接口(适配wan2.6-image模型),需要余额检查"""
- if not current_user.apikey:
- raise HTTPException(status_code=403, detail="未配置API密钥,请在用户设置中配置apikey")
-
- # 检查图片数量限制
- max_images = get_config_int("max_images_per_request", 4)
- if n > max_images:
- raise HTTPException(status_code=400, detail=f"单次最多生成{max_images}张图片")
-
- if len(images) < 1 or len(images) > 4:
- raise HTTPException(status_code=400, detail="参考图片数量必须为1~4张")
- # 上传所有参考图片到OSS获取URL
- from app.services.crypto_utils import get_effective_api_key
- effective_key = get_effective_api_key(db, model, current_user.apikey)
- service = ImageGenerationService(db, effective_key)
- image_urls = []
- for img in images:
- image_data = await img.read()
- img_url = service.oss_service.upload_image(image_data, "ai-images/input")
- image_urls.append(img_url)
- result = await service.image_to_image(
- user_id=current_user.id,
- image_urls=image_urls,
- prompt=prompt,
- model=model,
- n=n,
- size=size,
- negative_prompt=negative_prompt,
- prompt_extend=prompt_extend,
- watermark=watermark,
- seed=seed
- )
- if not result.success:
- error_detail = _normalize_error_detail(result.error)
- if "余额不足" in error_detail:
- raise HTTPException(status_code=402, detail=error_detail)
- raise HTTPException(status_code=400, detail=error_detail)
- return ApiResponse(
- code=200,
- message="success",
- data=ImageToImageResponse(
- success=result.success,
- images=result.images,
- bill=result.bill,
- record_id=result.record_id
- )
- )
- @router.get("/text-to-image/models", response_model=ApiResponse[List[ImageModelInfo]])
- def get_text_to_image_models(
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user_from_request)
- ):
- """获取文生图模型列表"""
- service = ImageGenerationService(db, current_user.apikey or "")
- models = service.get_text_to_image_models()
- return ApiResponse(
- code=200,
- message="success",
- data=[ImageModelInfo(
- model_id=m.model_id,
- model_name=m.model_name,
- description=m.description,
- price_per_image=m.price_per_image,
- supported_sizes=m.supported_sizes
- ) for m in models]
- )
- @router.get("/image-to-image/models", response_model=ApiResponse[List[ImageModelInfo]])
- def get_image_to_image_models(
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user_from_request)
- ):
- """获取图生图模型列表"""
- service = ImageGenerationService(db, current_user.apikey or "")
- models = service.get_image_to_image_models()
- return ApiResponse(
- code=200,
- message="success",
- data=[ImageModelInfo(
- model_id=m.model_id,
- model_name=m.model_name,
- description=m.description,
- price_per_image=m.price_per_image,
- supported_sizes=m.supported_sizes
- ) for m in models]
- )
- @router.get("/history", response_model=ApiResponse[ImageHistoryResponse])
- def get_image_history(
- page: int = 1,
- page_size: int = 20,
- db: Session = Depends(get_db),
- current_user: User = Depends(get_current_user_from_request)
- ):
- """获取用户图片生成历史记录"""
- service = ImageGenerationService(db, current_user.apikey or "")
- result = service.get_user_history(
- user_id=current_user.id,
- page=page,
- page_size=page_size
- )
- return ApiResponse(
- code=200,
- message="success",
- data=ImageHistoryResponse(
- items=[ImageHistoryItem.model_validate(item) for item in result.items],
- total=result.total,
- page=result.page,
- page_size=result.page_size
- )
- )
|