image_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. """
  2. 图片生成服务
  3. 提供文生图和图生图功能,调用DashScope API生成图片
  4. """
  5. import logging
  6. import httpx
  7. from decimal import Decimal
  8. from typing import List, Optional
  9. from dataclasses import dataclass
  10. import dashscope
  11. from dashscope import ImageSynthesis
  12. from sqlalchemy.orm import Session
  13. from fastapi import HTTPException
  14. from app.models.ai_picture import AIPicture
  15. from app.models.model import ModelNew, ModelCategory
  16. from app.services.oss_service import get_oss_service
  17. logger = logging.getLogger(__name__)
  18. DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"
  19. @dataclass
  20. class ImageGenerationResult:
  21. """图片生成结果"""
  22. success: bool
  23. images: List[str]
  24. bill: Decimal
  25. record_id: int
  26. error: Optional[str] = None
  27. @dataclass
  28. class ImageModelInfo:
  29. """图片模型信息"""
  30. model_id: str
  31. model_name: str
  32. description: str
  33. price_per_image: Decimal
  34. supported_sizes: List[str]
  35. @dataclass
  36. class ImageHistoryResult:
  37. """历史记录查询结果"""
  38. items: List[AIPicture]
  39. total: int
  40. page: int
  41. page_size: int
  42. class ImageGenerationService:
  43. """图片生成服务类"""
  44. # 文生图支持的尺寸
  45. SUPPORTED_SIZES = [
  46. "1024*1024",
  47. "720*1280",
  48. "1280*720"
  49. ]
  50. # 图生图(wan2.6-image)支持的尺寸
  51. IMAGE_EDIT_SIZES = [
  52. "1280*1280",
  53. "1024*1024",
  54. "800*1200",
  55. "1200*800",
  56. "960*1280",
  57. "1280*960",
  58. "720*1280",
  59. "1280*720",
  60. "1344*576"
  61. ]
  62. def __init__(self, db: Session, api_key: str):
  63. self.db = db
  64. self.api_key = api_key
  65. self.oss_service = get_oss_service()
  66. dashscope.base_http_api_url = DASHSCOPE_BASE_URL
  67. def _get_image_models(self):
  68. return self.db.query(ModelNew).filter(
  69. ModelNew.is_show_enabled == True,
  70. ModelNew.is_api_enabled == True,
  71. ModelNew.categories.any(int(ModelCategory.IMAGE_GEN)) |
  72. ModelNew.categories.any(int(ModelCategory.IMAGE_EDIT))
  73. ).all()
  74. def _is_wan26_model(self, model: str) -> bool:
  75. """判断是否为wan2.6系列模型(需要使用新版多模态接口)"""
  76. return model.startswith("wan2.6")
  77. async def text_to_image(
  78. self,
  79. user_id: str,
  80. prompt: str,
  81. model: str = "wanx2.1-t2i-turbo",
  82. n: int = 1,
  83. size: str = "1024*1024",
  84. negative_prompt: Optional[str] = None,
  85. prompt_extend: bool = True,
  86. watermark: bool = False,
  87. seed: Optional[int] = None
  88. ) -> ImageGenerationResult:
  89. """
  90. 文生图:调用DashScope API生成图片
  91. 根据模型版本自动选择调用方式:
  92. - wan2.6系列:使用multimodal-generation接口
  93. - 其他版本:使用ImageSynthesis接口
  94. """
  95. if not prompt or not prompt.strip():
  96. return ImageGenerationResult(
  97. success=False, images=[], bill=Decimal("0"), record_id=0,
  98. error="提示词不能为空"
  99. )
  100. try:
  101. # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己的 apikey
  102. from app.services.crypto_utils import get_effective_api_key
  103. effective_api_key = get_effective_api_key(self.db, model, self.api_key)
  104. if self._is_wan26_model(model):
  105. # wan2.6使用新版多模态接口
  106. oss_urls = await self._text_to_image_wan26(
  107. prompt=prompt, model=model, n=n, size=size,
  108. negative_prompt=negative_prompt, prompt_extend=prompt_extend,
  109. watermark=watermark, seed=seed, api_key=effective_api_key
  110. )
  111. else:
  112. # 旧版模型使用ImageSynthesis
  113. oss_urls = await self._text_to_image_legacy(
  114. prompt=prompt, model=model, n=n, size=size,
  115. negative_prompt=negative_prompt, api_key=effective_api_key
  116. )
  117. if not oss_urls:
  118. return ImageGenerationResult(
  119. success=False, images=[], bill=Decimal("0"), record_id=0,
  120. error="未能获取生成的图片"
  121. )
  122. bill = Decimal("0")
  123. record = self._save_generation_record(
  124. user_id=user_id, model_id=model, model_name=model,
  125. input_type="text", input_data=prompt,
  126. image_count=len(oss_urls), output_images=oss_urls, bill=bill
  127. )
  128. return ImageGenerationResult(
  129. success=True, images=oss_urls, bill=bill, record_id=record.id
  130. )
  131. except HTTPException:
  132. raise
  133. except Exception as e:
  134. logger.error(f"文生图失败: {str(e)}")
  135. return ImageGenerationResult(
  136. success=False, images=[], bill=Decimal("0"), record_id=0,
  137. error=str(e)
  138. )
  139. async def _text_to_image_wan26(
  140. self, prompt: str, model: str, n: int, size: str,
  141. negative_prompt: Optional[str], prompt_extend: bool,
  142. watermark: bool, seed: Optional[int], api_key: Optional[str] = None
  143. ) -> List[str]:
  144. """wan2.6系列文生图(使用multimodal-generation接口)"""
  145. effective_key = api_key or self.api_key
  146. request_body = {
  147. "model": model,
  148. "input": {
  149. "messages": [
  150. {"role": "user", "content": [{"text": prompt}]}
  151. ]
  152. },
  153. "parameters": {
  154. "n": n,
  155. "size": size,
  156. "prompt_extend": prompt_extend,
  157. "watermark": watermark
  158. }
  159. }
  160. if negative_prompt:
  161. request_body["parameters"]["negative_prompt"] = negative_prompt
  162. if seed is not None:
  163. request_body["parameters"]["seed"] = seed
  164. url = f"{DASHSCOPE_BASE_URL}/services/aigc/multimodal-generation/generation"
  165. headers = {
  166. "Content-Type": "application/json",
  167. "Authorization": f"Bearer {effective_key}"
  168. }
  169. async with httpx.AsyncClient(timeout=180.0) as client:
  170. response = await client.post(url, json=request_body, headers=headers)
  171. result = response.json()
  172. if "code" in result:
  173. raise Exception(f"API调用失败: {result.get('message', result.get('code'))}")
  174. oss_urls = []
  175. choices = result.get("output", {}).get("choices", [])
  176. for choice in choices:
  177. content_list = choice.get("message", {}).get("content", [])
  178. for item in content_list:
  179. if item.get("type") == "image":
  180. original_url = item.get("image")
  181. if original_url:
  182. oss_url = await self.oss_service.upload_from_url(original_url, "ai-images")
  183. oss_urls.append(oss_url)
  184. return oss_urls
  185. async def _text_to_image_legacy(
  186. self, prompt: str, model: str, n: int, size: str,
  187. negative_prompt: Optional[str], api_key: Optional[str] = None
  188. ) -> List[str]:
  189. """旧版模型文生图(使用ImageSynthesis接口)"""
  190. effective_key = api_key or self.api_key
  191. response = ImageSynthesis.call(
  192. api_key=effective_key,
  193. model=model,
  194. prompt=prompt,
  195. negative_prompt=negative_prompt,
  196. n=n,
  197. size=size
  198. )
  199. if response.status_code != 200:
  200. raise Exception(f"API调用失败: {response.message}")
  201. oss_urls = []
  202. for result in response.output.results:
  203. oss_url = await self.oss_service.upload_from_url(result.url, "ai-images")
  204. oss_urls.append(oss_url)
  205. return oss_urls
  206. async def image_to_image(
  207. self,
  208. user_id: str,
  209. image_urls: List[str],
  210. prompt: str,
  211. model: str = "wan2.6-image",
  212. n: int = 1,
  213. size: str = "1280*1280",
  214. negative_prompt: Optional[str] = None,
  215. prompt_extend: bool = True,
  216. watermark: bool = False,
  217. seed: Optional[int] = None
  218. ) -> ImageGenerationResult:
  219. """
  220. 图生图:基于参考图片生成新图片(适配wan2.6-image模型)
  221. 使用HTTP同步调用 multimodal-generation 接口
  222. Args:
  223. user_id: 用户ID
  224. image_urls: 参考图片URL列表(1~4张)
  225. prompt: 文本提示词
  226. model: 模型名称
  227. n: 生成图片数量(1~4)
  228. size: 图片尺寸
  229. negative_prompt: 反向提示词
  230. prompt_extend: 是否开启提示词智能改写
  231. watermark: 是否添加水印
  232. seed: 随机数种子
  233. """
  234. if not prompt or not prompt.strip():
  235. return ImageGenerationResult(
  236. success=False, images=[], bill=Decimal("0"), record_id=0,
  237. error="提示词不能为空"
  238. )
  239. if not image_urls or len(image_urls) < 1:
  240. return ImageGenerationResult(
  241. success=False, images=[], bill=Decimal("0"), record_id=0,
  242. error="图像编辑模式必须提供至少1张参考图片"
  243. )
  244. if len(image_urls) > 4:
  245. return ImageGenerationResult(
  246. success=False, images=[], bill=Decimal("0"), record_id=0,
  247. error="最多支持4张参考图片"
  248. )
  249. try:
  250. # 构建content数组:先text后images
  251. content = [{"text": prompt}]
  252. for img_url in image_urls:
  253. content.append({"image": img_url})
  254. # 构建请求体
  255. request_body = {
  256. "model": model,
  257. "input": {
  258. "messages": [
  259. {"role": "user", "content": content}
  260. ]
  261. },
  262. "parameters": {
  263. "n": n,
  264. "size": size,
  265. "enable_interleave": False,
  266. "prompt_extend": prompt_extend,
  267. "watermark": watermark
  268. }
  269. }
  270. if negative_prompt:
  271. request_body["parameters"]["negative_prompt"] = negative_prompt
  272. if seed is not None:
  273. request_body["parameters"]["seed"] = seed
  274. # 调用HTTP同步接口
  275. url = f"{DASHSCOPE_BASE_URL}/services/aigc/multimodal-generation/generation"
  276. headers = {
  277. "Content-Type": "application/json",
  278. "Authorization": f"Bearer {self.api_key}"
  279. }
  280. async with httpx.AsyncClient(timeout=180.0) as client:
  281. response = await client.post(url, json=request_body, headers=headers)
  282. result = response.json()
  283. # 检查错误
  284. if "code" in result:
  285. return ImageGenerationResult(
  286. success=False, images=[], bill=Decimal("0"), record_id=0,
  287. error=f"API调用失败: {result.get('message', result.get('code'))}"
  288. )
  289. # 解析响应
  290. output = result.get("output", {})
  291. choices = output.get("choices", [])
  292. oss_urls = []
  293. for choice in choices:
  294. message = choice.get("message", {})
  295. content_list = message.get("content", [])
  296. for item in content_list:
  297. # 兼容两种格式:{"type":"image","image":"url"} 和 {"image":"url"}
  298. if item.get("type") == "image" or ("image" in item and "text" not in item):
  299. original_url = item.get("image")
  300. if original_url:
  301. oss_url = await self.oss_service.upload_from_url(original_url, "ai-images")
  302. oss_urls.append(oss_url)
  303. if not oss_urls:
  304. logger.error(f"图生图响应解析失败,完整响应: {result}")
  305. return ImageGenerationResult(
  306. success=False, images=[], bill=Decimal("0"), record_id=0,
  307. error="未能获取生成的图片"
  308. )
  309. # 计算费用
  310. bill = Decimal("0")
  311. # 保存记录(图生图时input_data保存提示词,而不是图片URL)
  312. record = self._save_generation_record(
  313. user_id=user_id,
  314. model_id=model,
  315. model_name=model,
  316. input_type="image",
  317. input_data=prompt, # 保存提示词而不是图片URL
  318. image_count=len(oss_urls),
  319. output_images=oss_urls,
  320. bill=bill
  321. )
  322. return ImageGenerationResult(
  323. success=True,
  324. images=oss_urls,
  325. bill=bill,
  326. record_id=record.id
  327. )
  328. except httpx.TimeoutException:
  329. logger.error("图生图请求超时")
  330. return ImageGenerationResult(
  331. success=False, images=[], bill=Decimal("0"), record_id=0,
  332. error="请求超时,请稍后重试"
  333. )
  334. except HTTPException:
  335. raise
  336. except Exception as e:
  337. logger.error(f"图生图失败: {str(e)}")
  338. return ImageGenerationResult(
  339. success=False, images=[], bill=Decimal("0"), record_id=0,
  340. error=str(e)
  341. )
  342. def _save_generation_record(
  343. self,
  344. user_id: str,
  345. model_id: str,
  346. model_name: str,
  347. input_type: str,
  348. input_data: str,
  349. image_count: int,
  350. output_images: List[str],
  351. bill: Decimal
  352. ) -> AIPicture:
  353. """保存生成记录到数据库"""
  354. record = AIPicture(
  355. model_id=model_id,
  356. model_name=model_name,
  357. user_id=user_id,
  358. input_type=input_type,
  359. input_data=input_data,
  360. image_count=image_count,
  361. output_images=output_images,
  362. bill=bill,
  363. status="success"
  364. )
  365. self.db.add(record)
  366. self.db.commit()
  367. self.db.refresh(record)
  368. return record
  369. def get_text_to_image_models(self) -> List[ImageModelInfo]:
  370. """获取文生图模型列表:有 IMAGE_GEN 且没有 IMAGE_EDIT(不需要参考图)"""
  371. models = self._get_image_models()
  372. result = []
  373. for model in models:
  374. cats = model.categories or []
  375. if int(ModelCategory.IMAGE_GEN) not in cats:
  376. continue
  377. if int(ModelCategory.IMAGE_EDIT) in cats:
  378. continue
  379. price = Decimal("0")
  380. result.append(ImageModelInfo(
  381. model_id=model.model_code,
  382. model_name=model.display_name,
  383. description=model.description or "",
  384. price_per_image=price,
  385. supported_sizes=self.SUPPORTED_SIZES
  386. ))
  387. return result
  388. def get_image_to_image_models(self) -> List[ImageModelInfo]:
  389. """获取图生图模型列表(categories 包含 IMAGE_EDIT(6))"""
  390. models = self._get_image_models()
  391. result = []
  392. for model in models:
  393. cats = model.categories or []
  394. if int(ModelCategory.IMAGE_EDIT) not in cats:
  395. continue
  396. price = Decimal("0")
  397. result.append(ImageModelInfo(
  398. model_id=model.model_code,
  399. model_name=model.display_name,
  400. description=model.description or "",
  401. price_per_image=price,
  402. supported_sizes=self.SUPPORTED_SIZES
  403. ))
  404. return result
  405. def get_user_history(
  406. self,
  407. user_id: str,
  408. page: int = 1,
  409. page_size: int = 20
  410. ) -> ImageHistoryResult:
  411. """
  412. 获取用户历史记录
  413. Args:
  414. user_id: 用户ID
  415. page: 页码
  416. page_size: 每页数量
  417. """
  418. query = self.db.query(AIPicture).filter(
  419. AIPicture.user_id == user_id,
  420. AIPicture.review_status != 'rejected' # 排除被拒绝的内容
  421. ).order_by(AIPicture.created_at.desc())
  422. total = query.count()
  423. offset = (page - 1) * page_size
  424. items = query.offset(offset).limit(page_size).all()
  425. return ImageHistoryResult(
  426. items=items,
  427. total=total,
  428. page=page,
  429. page_size=page_size
  430. )