| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490 |
- """
- 图片生成服务
- 提供文生图和图生图功能,调用DashScope API生成图片
- """
- import logging
- import httpx
- from decimal import Decimal
- from typing import List, Optional
- from dataclasses import dataclass
- import dashscope
- from dashscope import ImageSynthesis
- from sqlalchemy.orm import Session
- from fastapi import HTTPException
- from app.models.ai_picture import AIPicture
- from app.models.model import ModelNew, ModelCategory
- from app.services.oss_service import get_oss_service
- logger = logging.getLogger(__name__)
- DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
- @dataclass
- class ImageGenerationResult:
- """图片生成结果"""
- success: bool
- images: List[str]
- bill: Decimal
- record_id: int
- error: Optional[str] = None
- @dataclass
- class ImageModelInfo:
- """图片模型信息"""
- model_id: str
- model_name: str
- description: str
- price_per_image: Decimal
- supported_sizes: List[str]
- @dataclass
- class ImageHistoryResult:
- """历史记录查询结果"""
- items: List[AIPicture]
- total: int
- page: int
- page_size: int
- class ImageGenerationService:
- """图片生成服务类"""
- # 文生图支持的尺寸
- SUPPORTED_SIZES = [
- "1024*1024",
- "720*1280",
- "1280*720"
- ]
-
- # 图生图(wan2.6-image)支持的尺寸
- IMAGE_EDIT_SIZES = [
- "1280*1280",
- "1024*1024",
- "800*1200",
- "1200*800",
- "960*1280",
- "1280*960",
- "720*1280",
- "1280*720",
- "1344*576"
- ]
- def __init__(self, db: Session, api_key: str):
- self.db = db
- self.api_key = api_key
- self.oss_service = get_oss_service()
- dashscope.base_http_api_url = DASHSCOPE_BASE_URL
- def _get_image_models(self):
- return self.db.query(ModelNew).filter(
- ModelNew.is_show_enabled == True,
- ModelNew.is_api_enabled == True,
- ModelNew.categories.any(int(ModelCategory.IMAGE_GEN)) |
- ModelNew.categories.any(int(ModelCategory.IMAGE_EDIT))
- ).all()
- def _is_wan26_model(self, model: str) -> bool:
- """判断是否为wan2.6系列模型(需要使用新版多模态接口)"""
- return model.startswith("wan2.6")
- async def text_to_image(
- self,
- user_id: str,
- prompt: str,
- model: str = "wanx2.1-t2i-turbo",
- n: int = 1,
- size: str = "1024*1024",
- negative_prompt: Optional[str] = None,
- prompt_extend: bool = True,
- watermark: bool = False,
- seed: Optional[int] = None
- ) -> ImageGenerationResult:
- """
- 文生图:调用DashScope API生成图片
-
- 根据模型版本自动选择调用方式:
- - wan2.6系列:使用multimodal-generation接口
- - 其他版本:使用ImageSynthesis接口
- """
- if not prompt or not prompt.strip():
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="提示词不能为空"
- )
- try:
- # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己的 apikey
- from app.services.crypto_utils import get_effective_api_key
- effective_api_key = get_effective_api_key(self.db, model, self.api_key)
- if self._is_wan26_model(model):
- # wan2.6使用新版多模态接口
- oss_urls = await self._text_to_image_wan26(
- prompt=prompt, model=model, n=n, size=size,
- negative_prompt=negative_prompt, prompt_extend=prompt_extend,
- watermark=watermark, seed=seed, api_key=effective_api_key
- )
- else:
- # 旧版模型使用ImageSynthesis
- oss_urls = await self._text_to_image_legacy(
- prompt=prompt, model=model, n=n, size=size,
- negative_prompt=negative_prompt, api_key=effective_api_key
- )
-
- if not oss_urls:
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="未能获取生成的图片"
- )
- bill = Decimal("0")
- record = self._save_generation_record(
- user_id=user_id, model_id=model, model_name=model,
- input_type="text", input_data=prompt,
- image_count=len(oss_urls), output_images=oss_urls, bill=bill
- )
- return ImageGenerationResult(
- success=True, images=oss_urls, bill=bill, record_id=record.id
- )
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"文生图失败: {str(e)}")
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error=str(e)
- )
- async def _text_to_image_wan26(
- self, prompt: str, model: str, n: int, size: str,
- negative_prompt: Optional[str], prompt_extend: bool,
- watermark: bool, seed: Optional[int], api_key: Optional[str] = None
- ) -> List[str]:
- """wan2.6系列文生图(使用multimodal-generation接口)"""
- effective_key = api_key or self.api_key
- request_body = {
- "model": model,
- "input": {
- "messages": [
- {"role": "user", "content": [{"text": prompt}]}
- ]
- },
- "parameters": {
- "n": n,
- "size": size,
- "prompt_extend": prompt_extend,
- "watermark": watermark
- }
- }
-
- if negative_prompt:
- request_body["parameters"]["negative_prompt"] = negative_prompt
- if seed is not None:
- request_body["parameters"]["seed"] = seed
-
- url = f"{DASHSCOPE_BASE_URL}/services/aigc/multimodal-generation/generation"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {effective_key}"
- }
-
- async with httpx.AsyncClient(timeout=180.0) as client:
- response = await client.post(url, json=request_body, headers=headers)
- result = response.json()
-
- if "code" in result:
- raise Exception(f"API调用失败: {result.get('message', result.get('code'))}")
-
- oss_urls = []
- choices = result.get("output", {}).get("choices", [])
- for choice in choices:
- content_list = choice.get("message", {}).get("content", [])
- for item in content_list:
- if item.get("type") == "image":
- original_url = item.get("image")
- if original_url:
- oss_url = await self.oss_service.upload_from_url(original_url, "ai-images")
- oss_urls.append(oss_url)
- return oss_urls
- async def _text_to_image_legacy(
- self, prompt: str, model: str, n: int, size: str,
- negative_prompt: Optional[str], api_key: Optional[str] = None
- ) -> List[str]:
- """旧版模型文生图(使用ImageSynthesis接口)"""
- effective_key = api_key or self.api_key
- response = ImageSynthesis.call(
- api_key=effective_key,
- model=model,
- prompt=prompt,
- negative_prompt=negative_prompt,
- n=n,
- size=size
- )
- if response.status_code != 200:
- raise Exception(f"API调用失败: {response.message}")
- oss_urls = []
- for result in response.output.results:
- oss_url = await self.oss_service.upload_from_url(result.url, "ai-images")
- oss_urls.append(oss_url)
- return oss_urls
- async def image_to_image(
- self,
- user_id: str,
- image_urls: List[str],
- prompt: str,
- model: str = "wan2.6-image",
- n: int = 1,
- size: str = "1280*1280",
- negative_prompt: Optional[str] = None,
- prompt_extend: bool = True,
- watermark: bool = False,
- seed: Optional[int] = None
- ) -> ImageGenerationResult:
- """
- 图生图:基于参考图片生成新图片(适配wan2.6-image模型)
-
- 使用HTTP同步调用 multimodal-generation 接口
-
- Args:
- user_id: 用户ID
- image_urls: 参考图片URL列表(1~4张)
- prompt: 文本提示词
- model: 模型名称
- n: 生成图片数量(1~4)
- size: 图片尺寸
- negative_prompt: 反向提示词
- prompt_extend: 是否开启提示词智能改写
- watermark: 是否添加水印
- seed: 随机数种子
- """
- if not prompt or not prompt.strip():
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="提示词不能为空"
- )
-
- if not image_urls or len(image_urls) < 1:
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="图像编辑模式必须提供至少1张参考图片"
- )
-
- if len(image_urls) > 4:
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="最多支持4张参考图片"
- )
- try:
- # 构建content数组:先text后images
- content = [{"text": prompt}]
- for img_url in image_urls:
- content.append({"image": img_url})
-
- # 构建请求体
- request_body = {
- "model": model,
- "input": {
- "messages": [
- {"role": "user", "content": content}
- ]
- },
- "parameters": {
- "n": n,
- "size": size,
- "enable_interleave": False,
- "prompt_extend": prompt_extend,
- "watermark": watermark
- }
- }
-
- if negative_prompt:
- request_body["parameters"]["negative_prompt"] = negative_prompt
- if seed is not None:
- request_body["parameters"]["seed"] = seed
-
- # 调用HTTP同步接口
- url = f"{DASHSCOPE_BASE_URL}/services/aigc/multimodal-generation/generation"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {self.api_key}"
- }
-
- async with httpx.AsyncClient(timeout=180.0) as client:
- response = await client.post(url, json=request_body, headers=headers)
- result = response.json()
-
- # 检查错误
- if "code" in result:
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error=f"API调用失败: {result.get('message', result.get('code'))}"
- )
-
- # 解析响应
- output = result.get("output", {})
- choices = output.get("choices", [])
-
- oss_urls = []
- for choice in choices:
- message = choice.get("message", {})
- content_list = message.get("content", [])
- for item in content_list:
- # 兼容两种格式:{"type":"image","image":"url"} 和 {"image":"url"}
- if item.get("type") == "image" or ("image" in item and "text" not in item):
- original_url = item.get("image")
- if original_url:
- oss_url = await self.oss_service.upload_from_url(original_url, "ai-images")
- oss_urls.append(oss_url)
-
- if not oss_urls:
- logger.error(f"图生图响应解析失败,完整响应: {result}")
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="未能获取生成的图片"
- )
-
- # 计算费用
- bill = Decimal("0")
- # 保存记录(图生图时input_data保存提示词,而不是图片URL)
- record = self._save_generation_record(
- user_id=user_id,
- model_id=model,
- model_name=model,
- input_type="image",
- input_data=prompt, # 保存提示词而不是图片URL
- image_count=len(oss_urls),
- output_images=oss_urls,
- bill=bill
- )
- return ImageGenerationResult(
- success=True,
- images=oss_urls,
- bill=bill,
- record_id=record.id
- )
- except httpx.TimeoutException:
- logger.error("图生图请求超时")
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error="请求超时,请稍后重试"
- )
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"图生图失败: {str(e)}")
- return ImageGenerationResult(
- success=False, images=[], bill=Decimal("0"), record_id=0,
- error=str(e)
- )
- def _save_generation_record(
- self,
- user_id: str,
- model_id: str,
- model_name: str,
- input_type: str,
- input_data: str,
- image_count: int,
- output_images: List[str],
- bill: Decimal
- ) -> AIPicture:
- """保存生成记录到数据库"""
- record = AIPicture(
- model_id=model_id,
- model_name=model_name,
- user_id=user_id,
- input_type=input_type,
- input_data=input_data,
- image_count=image_count,
- output_images=output_images,
- bill=bill,
- status="success"
- )
- self.db.add(record)
- self.db.commit()
- self.db.refresh(record)
- return record
- def get_text_to_image_models(self) -> List[ImageModelInfo]:
- """获取文生图模型列表:有 IMAGE_GEN 且没有 IMAGE_EDIT(不需要参考图)"""
- models = self._get_image_models()
- result = []
- for model in models:
- cats = model.categories or []
- if int(ModelCategory.IMAGE_GEN) not in cats:
- continue
- if int(ModelCategory.IMAGE_EDIT) in cats:
- continue
- price = Decimal("0")
- result.append(ImageModelInfo(
- model_id=model.model_code,
- model_name=model.display_name,
- description=model.description or "",
- price_per_image=price,
- supported_sizes=self.SUPPORTED_SIZES
- ))
- return result
- def get_image_to_image_models(self) -> List[ImageModelInfo]:
- """获取图生图模型列表(categories 包含 IMAGE_EDIT(6))"""
- models = self._get_image_models()
- result = []
- for model in models:
- cats = model.categories or []
- if int(ModelCategory.IMAGE_EDIT) not in cats:
- continue
- price = Decimal("0")
- result.append(ImageModelInfo(
- model_id=model.model_code,
- model_name=model.display_name,
- description=model.description or "",
- price_per_image=price,
- supported_sizes=self.SUPPORTED_SIZES
- ))
- return result
- def get_user_history(
- self,
- user_id: str,
- page: int = 1,
- page_size: int = 20
- ) -> ImageHistoryResult:
- """
- 获取用户历史记录
-
- Args:
- user_id: 用户ID
- page: 页码
- page_size: 每页数量
- """
- query = self.db.query(AIPicture).filter(
- AIPicture.user_id == user_id,
- AIPicture.review_status != 'rejected' # 排除被拒绝的内容
- ).order_by(AIPicture.created_at.desc())
- total = query.count()
- offset = (page - 1) * page_size
- items = query.offset(offset).limit(page_size).all()
- return ImageHistoryResult(
- items=items,
- total=total,
- page=page,
- page_size=page_size
- )
|