""" 图片生成API路由 提供AI图片生成的RESTful API端点 """ from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form from sqlalchemy.orm import Session from app.database import get_db from app.models.user import User from app.middleware import get_current_user_from_request from app.services.image_service import ImageGenerationService from app.services.system_config_manager import get_config_int from app.schemas.model_schema import ApiResponse from app.schemas.image_schema import ( TextToImageRequest, TextToImageResponse, ImageToImageResponse, ImageModelInfo, ImageHistoryItem, ImageHistoryResponse ) router = APIRouter(prefix="/api/image", tags=["图片生成"]) def _normalize_error_detail(detail: Optional[str]) -> str: """移除形如 '402: ' 的前缀,返回更友好的错误文案。""" if not detail: return "操作失败,请稍后重试" normalized = detail.strip() if ":" in normalized: prefix, rest = normalized.split(":", 1) if prefix.strip().isdigit() and len(prefix.strip()) == 3: return rest.strip() or normalized return normalized @router.post("/text-to-image", response_model=ApiResponse[TextToImageResponse]) async def text_to_image( request: TextToImageRequest, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """文生图接口,需要余额检查""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥,请在用户设置中配置apikey") # 检查图片数量限制 max_images = get_config_int("max_images_per_request", 4) if request.n > max_images: raise HTTPException(status_code=400, detail=f"单次最多生成{max_images}张图片") from app.services.crypto_utils import get_effective_api_key effective_key = get_effective_api_key(db, request.model, current_user.apikey) service = ImageGenerationService(db, effective_key) result = await service.text_to_image( user_id=current_user.id, prompt=request.prompt, model=request.model, n=request.n, size=request.size, negative_prompt=request.negative_prompt, prompt_extend=request.prompt_extend, watermark=request.watermark, seed=request.seed ) if not result.success: error_detail = _normalize_error_detail(result.error) if "余额不足" in error_detail: raise HTTPException(status_code=402, detail=error_detail) raise HTTPException(status_code=400, detail=error_detail) return ApiResponse( code=200, message="success", data=TextToImageResponse( success=result.success, images=result.images, bill=result.bill, record_id=result.record_id ) ) @router.post("/image-to-image", response_model=ApiResponse[ImageToImageResponse]) async def image_to_image( images: List[UploadFile] = File(..., description="参考图片(1~4张)"), prompt: str = Form(..., description="文本提示词"), model: str = Form(default="wan2.6-image", description="模型名称"), n: int = Form(default=1, ge=1, le=4, description="生成图片数量"), size: str = Form(default="1280*1280", description="图片尺寸"), negative_prompt: Optional[str] = Form(default=None, description="反向提示词"), prompt_extend: bool = Form(default=True, description="是否开启提示词智能改写"), watermark: bool = Form(default=False, description="是否添加水印"), seed: Optional[int] = Form(default=None, description="随机数种子"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """图生图接口(适配wan2.6-image模型),需要余额检查""" if not current_user.apikey: raise HTTPException(status_code=403, detail="未配置API密钥,请在用户设置中配置apikey") # 检查图片数量限制 max_images = get_config_int("max_images_per_request", 4) if n > max_images: raise HTTPException(status_code=400, detail=f"单次最多生成{max_images}张图片") if len(images) < 1 or len(images) > 4: raise HTTPException(status_code=400, detail="参考图片数量必须为1~4张") # 上传所有参考图片到OSS获取URL from app.services.crypto_utils import get_effective_api_key effective_key = get_effective_api_key(db, model, current_user.apikey) service = ImageGenerationService(db, effective_key) image_urls = [] for img in images: image_data = await img.read() img_url = service.oss_service.upload_image(image_data, "ai-images/input") image_urls.append(img_url) result = await service.image_to_image( user_id=current_user.id, image_urls=image_urls, prompt=prompt, model=model, n=n, size=size, negative_prompt=negative_prompt, prompt_extend=prompt_extend, watermark=watermark, seed=seed ) if not result.success: error_detail = _normalize_error_detail(result.error) if "余额不足" in error_detail: raise HTTPException(status_code=402, detail=error_detail) raise HTTPException(status_code=400, detail=error_detail) return ApiResponse( code=200, message="success", data=ImageToImageResponse( success=result.success, images=result.images, bill=result.bill, record_id=result.record_id ) ) @router.get("/text-to-image/models", response_model=ApiResponse[List[ImageModelInfo]]) def get_text_to_image_models( db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """获取文生图模型列表""" service = ImageGenerationService(db, current_user.apikey or "") models = service.get_text_to_image_models() return ApiResponse( code=200, message="success", data=[ImageModelInfo( model_id=m.model_id, model_name=m.model_name, description=m.description, price_per_image=m.price_per_image, supported_sizes=m.supported_sizes ) for m in models] ) @router.get("/image-to-image/models", response_model=ApiResponse[List[ImageModelInfo]]) def get_image_to_image_models( db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """获取图生图模型列表""" service = ImageGenerationService(db, current_user.apikey or "") models = service.get_image_to_image_models() return ApiResponse( code=200, message="success", data=[ImageModelInfo( model_id=m.model_id, model_name=m.model_name, description=m.description, price_per_image=m.price_per_image, supported_sizes=m.supported_sizes ) for m in models] ) @router.get("/history", response_model=ApiResponse[ImageHistoryResponse]) def get_image_history( page: int = 1, page_size: int = 20, db: Session = Depends(get_db), current_user: User = Depends(get_current_user_from_request) ): """获取用户图片生成历史记录""" service = ImageGenerationService(db, current_user.apikey or "") result = service.get_user_history( user_id=current_user.id, page=page, page_size=page_size ) return ApiResponse( code=200, message="success", data=ImageHistoryResponse( items=[ImageHistoryItem.model_validate(item) for item in result.items], total=result.total, page=result.page, page_size=result.page_size ) )