openai_compat_service.py 120 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867
  1. """
  2. OpenAI 兼容服务层
  3. 完整支持 /v1/chat/completions 和 /v1/models 接口
  4. 支持多种模型提供商的自动适配
  5. """
  6. import asyncio
  7. import logging
  8. import time
  9. import uuid
  10. from decimal import Decimal
  11. import httpx
  12. from fastapi import UploadFile
  13. from app.services.api_call_log_service import ApiCallLogService
  14. from app.services.image_service import ImageGenerationService
  15. from app.services.model_adapters import BaseAdapter, get_adapter, ModelProvider
  16. from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
  17. from sqlalchemy import desc
  18. from sqlalchemy.orm import Session
  19. from app.models.model import ModelNew as Model, ModelPriceNew as ModelPrice
  20. from app.models.user import User
  21. from app.schemas.openai_compat import (
  22. ChatCompletionsRequest,
  23. ModelInfo,
  24. EmbeddingsRequest,
  25. EmbeddingsResponse,
  26. EmbeddingData,
  27. Usage,
  28. ModelsListResponse,
  29. ImageGenerationRequest,
  30. ImageGenerationResponse,
  31. ImageData,
  32. AudioTranscriptionResponse,
  33. AudioSpeechRequest,
  34. VideoGenerationRequest,
  35. VideoGenerationResponse,
  36. RerankRequest,
  37. RerankResponse,
  38. RerankResult,
  39. )
  40. from app.services.crypto_utils import decrypt_api_key
  41. from app.services.system_config_manager import get_config_bool
  42. logger = logging.getLogger(__name__)
  43. # ─────────────────────────────────────────────
  44. # 工具函数
  45. # ─────────────────────────────────────────────
  46. def parse_video_size(size: str) -> Tuple[str, str]:
  47. """
  48. 解析视频尺寸,支持多种格式并转换为OpenAI标准格式和内部格式
  49. Args:
  50. size: 视频尺寸,支持以下格式:
  51. - OpenAI格式: "1280x720", "1920x1080", "720x1280"
  52. - 简写格式: "720P", "1080P", "720p", "1080p"
  53. Returns:
  54. (openai_format, internal_format)
  55. 例如: ("1280x720", "720P")
  56. Raises:
  57. ValueError: 如果格式无效
  58. """
  59. import re
  60. size = size.strip()
  61. # 如果是OpenAI格式 (widthxheight)
  62. if 'x' in size.lower():
  63. match = re.match(r'^(\d+)x(\d+)$', size.lower())
  64. if not match:
  65. raise ValueError(f"Invalid size format: {size}. Expected format: 1280x720")
  66. width, height = int(match.group(1)), int(match.group(2))
  67. # 推断简写格式(基于高度)
  68. if height == 720:
  69. internal = "720P"
  70. elif height == 1080:
  71. internal = "1080P"
  72. else:
  73. # 对于非标准分辨率,使用高度作为简写
  74. internal = f"{height}P"
  75. return (f"{width}x{height}", internal)
  76. # 如果是简写格式 (720P, 1080P)
  77. else:
  78. match = re.match(r'^(\d+)p$', size.lower())
  79. if not match:
  80. raise ValueError(f"Invalid size format: {size}. Expected format: 720P or 1280x720")
  81. height = int(match.group(1))
  82. size_upper = f"{height}P"
  83. # 标准分辨率映射(16:9比例)
  84. if height == 720:
  85. return ("1280x720", "720P")
  86. elif height == 1080:
  87. return ("1920x1080", "1080P")
  88. else:
  89. # 对于非标准分辨率,假设16:9比例
  90. width = int(height * 16 / 9)
  91. return (f"{width}x{height}", size_upper)
  92. class OpenAICompatError(Exception):
  93. """OpenAI 兼容服务错误"""
  94. def __init__(self, status_code: int, message: str, error_type: str = "invalid_request_error"):
  95. self.status_code = status_code
  96. self.message = message
  97. self.error_type = error_type
  98. super().__init__(message)
  99. class OpenAICompatService:
  100. """OpenAI API 兼容服务"""
  101. def __init__(self, db: Session):
  102. self.db = db
  103. self._user_cache: dict = {} # user_id → User,请求内缓存,避免重复查询
  104. def _get_user(self, user_id: str):
  105. """获取用户对象,同一请求内缓存,避免重复查询 users 表。"""
  106. if user_id not in self._user_cache:
  107. self._user_cache[user_id] = self.db.query(User).filter(User.id == user_id).first()
  108. return self._user_cache[user_id]
  109. # ─────────────────────────────────────────────
  110. # 主入口
  111. # ─────────────────────────────────────────────
  112. async def chat_completions(
  113. self,
  114. request: ChatCompletionsRequest,
  115. user_id: str,
  116. api_key_id: int,
  117. request_ip: Optional[str] = None,
  118. ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
  119. """
  120. 处理 Chat Completions 请求,包含日志记录与扣费。
  121. 权限检查流程:
  122. 1. 验证模型是否存在
  123. 2. 检查模型 is_api_enabled(云端模型)
  124. 3. 检查用户余额(云端模型)
  125. 4. 获取用户 API Key(云端模型)
  126. 5. 检查用户对本地模型的访问权限(本地模型)
  127. """
  128. log_service = ApiCallLogService(self.db)
  129. model = self._find_model(request.model, user_id)
  130. if not model:
  131. raise OpenAICompatError(
  132. status_code=404,
  133. message=f"The model '{request.model}' does not exist",
  134. error_type="model_not_found",
  135. )
  136. if not model.is_local and not model.is_api_enabled:
  137. raise OpenAICompatError(
  138. status_code=403,
  139. message=f"Model '{request.model}' does not support API calls",
  140. error_type="model_not_available",
  141. )
  142. # realtime 模型使用 WebSocket 实时流协议,不支持此 REST 接口
  143. if "realtime" in request.model.lower() and not model.is_local:
  144. raise OpenAICompatError(
  145. status_code=400,
  146. message=f"Model '{request.model}' is a real-time streaming model that uses WebSocket protocol. "
  147. f"It cannot be called via /api/v1/chat/completions. "
  148. f"Please use the WebSocket API instead.",
  149. error_type="model_not_supported",
  150. )
  151. # 检查本地模型的访问权限
  152. if model.is_local:
  153. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  154. permission_service = UserLocalModelPermissionService(self.db)
  155. if not await permission_service.check_user_model_access(user_id, model.id):
  156. raise OpenAICompatError(
  157. status_code=403,
  158. message=f"You don't have permission to access model '{request.model}'",
  159. error_type="permission_error",
  160. )
  161. user_api_key: Optional[str] = None
  162. if not model.is_local:
  163. user = self._get_user(user_id)
  164. if not user:
  165. raise OpenAICompatError(
  166. status_code=401,
  167. message="User not found",
  168. error_type="authentication_error",
  169. )
  170. # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己配置的 apikey
  171. if model.encrypted_api_key:
  172. from app.services.crypto_utils import decrypt_api_key
  173. decrypted = decrypt_api_key(model.encrypted_api_key)
  174. user_api_key = decrypted if decrypted else None
  175. if not user_api_key:
  176. user_api_key = user.apikey
  177. if not user_api_key:
  178. raise OpenAICompatError(
  179. status_code=403,
  180. message="User API key not configured. Please configure your API key in settings.",
  181. error_type="api_key_not_configured",
  182. )
  183. # ── OCR 模型校验:必须包含图片 ────────────────────────────────────
  184. OCR_MODELS = ("qwen-vl-ocr",)
  185. if model.model_code in OCR_MODELS and not model.is_local:
  186. has_image = any(
  187. isinstance(msg.content, list) and
  188. any(isinstance(part, dict) and part.get("type") == "image_url"
  189. for part in msg.content)
  190. for msg in request.messages
  191. )
  192. # 也兼容 Pydantic 对象形式
  193. if not has_image:
  194. from app.schemas.openai_compat import ContentPartImage
  195. has_image = any(
  196. isinstance(msg.content, list) and
  197. any(isinstance(part, ContentPartImage) for part in msg.content)
  198. for msg in request.messages
  199. )
  200. if not has_image:
  201. raise OpenAICompatError(
  202. status_code=400,
  203. message=f"Model '{model.model_code}' is an OCR model and requires at least one image in the messages. "
  204. f"Please include an image_url content part in your user message.",
  205. error_type="invalid_request_error",
  206. )
  207. # ── 流式 ──────────────────────────────────────────────────────────
  208. if request.stream:
  209. raw_stream = await self._call_local_model(model, request) if model.is_local \
  210. else await self._call_cloud_model(model, request, user_api_key)
  211. async def _stream_with_billing() -> AsyncGenerator[str, None]:
  212. input_text = "".join(
  213. [m.content for m in request.messages if isinstance(m.content, str)]
  214. )
  215. input_tokens = max(int(len(input_text) * 1.2), 1)
  216. output_tokens = 0
  217. stream_error: Optional[Exception] = None
  218. try:
  219. async for chunk in raw_stream:
  220. yield chunk
  221. if isinstance(chunk, str) and chunk.startswith("data: ") \
  222. and "data: [DONE]" not in chunk:
  223. try:
  224. import json as _json
  225. data_dict = _json.loads(chunk[6:])
  226. delta = data_dict.get("choices", [{}])[0] \
  227. .get("delta", {}).get("content", "")
  228. if delta:
  229. output_tokens += max(int(len(delta) * 1.2), 1)
  230. # 优先使用上游返回的真实 usage(部分模型在最后一个 chunk 里带)
  231. usage = data_dict.get("usage")
  232. if usage:
  233. input_tokens = usage.get("prompt_tokens", input_tokens)
  234. output_tokens = usage.get("completion_tokens", output_tokens)
  235. except Exception:
  236. pass
  237. except (GeneratorExit, asyncio.CancelledError):
  238. # 客户端断开/任务取消 - 仍按已产生的 token 扣费
  239. raise
  240. except Exception as exc:
  241. stream_error = exc
  242. raise
  243. finally:
  244. # 关键:无论流正常结束、客户端中断、还是异常都要扣费
  245. # 防止 "token 已被消耗但未扣费" 的资损
  246. try:
  247. if stream_error is not None:
  248. # 上游错误 - 记录失败日志
  249. log_service.create_log(
  250. user_id=user_id, api_key_id=api_key_id,
  251. model_id=model.id, model_name=request.model,
  252. is_local=model.is_local,
  253. input_tokens=0, output_tokens=0,
  254. bill=0, status="failed",
  255. error_message=str(stream_error), request_ip=request_ip,
  256. )
  257. else:
  258. bill = self.calculate_bill(model, input_tokens, output_tokens)
  259. log_service.create_log(
  260. user_id=user_id, api_key_id=api_key_id,
  261. model_id=model.id, model_name=request.model,
  262. is_local=model.is_local,
  263. input_tokens=input_tokens, output_tokens=output_tokens,
  264. bill=float(bill), status="success", request_ip=request_ip,
  265. )
  266. except Exception as fin_exc:
  267. logger.error("流式响应收尾日志记录失败: %s", fin_exc)
  268. return _stream_with_billing()
  269. # ── 非流式 ────────────────────────────────────────────────────────
  270. try:
  271. result = await self._call_local_model(model, request) if model.is_local \
  272. else await self._call_cloud_model(model, request, user_api_key)
  273. input_tokens, output_tokens = self.extract_usage_from_response(result)
  274. bill = self.calculate_bill(model, input_tokens, output_tokens)
  275. log_service.create_log(
  276. user_id=user_id, api_key_id=api_key_id,
  277. model_id=model.id, model_name=request.model,
  278. is_local=model.is_local,
  279. input_tokens=input_tokens, output_tokens=output_tokens,
  280. bill=float(bill), status="success", request_ip=request_ip,
  281. )
  282. return result
  283. except OpenAICompatError:
  284. raise
  285. except Exception as exc:
  286. error_msg = str(exc) or repr(exc)
  287. logger.warning(
  288. "[CHAT_COMPLETION_FAILED] model=%s user_id=%s error_type=%s error=%s",
  289. request.model, user_id, type(exc).__name__, error_msg,
  290. )
  291. log_service.create_log(
  292. user_id=user_id, api_key_id=api_key_id,
  293. model_id=model.id, model_name=request.model,
  294. is_local=model.is_local,
  295. input_tokens=0, output_tokens=0,
  296. bill=0, status="failed",
  297. error_message=error_msg, request_ip=request_ip,
  298. )
  299. raise OpenAICompatError(status_code=500, message=error_msg, error_type="upstream_error")
  300. # ─────────────────────────────────────────────
  301. # 构建请求体(完整参数透传)
  302. # ─────────────────────────────────────────────
  303. def _build_request_body(
  304. self, request: ChatCompletionsRequest, model_name: str
  305. ) -> Dict[str, Any]:
  306. """将 ChatCompletionsRequest 转换为上游 API 请求体,透传所有非 None 参数"""
  307. # 序列化消息(content 支持 str 或 list)
  308. messages = []
  309. for msg in request.messages:
  310. m: Dict[str, Any] = {"role": msg.role}
  311. if msg.content is None:
  312. m["content"] = None
  313. elif isinstance(msg.content, str):
  314. m["content"] = msg.content
  315. else:
  316. # 多模态内容列表
  317. parts = []
  318. for part in msg.content:
  319. part_dict = part.model_dump(exclude_none=True)
  320. # 校验 image_url 格式
  321. if part_dict.get("type") == "image_url":
  322. url = (part_dict.get("image_url") or {}).get("url", "")
  323. if url and not url.startswith("data:"):
  324. import os
  325. ext = os.path.splitext(url.split("?")[0].lower())[1]
  326. SUPPORTED_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif', '.bmp'}
  327. if ext and ext not in SUPPORTED_IMAGE_EXTS:
  328. raise OpenAICompatError(
  329. status_code=400,
  330. message=f"Unsupported image format '{ext}'. "
  331. f"Supported formats: jpg, jpeg, png, webp, gif, bmp.",
  332. error_type="invalid_request_error",
  333. )
  334. parts.append(part_dict)
  335. m["content"] = parts
  336. if msg.name is not None:
  337. m["name"] = msg.name
  338. if msg.tool_calls is not None:
  339. m["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
  340. if msg.tool_call_id is not None:
  341. m["tool_call_id"] = msg.tool_call_id
  342. messages.append(m)
  343. body: Dict[str, Any] = {
  344. "model": model_name,
  345. "messages": messages,
  346. "stream": request.stream,
  347. }
  348. # 可选参数:只有非 None 才透传
  349. optional_fields = [
  350. "temperature", "top_p", "n", "stop",
  351. "presence_penalty", "frequency_penalty",
  352. "logit_bias", "logprobs", "top_logprobs",
  353. "seed", "user", "service_tier", "store", "metadata",
  354. "parallel_tool_calls",
  355. ]
  356. for field in optional_fields:
  357. val = getattr(request, field, None)
  358. if val is not None:
  359. body[field] = val
  360. # max_tokens / max_completion_tokens(优先新版)
  361. if request.max_completion_tokens is not None:
  362. body["max_completion_tokens"] = request.max_completion_tokens
  363. elif request.max_tokens is not None:
  364. body["max_tokens"] = request.max_tokens
  365. # 流式选项
  366. if request.stream and request.stream_options is not None:
  367. body["stream_options"] = request.stream_options.model_dump(exclude_none=True)
  368. # 工具调用
  369. if request.tools is not None:
  370. body["tools"] = [t.model_dump(exclude_none=True) for t in request.tools]
  371. if request.tool_choice is not None:
  372. if isinstance(request.tool_choice, str):
  373. body["tool_choice"] = request.tool_choice
  374. else:
  375. body["tool_choice"] = request.tool_choice.model_dump()
  376. # 响应格式
  377. if request.response_format is not None:
  378. body["response_format"] = request.response_format.model_dump(exclude_none=True)
  379. return body
  380. # ─────────────────────────────────────────────
  381. # 本地模型调用
  382. # ─────────────────────────────────────────────
  383. async def _call_local_model(
  384. self, model: Model, request: ChatCompletionsRequest
  385. ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
  386. # 从缓存获取模型信息
  387. from app.services.cache_service import CacheService
  388. model_data = await CacheService.get_model(model.id)
  389. if model_data:
  390. base_url = model_data.get("base_url", "").rstrip("/")
  391. local_api_key = model_data.get("local_api_key")
  392. else:
  393. # 从数据库获取
  394. base_url = (model.base_url or "").rstrip("/")
  395. local_api_key = model.local_api_key
  396. # 缓存模型信息
  397. await CacheService.set_model(model.id, {
  398. "base_url": base_url,
  399. "local_api_key": local_api_key,
  400. "is_local": model.is_local,
  401. "name": model.name
  402. })
  403. if not base_url:
  404. raise OpenAICompatError(
  405. status_code=500,
  406. message="本地模型未配置 Base URL",
  407. error_type="configuration_error",
  408. )
  409. headers: Dict[str, str] = {"Content-Type": "application/json"}
  410. # 本地模型不使用用户的API密钥,而是使用模型配置的API密钥
  411. if local_api_key:
  412. api_key = decrypt_api_key(local_api_key)
  413. if api_key:
  414. headers["Authorization"] = f"Bearer {api_key}"
  415. # 使用模型的name字段作为实际模型名称
  416. actual_name = model.name
  417. body = self._build_request_body(request, actual_name)
  418. api_url = f"{base_url}/chat/completions"
  419. if request.stream:
  420. return self._stream_response(api_url, headers, body)
  421. else:
  422. return await self._non_stream_response(api_url, headers, body)
  423. # ─────────────────────────────────────────────
  424. # 云端模型调用(阿里云百炼 OpenAI 兼容模式)
  425. # ─────────────────────────────────────────────
  426. async def _call_cloud_model(
  427. self, model: Model, request: ChatCompletionsRequest, user_api_key: str
  428. ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
  429. api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
  430. headers = {
  431. "Content-Type": "application/json",
  432. "Authorization": f"Bearer {user_api_key}",
  433. }
  434. body = self._build_request_body(request, model.model_code)
  435. if request.stream:
  436. return self._stream_response(api_url, headers, body)
  437. else:
  438. return await self._non_stream_response(api_url, headers, body)
  439. # ─────────────────────────────────────────────
  440. # HTTP 请求封装
  441. # ─────────────────────────────────────────────
  442. async def _non_stream_response(
  443. self, api_url: str, headers: Dict[str, str], body: Dict[str, Any],
  444. timeout: float = 300.0,
  445. ) -> Dict[str, Any]:
  446. model_name = body.get("model", "unknown")
  447. try:
  448. async with httpx.AsyncClient(timeout=timeout) as client:
  449. resp = await client.post(api_url, headers=headers, json=body)
  450. if resp.status_code >= 400:
  451. error_msg = self._extract_upstream_error(resp)
  452. logger.warning(
  453. "[UPSTREAM_ERROR] model=%s status=%d url=%s response=%s",
  454. model_name, resp.status_code, api_url, resp.text[:500],
  455. )
  456. raise OpenAICompatError(
  457. status_code=resp.status_code,
  458. message=error_msg,
  459. error_type="upstream_error",
  460. )
  461. return resp.json()
  462. except httpx.ReadTimeout:
  463. logger.warning("[UPSTREAM_TIMEOUT] model=%s url=%s timeout=%ss", model_name, api_url, timeout)
  464. raise OpenAICompatError(
  465. status_code=504,
  466. message=f"模型 '{model_name}' 响应超时({timeout}s),请稍后重试或换个模型",
  467. error_type="timeout_error",
  468. )
  469. async def _stream_response(
  470. self, api_url: str, headers: Dict[str, str], body: Dict[str, Any]
  471. ) -> AsyncGenerator[str, None]:
  472. """SSE 流式响应生成器"""
  473. model_name = body.get("model", "unknown")
  474. try:
  475. async with httpx.AsyncClient(
  476. timeout=httpx.Timeout(30.0, read=None)
  477. ) as client:
  478. async with client.stream("POST", api_url, headers=headers, json=body) as resp:
  479. if resp.status_code >= 400:
  480. error_body = await resp.aread()
  481. error_text = error_body.decode("utf-8", errors="replace")
  482. logger.warning(
  483. "[UPSTREAM_STREAM_ERROR] model=%s status=%d url=%s response=%s",
  484. model_name, resp.status_code, api_url, error_text[:500],
  485. )
  486. # 使用统一的错误提取方法
  487. error_detail = self._extract_upstream_error(resp)
  488. # 针对特定状态码提供友好提示
  489. if resp.status_code == 401:
  490. raise OpenAICompatError(
  491. status_code=500,
  492. message=f"模型 '{model_name}' 认证失败(401 Unauthorized)。\n"
  493. f"请在管理后台检查该模型的 API Key 配置是否正确。\n"
  494. f"错误详情: {error_detail}",
  495. error_type="authentication_error",
  496. )
  497. elif resp.status_code == 404:
  498. raise OpenAICompatError(
  499. status_code=500,
  500. message=f"模型 '{model_name}' 接口不存在(404 Not Found)。请检查 Base URL 和模型名称是否正确。\n"
  501. f"错误详情: {error_detail}",
  502. error_type="not_found_error",
  503. )
  504. else:
  505. raise OpenAICompatError(
  506. status_code=resp.status_code,
  507. message=f"模型 '{model_name}' 调用失败({resp.status_code}): {error_detail}",
  508. error_type="upstream_error",
  509. )
  510. async for line in resp.aiter_lines():
  511. if line.startswith("data: "):
  512. data = line[6:]
  513. if data.strip() == "[DONE]":
  514. yield "data: [DONE]\n\n"
  515. return
  516. yield f"data: {data}\n\n"
  517. elif line.strip():
  518. yield f"data: {line}\n\n"
  519. except httpx.ConnectTimeout:
  520. logger.warning("[UPSTREAM_CONNECT_TIMEOUT] model=%s url=%s", model_name, api_url)
  521. raise OpenAICompatError(
  522. status_code=504,
  523. message=f"模型 '{model_name}' 连接超时,请稍后重试",
  524. error_type="timeout_error",
  525. )
  526. def _extract_upstream_error(self, resp: httpx.Response) -> str:
  527. """从上游错误响应中提取错误信息,确保始终返回有意义的内容"""
  528. raw_text = ""
  529. try:
  530. raw_text = resp.text
  531. data = resp.json()
  532. # 标准 OpenAI 格式: {"error": {"message": "...", "type": "...", "code": "..."}}
  533. err = data.get("error", {})
  534. if isinstance(err, dict):
  535. msg = err.get("message", "")
  536. err_type = err.get("type", "")
  537. code = err.get("code", "")
  538. parts = [p for p in [msg, f"type={err_type}" if err_type else "", f"code={code}" if code else ""] if p]
  539. if parts:
  540. return " | ".join(parts)
  541. # DashScope 格式: {"code": "...", "message": "...", "request_id": "..."}
  542. code = data.get("code", "")
  543. msg = data.get("message", "")
  544. request_id = data.get("request_id", "")
  545. if msg:
  546. parts = [msg]
  547. if code:
  548. parts.append(f"code={code}")
  549. if request_id:
  550. parts.append(f"request_id={request_id}")
  551. return " | ".join(parts)
  552. # 兜底:返回整个 JSON
  553. if data:
  554. return str(data)
  555. except Exception:
  556. pass
  557. # JSON 解析失败或为空,返回原始文本
  558. if raw_text:
  559. return raw_text[:500]
  560. return f"Upstream error {resp.status_code} (empty response body)"
  561. # ─────────────────────────────────────────────
  562. # Models 列表
  563. # ─────────────────────────────────────────────
  564. def get_available_models(self, user_id: str, key_type: str = "public") -> ModelsListResponse:
  565. """返回用户可用的模型,根据API密钥类型过滤"""
  566. models_data: List[ModelInfo] = []
  567. # 根据密钥类型返回相应的模型
  568. if key_type == "public":
  569. # 公钥只能访问云端模型
  570. cloud_models = (
  571. self.db.query(Model)
  572. .filter(Model.is_local == False, Model.is_api_enabled == True)
  573. .all()
  574. )
  575. for m in cloud_models:
  576. models_data.append(
  577. ModelInfo(
  578. id=m.model_code,
  579. object="model",
  580. created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
  581. owned_by=m.supplier or "platform",
  582. )
  583. )
  584. elif key_type == "local":
  585. # 检查本地模型是否启用
  586. if get_config_bool("enable_local_models", True):
  587. # 如果本地模型启用,返回所有本地模型
  588. local_models = (
  589. self.db.query(Model)
  590. .filter(
  591. Model.is_local == True
  592. )
  593. .all()
  594. )
  595. for m in local_models:
  596. models_data.append(
  597. ModelInfo(
  598. id=f"local:{m.id}",
  599. object="model",
  600. created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
  601. owned_by="local",
  602. )
  603. )
  604. else:
  605. # 如果本地模型未启用,返回用户有权限的模型
  606. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  607. permission_service = UserLocalModelPermissionService(self.db)
  608. permissions = permission_service.get_user_model_permissions(user_id)
  609. permitted_model_ids = [perm["model_id"] for perm in permissions if perm["has_access"]]
  610. local_models = (
  611. self.db.query(Model)
  612. .filter(
  613. Model.is_local == True,
  614. Model.id.in_(permitted_model_ids)
  615. )
  616. .all()
  617. )
  618. for m in local_models:
  619. models_data.append(
  620. ModelInfo(
  621. id=f"local:{m.id}",
  622. object="model",
  623. created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
  624. owned_by="local",
  625. )
  626. )
  627. return ModelsListResponse(object="list", data=models_data)
  628. # ─────────────────────────────────────────────
  629. # 工具方法
  630. # ─────────────────────────────────────────────
  631. async def _handle_local_model_request(
  632. self,
  633. api_url: str,
  634. headers: Dict[str, str],
  635. payload: Dict[str, Any],
  636. model_name: str,
  637. base_url: str,
  638. endpoint_type: str = "chat",
  639. timeout: float = 60.0,
  640. return_raw_response: bool = False
  641. ) -> Union[Dict[str, Any], httpx.Response]:
  642. """
  643. 统一处理本地模型 HTTP 请求,直接透传 OpenAI 格式
  644. 注意:本地模型必须是 OpenAI 兼容的,不做任何格式适配
  645. Args:
  646. api_url: 请求URL
  647. headers: 请求头
  648. payload: 请求体(OpenAI 格式)
  649. model_name: 模型名称(用于错误提示)
  650. base_url: 模型的 base_url
  651. endpoint_type: 端点类型(保留参数,暂未使用)
  652. timeout: 超时时间
  653. return_raw_response: 是否返回原始响应对象(用于处理音频等二进制数据)
  654. Returns:
  655. 响应的 JSON 数据(OpenAI 格式)或原始 httpx.Response 对象
  656. Raises:
  657. OpenAICompatError: 请求失败时抛出
  658. """
  659. async with httpx.AsyncClient(timeout=timeout) as client:
  660. try:
  661. response = await client.post(api_url, headers=headers, json=payload)
  662. response.raise_for_status()
  663. # 如果需要原始响应(例如音频数据),直接返回
  664. if return_raw_response:
  665. return response
  666. # 直接返回 JSON 响应,不做任何转换
  667. result = response.json()
  668. return result
  669. except httpx.HTTPStatusError as e:
  670. if e.response.status_code == 401:
  671. # 认证失败
  672. error_detail = ""
  673. try:
  674. error_data = e.response.json()
  675. if isinstance(error_data, dict):
  676. error_detail = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
  677. except:
  678. error_detail = e.response.text[:200]
  679. raise OpenAICompatError(
  680. status_code=500,
  681. message=f"本地模型 '{model_name}' 认证失败(401 Unauthorized)。\n"
  682. f"请在管理后台检查该模型的 API Key 配置是否正确。\n"
  683. f"错误详情: {error_detail}",
  684. error_type="authentication_error",
  685. )
  686. elif e.response.status_code == 404:
  687. raise OpenAICompatError(
  688. status_code=500,
  689. message=f"本地模型 '{model_name}' 接口不存在(404 Not Found)。请检查 Base URL 和模型名称是否正确。",
  690. error_type="not_found_error",
  691. )
  692. else:
  693. # 其他 HTTP 错误
  694. error_detail = ""
  695. try:
  696. error_data = e.response.json()
  697. if isinstance(error_data, dict):
  698. error_detail = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
  699. except:
  700. error_detail = e.response.text[:200]
  701. raise OpenAICompatError(
  702. status_code=e.response.status_code,
  703. message=f"本地模型调用失败({e.response.status_code}): {error_detail}",
  704. error_type="upstream_error",
  705. )
  706. except httpx.TimeoutException:
  707. raise OpenAICompatError(
  708. status_code=504,
  709. message=f"本地模型 '{model_name}' 请求超时。请检查网络连接或增加超时时间。",
  710. error_type="timeout_error",
  711. )
  712. except httpx.RequestError as e:
  713. raise OpenAICompatError(
  714. status_code=500,
  715. message=f"本地模型 '{model_name}' 请求失败: {str(e)}",
  716. error_type="request_error",
  717. )
  718. def _find_model(self, model_name: str, user_id: str) -> Optional[Model]:
  719. # 优先识别 local:{id} 格式(精确匹配)
  720. if model_name.startswith("local:"):
  721. try:
  722. model_id = int(model_name[6:])
  723. except ValueError:
  724. return None
  725. return (
  726. self.db.query(Model)
  727. .filter(Model.id == model_id, Model.is_local == True)
  728. .first()
  729. )
  730. # supplier/name 格式(本地模型)
  731. if "/" in model_name:
  732. parts = model_name.split("/", 1)
  733. if len(parts) == 2:
  734. supplier, name = parts
  735. local_with_supplier = (
  736. self.db.query(Model)
  737. .filter(
  738. Model.supplier == supplier,
  739. Model.display_name == name,
  740. Model.is_local == True,
  741. )
  742. .order_by(Model.created_at.desc())
  743. .first()
  744. )
  745. if local_with_supplier:
  746. return local_with_supplier
  747. # 云端模型按 model_code 精确匹配
  748. cloud = (
  749. self.db.query(Model)
  750. .filter(Model.model_code == model_name, Model.is_local == False)
  751. .first()
  752. )
  753. if cloud:
  754. return cloud
  755. # 本地模型按 display_name 查找
  756. return (
  757. self.db.query(Model)
  758. .filter(Model.display_name == model_name, Model.is_local == True)
  759. .order_by(Model.created_at.desc())
  760. .first()
  761. )
  762. def calculate_bill(
  763. self, model: Model, input_tokens: int, output_tokens: int
  764. ) -> Decimal:
  765. """API 调用免费,始终返回 0"""
  766. return Decimal("0")
  767. def extract_usage_from_response(self, response: Dict[str, Any]) -> Tuple[int, int]:
  768. usage = response.get("usage", {})
  769. input_tokens = usage.get("prompt_tokens", 0)
  770. output_tokens = usage.get("completion_tokens", 0)
  771. return input_tokens, output_tokens
  772. async def _validate_model_and_balance(self, model_name: str, user_id: str) -> Model:
  773. """
  774. 验证模型状态
  775. Args:
  776. model_name: 模型名称
  777. user_id: 用户ID
  778. Returns:
  779. Model: 验证通过的模型对象
  780. Raises:
  781. OpenAICompatError: 模型不存在或不可用时抛出
  782. """
  783. model = self._find_model(model_name, user_id)
  784. if not model:
  785. raise OpenAICompatError(
  786. status_code=404,
  787. message=f"Model '{model_name}' not found",
  788. error_type="model_not_found"
  789. )
  790. if not model.is_local and not model.is_api_enabled:
  791. raise OpenAICompatError(
  792. status_code=403,
  793. message=f"Model '{model_name}' does not support API calls",
  794. error_type="model_not_available"
  795. )
  796. return model
  797. async def embeddings(
  798. self,
  799. request: EmbeddingsRequest,
  800. user_id: str,
  801. api_key_id: int,
  802. request_ip: Optional[str] = None
  803. ) -> EmbeddingsResponse:
  804. """
  805. 处理文本嵌入(Embeddings)请求
  806. Args:
  807. request: 嵌入请求对象
  808. user_id: 用户ID
  809. api_key_id: 用于日志记录
  810. request_ip: 请求来源IP
  811. Returns:
  812. EmbeddingsResponse: 包含向量数据的响应对象
  813. Raises:
  814. OpenAICompatError: 处理失败或鉴权失败时抛出
  815. """
  816. log_service = ApiCallLogService(self.db)
  817. try:
  818. model = await self._validate_model_and_balance(request.model, user_id)
  819. # 检查模型类型是否支持向量嵌入
  820. from app.models.model import ModelCategory
  821. if not any(int(c) in [int(ModelCategory.EMBEDDING), int(ModelCategory.LLM), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
  822. raise OpenAICompatError(
  823. status_code=400,
  824. message=f"Model '{request.model}' does not support embeddings",
  825. error_type="model_not_supported",
  826. )
  827. # 检查本地模型权限
  828. if model.is_local:
  829. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  830. permission_service = UserLocalModelPermissionService(self.db)
  831. if not await permission_service.check_user_model_access(user_id, model.id):
  832. raise OpenAICompatError(
  833. status_code=403,
  834. message=f"You don't have permission to access model '{request.model}'",
  835. error_type="permission_error",
  836. )
  837. else:
  838. user = self._get_user(user_id)
  839. if not user:
  840. raise OpenAICompatError(
  841. status_code=403,
  842. message="User API key not configured.",
  843. error_type="api_key_not_configured"
  844. )
  845. # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己配置的 apikey
  846. effective_api_key: Optional[str] = None
  847. if model.encrypted_api_key:
  848. from app.services.crypto_utils import decrypt_api_key as _decrypt
  849. decrypted = _decrypt(model.encrypted_api_key)
  850. effective_api_key = decrypted if decrypted else None
  851. if not effective_api_key:
  852. effective_api_key = user.apikey
  853. if not effective_api_key:
  854. raise OpenAICompatError(
  855. status_code=403,
  856. message="User API key not configured.",
  857. error_type="api_key_not_configured"
  858. )
  859. texts = [request.input] if isinstance(request.input, str) else request.input
  860. if model.is_local:
  861. # 从缓存获取模型信息
  862. from app.services.cache_service import CacheService
  863. from app.services.crypto_utils import decrypt_api_key
  864. model_data = await CacheService.get_model(model.id)
  865. if model_data:
  866. base_url = model_data.get("base_url", "").rstrip("/")
  867. local_api_key = model_data.get("local_api_key")
  868. else:
  869. # 从数据库获取
  870. base_url = (model.base_url or "").rstrip("/")
  871. local_api_key = model.local_api_key
  872. # 缓存模型信息
  873. await CacheService.set_model(model.id, {
  874. "base_url": base_url,
  875. "local_api_key": local_api_key,
  876. "is_local": model.is_local,
  877. "name": model.name
  878. })
  879. if not base_url:
  880. raise OpenAICompatError(
  881. status_code=500,
  882. message="本地模型未配置 Base URL",
  883. error_type="configuration_error",
  884. )
  885. # 构建请求头
  886. headers = {"Content-Type": "application/json"}
  887. if local_api_key:
  888. api_key = decrypt_api_key(local_api_key)
  889. if api_key:
  890. headers["Authorization"] = f"Bearer {api_key}"
  891. # 构建请求体(OpenAI 格式)
  892. payload = {
  893. "model": model.name,
  894. "input": texts
  895. }
  896. if request.dimensions:
  897. payload["dimensions"] = request.dimensions
  898. if request.user:
  899. payload["user"] = request.user
  900. # 使用统一的方法发送请求
  901. api_url = f"{base_url}/embeddings"
  902. result_data = await self._handle_local_model_request(
  903. api_url=api_url,
  904. headers=headers,
  905. payload=payload,
  906. model_name=model.name,
  907. base_url=base_url,
  908. endpoint_type="embedding",
  909. timeout=30.0
  910. )
  911. # 处理响应
  912. embeddings_list = result_data.get("data", [])
  913. usage_data = result_data.get("usage", {})
  914. total_tokens = usage_data.get("total_tokens", 0)
  915. data_list = []
  916. for item in embeddings_list:
  917. data_list.append(
  918. EmbeddingData(
  919. index=item.get("index", 0),
  920. embedding=item.get("embedding", [])
  921. )
  922. )
  923. else:
  924. # 云端模型处理
  925. # 根据模型类型选择端点:多模态 embedding vs 文本 embedding
  926. # 多模态模型:名称含 vl/vision/multimodal
  927. code_lower = model.model_code.lower()
  928. is_multimodal = any(kw in code_lower for kw in ("vl", "vision", "multimodal"))
  929. if is_multimodal:
  930. api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/multimodal-embedding/multimodal-embedding"
  931. payload = {
  932. "model": model.model_code,
  933. "input": {
  934. "contents": [{"text": t} for t in texts]
  935. }
  936. }
  937. else:
  938. api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
  939. payload = {
  940. "model": model.model_code,
  941. "input": {
  942. "texts": texts
  943. }
  944. }
  945. if request.dimensions:
  946. payload.setdefault("parameters", {})["dimension"] = request.dimensions
  947. headers = {
  948. "Content-Type": "application/json",
  949. "Authorization": f"Bearer {effective_api_key}"
  950. }
  951. # 多模态 embedding 的 dimension 放在 parameters 里
  952. if is_multimodal and request.dimensions:
  953. payload.setdefault("parameters", {})["dimension"] = request.dimensions
  954. async with httpx.AsyncClient(timeout=30.0) as client:
  955. response = await client.post(api_url, headers=headers, json=payload)
  956. response.raise_for_status()
  957. result_data = response.json()
  958. output = result_data.get("output", {})
  959. embeddings_list = output.get("embeddings", [])
  960. usage_data = result_data.get("usage", {})
  961. total_tokens = usage_data.get("total_tokens", 0)
  962. data_list = []
  963. for item in embeddings_list:
  964. data_list.append(
  965. EmbeddingData(
  966. index=item.get("text_index", 0),
  967. embedding=item.get("embedding", [])
  968. )
  969. )
  970. bill = Decimal("0")
  971. log_service.create_log(
  972. user_id=user_id,
  973. api_key_id=api_key_id,
  974. model_id=model.id if model else None,
  975. model_name=request.model,
  976. is_local=model.is_local if model else False,
  977. input_tokens=total_tokens,
  978. output_tokens=0,
  979. bill=float(bill),
  980. status="success",
  981. request_ip=request_ip
  982. )
  983. return EmbeddingsResponse(
  984. model=request.model,
  985. data=data_list,
  986. usage=Usage(
  987. prompt_tokens=total_tokens,
  988. completion_tokens=0,
  989. total_tokens=total_tokens
  990. )
  991. )
  992. except Exception as e:
  993. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  994. model = self._find_model(request.model, user_id)
  995. log_service.create_log(
  996. user_id=user_id,
  997. api_key_id=api_key_id,
  998. model_id=model.id if model else None,
  999. model_name=request.model,
  1000. is_local=model.is_local if model else False,
  1001. input_tokens=0,
  1002. output_tokens=0,
  1003. bill=0,
  1004. status="failed",
  1005. error_message=error_msg,
  1006. request_ip=request_ip
  1007. )
  1008. if isinstance(e, OpenAICompatError):
  1009. raise e
  1010. raise OpenAICompatError(status_code=500, message=error_msg, error_type="embeddings_error")
  1011. async def image_generations(
  1012. self, request: ImageGenerationRequest, user_id: str, api_key_id: int, request_ip: str = None
  1013. ) -> ImageGenerationResponse:
  1014. """
  1015. 处理图像生成请求
  1016. 调用底层ImageGenerationService完成真实的图像生成与业务扣费。
  1017. Args:
  1018. request: 图像生成请求对象
  1019. user_id: 用户ID
  1020. api_key_id: API Key ID
  1021. request_ip: 请求来源IP
  1022. Returns:
  1023. ImageGenerationResponse: 符合OpenAI规范的图像响应对象
  1024. Raises:
  1025. OpenAICompatError: 处理失败或鉴权失败时抛出
  1026. """
  1027. log_service = ApiCallLogService(self.db)
  1028. try:
  1029. # 验证模型状态与用户基础余额
  1030. model = await self._validate_model_and_balance(request.model, user_id)
  1031. # 检查模型类型是否支持图像生成(文生图)
  1032. from app.models.model import ModelCategory
  1033. if not any(int(c) in [int(ModelCategory.IMAGE_GEN), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
  1034. raise OpenAICompatError(
  1035. status_code=400,
  1036. message=f"Model '{request.model}' does not support image generation. Use a model with category IMAGE_GEN or MULTIMODAL.",
  1037. error_type="model_not_supported",
  1038. )
  1039. # 检查本地模型权限
  1040. if model.is_local:
  1041. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  1042. permission_service = UserLocalModelPermissionService(self.db)
  1043. if not await permission_service.check_user_model_access(user_id, model.id):
  1044. raise OpenAICompatError(
  1045. status_code=403,
  1046. message=f"You don't have permission to access model '{request.model}'",
  1047. error_type="permission_error",
  1048. )
  1049. if model.is_local:
  1050. # 从缓存获取模型信息
  1051. from app.services.cache_service import CacheService
  1052. from app.services.crypto_utils import decrypt_api_key
  1053. model_data = await CacheService.get_model(model.id)
  1054. if model_data:
  1055. base_url = model_data.get("base_url", "").rstrip("/")
  1056. local_api_key = model_data.get("local_api_key")
  1057. else:
  1058. # 从数据库获取
  1059. base_url = (model.base_url or "").rstrip("/")
  1060. local_api_key = model.local_api_key
  1061. # 缓存模型信息
  1062. await CacheService.set_model(model.id, {
  1063. "base_url": base_url,
  1064. "local_api_key": local_api_key,
  1065. "is_local": model.is_local,
  1066. "name": model.name
  1067. })
  1068. if not base_url:
  1069. raise OpenAICompatError(
  1070. status_code=500,
  1071. message="本地模型未配置 Base URL",
  1072. error_type="configuration_error",
  1073. )
  1074. # 构建请求头
  1075. headers = {"Content-Type": "application/json"}
  1076. if local_api_key:
  1077. try:
  1078. api_key = decrypt_api_key(local_api_key)
  1079. if api_key:
  1080. headers["Authorization"] = f"Bearer {api_key}"
  1081. else:
  1082. raise OpenAICompatError(
  1083. status_code=500,
  1084. message="本地模型 API Key 解密失败",
  1085. error_type="configuration_error",
  1086. )
  1087. except Exception as e:
  1088. raise OpenAICompatError(
  1089. status_code=500,
  1090. message=f"本地模型 API Key 处理失败: {str(e)}",
  1091. error_type="configuration_error",
  1092. )
  1093. # 构建请求体(OpenAI 格式)
  1094. payload = {
  1095. "model": model.name,
  1096. "prompt": request.prompt,
  1097. "n": request.n or 1,
  1098. "size": request.size or "1024x1024"
  1099. }
  1100. if request.quality:
  1101. payload["quality"] = request.quality
  1102. if request.response_format:
  1103. payload["response_format"] = request.response_format
  1104. if request.style:
  1105. payload["style"] = request.style
  1106. if request.user:
  1107. payload["user"] = request.user
  1108. # 使用统一的方法发送请求
  1109. api_url = f"{base_url}/images/generations"
  1110. result_data = await self._handle_local_model_request(
  1111. api_url=api_url,
  1112. headers=headers,
  1113. payload=payload,
  1114. model_name=model.name,
  1115. base_url=base_url,
  1116. endpoint_type="image",
  1117. timeout=60.0
  1118. )
  1119. # 处理响应
  1120. images = []
  1121. for item in result_data.get("data", []):
  1122. if item.get("url"):
  1123. images.append(item.get("url"))
  1124. elif item.get("b64_json"):
  1125. # 处理base64编码的图像
  1126. import base64
  1127. from app.services.oss_service import get_oss_service
  1128. oss_service = get_oss_service()
  1129. image_bytes = base64.b64decode(item.get("b64_json"))
  1130. url = oss_service.upload_file(
  1131. image_bytes,
  1132. prefix="ai-images/generations",
  1133. original_filename=f"generated_{time.time()}.png"
  1134. )
  1135. images.append(url)
  1136. if not images:
  1137. raise OpenAICompatError(
  1138. status_code=500,
  1139. message="图像生成失败:未返回图像",
  1140. error_type="image_generation_error",
  1141. )
  1142. result = ImageGenerationResponse(
  1143. created=int(time.time()),
  1144. data=[ImageData(url=url) for url in images]
  1145. )
  1146. bill = Decimal("0")
  1147. else:
  1148. # 云端模型处理
  1149. # 获取用户API Key并实例化底层服务
  1150. user = self._get_user(user_id)
  1151. dashscope_api_key = user.apikey if user and user.apikey else ""
  1152. real_image_service = ImageGenerationService(self.db, api_key=dashscope_api_key)
  1153. # 适配尺寸参数
  1154. mapped_size = request.size.replace("x", "*") if request.size else "1024*1024"
  1155. # 调用底层图像生成服务
  1156. result_obj = await real_image_service.text_to_image(
  1157. user_id=user_id,
  1158. prompt=request.prompt,
  1159. model=model.model_code,
  1160. n=request.n or 1,
  1161. size=mapped_size
  1162. )
  1163. if not result_obj.success:
  1164. raise OpenAICompatError(
  1165. status_code=500,
  1166. message=result_obj.error or "图像生成失败",
  1167. error_type="image_generation_error"
  1168. )
  1169. result = ImageGenerationResponse(
  1170. created=int(time.time()),
  1171. data=[ImageData(url=url) for url in result_obj.images]
  1172. )
  1173. bill = result_obj.bill
  1174. # 记录日志
  1175. log_service.create_log(
  1176. user_id=user_id,
  1177. api_key_id=api_key_id,
  1178. model_id=model.id if model else None,
  1179. model_name=request.model,
  1180. is_local=model.is_local if model else False,
  1181. input_tokens=0,
  1182. output_tokens=len(result.data),
  1183. bill=float(bill),
  1184. status="success",
  1185. request_ip=request_ip
  1186. )
  1187. return result
  1188. except Exception as e:
  1189. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  1190. model = self._find_model(request.model, user_id)
  1191. log_service.create_log(
  1192. user_id=user_id,
  1193. api_key_id=api_key_id,
  1194. model_id=model.id if model else None,
  1195. model_name=request.model,
  1196. is_local=model.is_local if model else False,
  1197. input_tokens=0,
  1198. output_tokens=0,
  1199. bill=0,
  1200. status="failed",
  1201. error_message=error_msg,
  1202. request_ip=request_ip
  1203. )
  1204. if isinstance(e, OpenAICompatError):
  1205. raise e
  1206. raise OpenAICompatError(status_code=500, message=error_msg, error_type="image_generation_error")
  1207. async def image_edits(
  1208. self,
  1209. image: Union[str, UploadFile],
  1210. prompt: str,
  1211. mask: Optional[Union[str, UploadFile]],
  1212. model_name: str,
  1213. n: int,
  1214. size: str,
  1215. user_id: str,
  1216. api_key_id: int,
  1217. request_ip: Optional[str] = None
  1218. ) -> ImageGenerationResponse:
  1219. """
  1220. 处理图像编辑(图生图)请求
  1221. 接收上传图片,转存OSS后调用底层ImageGenerationService处理。
  1222. Args:
  1223. image: 用户上传的原始图片
  1224. prompt: 对新图像的文本描述
  1225. mask: 可选的遮罩图
  1226. model_name: 模型名称
  1227. n: 生成数量
  1228. size: 生成尺寸
  1229. user_id: 用户ID
  1230. api_key_id: API Key ID
  1231. request_ip: 请求来源IP
  1232. Returns:
  1233. ImageGenerationResponse: 包含生成图片URL的响应对象
  1234. Raises:
  1235. OpenAICompatError: 处理失败或鉴权失败时抛出
  1236. 需求: OpenAI兼容-图生图
  1237. """
  1238. log_service = ApiCallLogService(self.db)
  1239. try:
  1240. model = await self._validate_model_and_balance(model_name, user_id)
  1241. # 检查模型类型是否支持图像编辑(图生图)
  1242. from app.models.model import ModelCategory
  1243. if not any(int(c) in [int(ModelCategory.IMAGE_EDIT), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
  1244. raise OpenAICompatError(
  1245. status_code=400,
  1246. message=f"Model '{model_name}' does not support image editing. Use a model with category IMAGE_EDIT or MULTIMODAL.",
  1247. error_type="model_not_supported",
  1248. )
  1249. # 检查本地模型权限
  1250. if model.is_local:
  1251. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  1252. permission_service = UserLocalModelPermissionService(self.db)
  1253. if not await permission_service.check_user_model_access(user_id, model.id):
  1254. raise OpenAICompatError(
  1255. status_code=403,
  1256. message=f"You don't have permission to access model '{model_name}'",
  1257. error_type="permission_error",
  1258. )
  1259. import base64
  1260. from app.services.oss_service import get_oss_service
  1261. oss_service = get_oss_service()
  1262. # 处理图像数据
  1263. if isinstance(image, str):
  1264. # 检查是否是URL
  1265. if image.startswith(('http://', 'https://')):
  1266. # 直接使用URL
  1267. image_url = image
  1268. else:
  1269. # 解码base64字符串
  1270. image_bytes = base64.b64decode(image)
  1271. image_url = oss_service.upload_file(
  1272. image_bytes,
  1273. prefix="ai-images/edits",
  1274. original_filename="edit_source.png"
  1275. )
  1276. else:
  1277. # 处理UploadFile对象
  1278. image_bytes = await image.read()
  1279. image_url = oss_service.upload_file(
  1280. image_bytes,
  1281. prefix="ai-images/edits",
  1282. original_filename=image.filename or "edit_source.png"
  1283. )
  1284. image_urls = [image_url]
  1285. # 处理遮罩数据
  1286. if mask:
  1287. if isinstance(mask, str):
  1288. # 检查是否是URL
  1289. if mask.startswith(('http://', 'https://')):
  1290. # 直接使用URL
  1291. mask_url = mask
  1292. else:
  1293. # 解码base64字符串
  1294. mask_bytes = base64.b64decode(mask)
  1295. mask_url = oss_service.upload_file(
  1296. mask_bytes,
  1297. prefix="ai-images/edits",
  1298. original_filename="edit_mask.png"
  1299. )
  1300. else:
  1301. # 处理UploadFile对象
  1302. mask_bytes = await mask.read()
  1303. mask_url = oss_service.upload_file(
  1304. mask_bytes,
  1305. prefix="ai-images/edits",
  1306. original_filename=mask.filename or "edit_mask.png"
  1307. )
  1308. image_urls.append(mask_url)
  1309. if model.is_local:
  1310. # 从缓存获取模型信息
  1311. from app.services.cache_service import CacheService
  1312. from app.services.crypto_utils import decrypt_api_key
  1313. model_data = await CacheService.get_model(model.id)
  1314. if model_data:
  1315. base_url = model_data.get("base_url", "").rstrip("/")
  1316. local_api_key = model_data.get("local_api_key")
  1317. else:
  1318. # 从数据库获取
  1319. base_url = (model.base_url or "").rstrip("/")
  1320. local_api_key = model.local_api_key
  1321. # 缓存模型信息
  1322. await CacheService.set_model(model.id, {
  1323. "base_url": base_url,
  1324. "local_api_key": local_api_key,
  1325. "is_local": model.is_local,
  1326. "name": model.name
  1327. })
  1328. if not base_url:
  1329. raise OpenAICompatError(
  1330. status_code=500,
  1331. message="本地模型未配置 Base URL",
  1332. error_type="configuration_error",
  1333. )
  1334. # 构建请求头
  1335. headers = {"Content-Type": "application/json"}
  1336. if local_api_key:
  1337. api_key = decrypt_api_key(local_api_key)
  1338. if api_key:
  1339. headers["Authorization"] = f"Bearer {api_key}"
  1340. # 构建请求体(OpenAI 格式)
  1341. payload = {
  1342. "model": model.name,
  1343. "prompt": prompt,
  1344. "n": n,
  1345. "size": size or "1024x1024"
  1346. }
  1347. # 处理图像和遮罩
  1348. if len(image_urls) == 1:
  1349. payload["image"] = image_urls[0]
  1350. elif len(image_urls) == 2:
  1351. payload["image"] = image_urls[0]
  1352. payload["mask"] = image_urls[1]
  1353. if prompt:
  1354. payload["prompt"] = prompt
  1355. # 使用统一的方法发送请求
  1356. api_url = f"{base_url}/images/edits"
  1357. result_data = await self._handle_local_model_request(
  1358. api_url=api_url,
  1359. headers=headers,
  1360. payload=payload,
  1361. model_name=model.name,
  1362. base_url=base_url,
  1363. endpoint_type="image",
  1364. timeout=60.0
  1365. )
  1366. # 处理响应
  1367. images = []
  1368. for item in result_data.get("data", []):
  1369. if item.get("url"):
  1370. images.append(item.get("url"))
  1371. elif item.get("b64_json"):
  1372. # 处理base64编码的图像
  1373. image_bytes = base64.b64decode(item.get("b64_json"))
  1374. url = oss_service.upload_file(
  1375. image_bytes,
  1376. prefix="ai-images/edits",
  1377. original_filename=f"edited_{time.time()}.png"
  1378. )
  1379. images.append(url)
  1380. if not images:
  1381. raise OpenAICompatError(
  1382. status_code=500,
  1383. message="图像编辑失败:未返回图像",
  1384. error_type="image_edit_error",
  1385. )
  1386. result = ImageGenerationResponse(
  1387. created=int(time.time()),
  1388. data=[ImageData(url=url) for url in images]
  1389. )
  1390. bill = Decimal("0")
  1391. else:
  1392. # 云端模型处理
  1393. user = self._get_user(user_id)
  1394. dashscope_api_key = user.apikey if user and user.apikey else ""
  1395. from app.services.image_service import ImageGenerationService
  1396. real_image_service = ImageGenerationService(self.db, api_key=dashscope_api_key)
  1397. mapped_size = size.replace("x", "*") if size else "1024*1024"
  1398. result_obj = await real_image_service.image_to_image(
  1399. user_id=user_id,
  1400. image_urls=image_urls,
  1401. prompt=prompt,
  1402. model=model.model_code,
  1403. n=n,
  1404. size=mapped_size
  1405. )
  1406. if not result_obj.success:
  1407. raise OpenAICompatError(
  1408. status_code=500,
  1409. message=result_obj.error or "图生图编辑失败",
  1410. error_type="image_edit_error"
  1411. )
  1412. result = ImageGenerationResponse(
  1413. created=int(time.time()),
  1414. data=[ImageData(url=url) for url in result_obj.images]
  1415. )
  1416. bill = result_obj.bill
  1417. log_service.create_log(
  1418. user_id=user_id,
  1419. api_key_id=api_key_id,
  1420. model_id=model.id if model else None,
  1421. model_name=model_name,
  1422. is_local=model.is_local if model else False,
  1423. input_tokens=0,
  1424. output_tokens=len(result.data),
  1425. bill=float(bill),
  1426. status="success",
  1427. request_ip=request_ip
  1428. )
  1429. return result
  1430. except Exception as e:
  1431. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  1432. model_obj = self._find_model(model_name, user_id)
  1433. log_service.create_log(
  1434. user_id=user_id,
  1435. api_key_id=api_key_id,
  1436. model_id=model_obj.id if model_obj else None,
  1437. model_name=model_name,
  1438. is_local=model_obj.is_local if model_obj else False,
  1439. input_tokens=0,
  1440. output_tokens=0,
  1441. bill=0,
  1442. status="failed",
  1443. error_message=error_msg,
  1444. request_ip=request_ip
  1445. )
  1446. if isinstance(e, OpenAICompatError):
  1447. raise e
  1448. raise OpenAICompatError(status_code=500, message=error_msg, error_type="image_edit_error")
  1449. async def audio_transcriptions(
  1450. self,
  1451. file: Union[str, UploadFile],
  1452. model_name: str,
  1453. language: Optional[str],
  1454. user_id: str,
  1455. api_key_id: int,
  1456. request_ip: Optional[str] = None
  1457. ) -> AudioTranscriptionResponse:
  1458. """
  1459. 处理语音识别(STT)请求
  1460. 接收上传的音频文件,转换为Base64编码后调用底层ASRService处理。
  1461. 包含模型名称从OpenAI(whisper-1)到DashScope原生模型的映射。
  1462. Args:
  1463. file: 客户端上传的音频文件
  1464. model_name: 模型名称
  1465. language: 语言代码 (ISO-639-1)
  1466. user_id: 用户ID
  1467. api_key_id: API Key ID
  1468. request_ip: 请求来源IP
  1469. Returns:
  1470. AudioTranscriptionResponse: 包含识别文本的响应对象
  1471. Raises:
  1472. OpenAICompatError: 处理失败或鉴权失败时抛出
  1473. 需求: OpenAI兼容-语音转文字
  1474. """
  1475. log_service = ApiCallLogService(self.db)
  1476. try:
  1477. actual_model = "qwen3-asr-flash" if model_name in ["whisper-1", "whisper-large-v3"] else model_name
  1478. # realtime 模型使用 WebSocket 实时流协议,不支持此文件上传接口
  1479. if "realtime" in actual_model.lower():
  1480. raise OpenAICompatError(
  1481. status_code=400,
  1482. message=f"Model '{actual_model}' is a real-time streaming model that uses WebSocket protocol. "
  1483. f"It cannot be called via /api/v1/audio/transcriptions. "
  1484. f"Please use the WebSocket API instead.",
  1485. error_type="model_not_supported",
  1486. )
  1487. model = await self._validate_model_and_balance(actual_model, user_id)
  1488. # 检查模型类型是否支持语音识别
  1489. from app.models.model import ModelCategory
  1490. if not any(int(c) in [int(ModelCategory.STT), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
  1491. raise OpenAICompatError(
  1492. status_code=400,
  1493. message=f"Model '{model_name}' does not support speech transcription",
  1494. error_type="model_not_supported",
  1495. )
  1496. # 检查本地模型权限
  1497. if model.is_local:
  1498. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  1499. permission_service = UserLocalModelPermissionService(self.db)
  1500. if not await permission_service.check_user_model_access(user_id, model.id):
  1501. raise OpenAICompatError(
  1502. status_code=403,
  1503. message=f"You don't have permission to access model '{model_name}'",
  1504. error_type="permission_error",
  1505. )
  1506. import base64
  1507. import httpx
  1508. if isinstance(file, str):
  1509. # 检查是否是URL
  1510. if file.startswith(('http://', 'https://')):
  1511. # 从URL下载音频文件并转换为base64
  1512. async with httpx.AsyncClient() as client:
  1513. response = await client.get(file)
  1514. response.raise_for_status()
  1515. audio_bytes = response.content
  1516. audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
  1517. else:
  1518. # 使用base64字符串
  1519. audio_base64 = file
  1520. else:
  1521. # 处理UploadFile对象
  1522. audio_bytes = await file.read()
  1523. audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
  1524. if model.is_local:
  1525. # 从缓存获取模型信息
  1526. from app.services.cache_service import CacheService
  1527. from app.services.crypto_utils import decrypt_api_key
  1528. model_data = await CacheService.get_model(model.id)
  1529. if model_data:
  1530. base_url = model_data.get("base_url", "").rstrip("/")
  1531. local_api_key = model_data.get("local_api_key")
  1532. else:
  1533. # 从数据库获取
  1534. base_url = (model.base_url or "").rstrip("/")
  1535. local_api_key = model.local_api_key
  1536. # 缓存模型信息
  1537. await CacheService.set_model(model.id, {
  1538. "base_url": base_url,
  1539. "local_api_key": local_api_key,
  1540. "is_local": model.is_local,
  1541. "name": model.name
  1542. })
  1543. if not base_url:
  1544. raise OpenAICompatError(
  1545. status_code=500,
  1546. message="本地模型未配置 Base URL",
  1547. error_type="configuration_error",
  1548. )
  1549. # 构建请求头
  1550. headers = {"Content-Type": "application/json"}
  1551. if local_api_key:
  1552. api_key = decrypt_api_key(local_api_key)
  1553. if api_key:
  1554. headers["Authorization"] = f"Bearer {api_key}"
  1555. # 构建请求体(OpenAI 格式)
  1556. payload = {
  1557. "model": model.name,
  1558. "file": audio_base64
  1559. }
  1560. if language:
  1561. payload["language"] = language
  1562. # 使用统一的方法发送请求
  1563. api_url = f"{base_url}/audio/transcriptions"
  1564. result_data = await self._handle_local_model_request(
  1565. api_url=api_url,
  1566. headers=headers,
  1567. payload=payload,
  1568. model_name=model.name,
  1569. base_url=base_url,
  1570. endpoint_type="audio_stt",
  1571. timeout=60.0
  1572. )
  1573. # 处理响应
  1574. text = result_data.get("text", "")
  1575. duration_seconds = len(audio_base64) // 10000 # 粗略估算
  1576. text_length = len(text)
  1577. bill = Decimal("0")
  1578. else:
  1579. # 云端模型处理
  1580. user = self._get_user(user_id)
  1581. dashscope_api_key = user.apikey if user and user.apikey else ""
  1582. from app.services.asr_service import ASRService
  1583. from app.schemas.audio_schema import ASRRequest
  1584. real_asr_service = ASRService(self.db, user_id=user_id, api_key=dashscope_api_key)
  1585. internal_req = ASRRequest(
  1586. model=actual_model,
  1587. audio_base64=audio_base64,
  1588. language=language
  1589. )
  1590. asr_response = await real_asr_service.recognize(internal_req)
  1591. text = asr_response.text
  1592. duration_seconds = asr_response.usage.seconds if asr_response.usage else 0
  1593. text_length = len(text) if text else 0
  1594. bill = Decimal("0")
  1595. log_service.create_log(
  1596. user_id=user_id,
  1597. api_key_id=api_key_id,
  1598. model_id=model.id if model else None,
  1599. model_name=model_name,
  1600. is_local=model.is_local if model else False,
  1601. input_tokens=duration_seconds,
  1602. output_tokens=text_length,
  1603. bill=float(bill),
  1604. status="success",
  1605. request_ip=request_ip
  1606. )
  1607. return AudioTranscriptionResponse(text=text)
  1608. except Exception as e:
  1609. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  1610. model_obj = self._find_model(model_name, user_id)
  1611. log_service.create_log(
  1612. user_id=user_id,
  1613. api_key_id=api_key_id,
  1614. model_id=model_obj.id if model_obj else None,
  1615. model_name=model_name,
  1616. is_local=model_obj.is_local if model_obj else False,
  1617. input_tokens=0,
  1618. output_tokens=0,
  1619. bill=0,
  1620. status="failed",
  1621. error_message=error_msg,
  1622. request_ip=request_ip
  1623. )
  1624. if isinstance(e, OpenAICompatError):
  1625. raise e
  1626. raise OpenAICompatError(status_code=500, message=error_msg, error_type="stt_error")
  1627. async def audio_speech(
  1628. self,
  1629. request: AudioSpeechRequest,
  1630. user_id: str,
  1631. api_key_id: int,
  1632. request_ip: Optional[str] = None
  1633. ) -> Tuple[AsyncGenerator, str]:
  1634. """
  1635. 处理文字转语音(TTS)请求
  1636. 执行全量音色映射与模型能力校验。
  1637. 通过底层 TTSService 生成语音并转存 OSS,随后转换为流式下发。
  1638. Args:
  1639. request: TTS请求对象
  1640. user_id: 用户ID
  1641. api_key_id: API Key ID
  1642. request_ip: 请求来源IP
  1643. Returns:
  1644. Tuple[AsyncGenerator, str]: 音频二进制生成器与 MIME 类型
  1645. Raises:
  1646. OpenAICompatError: 模型不支持所选音色或处理失败时抛出
  1647. 需求: OpenAI兼容-文字转语音
  1648. """
  1649. log_service = ApiCallLogService(self.db)
  1650. try:
  1651. actual_model = "cosyvoice-v3-flash" if request.model in ["tts-1", "tts-1-hd"] else request.model
  1652. # realtime 模型使用 WebSocket 实时流协议,不支持此 REST 接口
  1653. if "realtime" in actual_model.lower():
  1654. raise OpenAICompatError(
  1655. status_code=400,
  1656. message=f"Model '{actual_model}' is a real-time streaming model that uses WebSocket protocol. "
  1657. f"It cannot be called via /api/v1/audio/speech. "
  1658. f"Please use the WebSocket API instead.",
  1659. error_type="model_not_supported",
  1660. )
  1661. # cosyvoice-clone 系列:voice 参数就是 voice_id,不做映射,直接透传
  1662. is_clone = "clone" in actual_model.lower()
  1663. voice_map = {
  1664. "alloy": "longxiaochun_v3",
  1665. "echo": "longcheng_v3",
  1666. "shimmer": "longwan_v3",
  1667. "onyx": "longhua_v3",
  1668. "nova": "longxiaoxia_v3",
  1669. "fable": "longshu_v3",
  1670. "sunny": "longanyang",
  1671. "lively": "longanhuan",
  1672. "cute_girl": "longhuhu_v3",
  1673. "cute_boy": "longniuniu_v3",
  1674. "bubble": "longpaopao_v3",
  1675. "naughty": "longjielidou_v3",
  1676. "bold_girl": "longxian_v3",
  1677. "cantonese_f": "longjiaxin_v3",
  1678. "cantonese_m": "longanyue_v3",
  1679. "dongbei": "longlaotie_v3",
  1680. "shanbei": "longshange_v3",
  1681. "korean": "loongkyong_v3",
  1682. "japanese": "loongriko_v3",
  1683. "news_m": "longfei_v3",
  1684. "news_f": "longxiaoxia_v3",
  1685. "story_m": "longxiu_v3",
  1686. "story_f": "longmiao_v3",
  1687. "customer_service": "longyingxiao_v3",
  1688. "monkey": "longhouge_v3",
  1689. "robot": "longjiqi_v3",
  1690. "daiyu": "longdaiyu_v3",
  1691. "uncle": "longlaobo_v3",
  1692. "aunt": "longlaoyi_v3"
  1693. }
  1694. actual_voice = request.voice if is_clone else voice_map.get(request.voice.lower(), request.voice)
  1695. if "plus" in actual_model.lower():
  1696. plus_allowed_voices = ["longanyang", "longanhuan"]
  1697. if actual_voice not in plus_allowed_voices:
  1698. raise OpenAICompatError(
  1699. status_code=400,
  1700. message=f"Model '{actual_model}' only supports voices: {plus_allowed_voices}. Requested: '{actual_voice}'.",
  1701. error_type="invalid_request_error"
  1702. )
  1703. model = await self._validate_model_and_balance(actual_model, user_id)
  1704. # 检查模型类型是否支持语音合成
  1705. from app.models.model import ModelCategory
  1706. if not any(int(c) in [int(ModelCategory.TTS), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
  1707. raise OpenAICompatError(
  1708. status_code=400,
  1709. message=f"Model '{request.model}' does not support speech synthesis",
  1710. error_type="model_not_supported",
  1711. )
  1712. # 检查本地模型权限
  1713. if model.is_local:
  1714. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  1715. permission_service = UserLocalModelPermissionService(self.db)
  1716. if not await permission_service.check_user_model_access(user_id, model.id):
  1717. raise OpenAICompatError(
  1718. status_code=403,
  1719. message=f"You don't have permission to access model '{request.model}'",
  1720. error_type="permission_error",
  1721. )
  1722. if model.is_local:
  1723. # 从缓存获取模型信息
  1724. from app.services.cache_service import CacheService
  1725. from app.services.crypto_utils import decrypt_api_key
  1726. model_data = await CacheService.get_model(model.id)
  1727. if model_data:
  1728. base_url = model_data.get("base_url", "").rstrip("/")
  1729. local_api_key = model_data.get("local_api_key")
  1730. else:
  1731. # 从数据库获取
  1732. base_url = (model.base_url or "").rstrip("/")
  1733. local_api_key = model.local_api_key
  1734. # 缓存模型信息
  1735. await CacheService.set_model(model.id, {
  1736. "base_url": base_url,
  1737. "local_api_key": local_api_key,
  1738. "is_local": model.is_local,
  1739. "name": model.name
  1740. })
  1741. if not base_url:
  1742. raise OpenAICompatError(
  1743. status_code=500,
  1744. message="本地模型未配置 Base URL",
  1745. error_type="configuration_error",
  1746. )
  1747. # 构建请求头
  1748. headers = {"Content-Type": "application/json"}
  1749. if local_api_key:
  1750. api_key = decrypt_api_key(local_api_key)
  1751. if api_key:
  1752. headers["Authorization"] = f"Bearer {api_key}"
  1753. # 构建请求体(OpenAI 格式)
  1754. # 注意:本地模型使用原始的 OpenAI 音色名称,不进行映射
  1755. payload = {
  1756. "model": model.name,
  1757. "input": request.input,
  1758. "voice": request.voice, # 使用原始音色名称,不映射
  1759. "response_format": request.response_format or "mp3",
  1760. "speed": request.speed or 1.0
  1761. }
  1762. # 使用统一的适配器方法发送请求(获取原始响应)
  1763. api_url = f"{base_url}/audio/speech"
  1764. response = await self._handle_local_model_request(
  1765. api_url=api_url,
  1766. headers=headers,
  1767. payload=payload,
  1768. model_name=model.name,
  1769. base_url=base_url,
  1770. endpoint_type="audio_tts",
  1771. timeout=60.0,
  1772. return_raw_response=True
  1773. )
  1774. # 检查响应类型
  1775. content_type = response.headers.get("content-type", "")
  1776. import httpx
  1777. if "application/json" in content_type:
  1778. # 响应是 JSON,包含音频 URL
  1779. result_data = response.json()
  1780. audio_url = result_data.get("audio_url") or result_data.get("url")
  1781. if not audio_url:
  1782. raise OpenAICompatError(
  1783. status_code=500,
  1784. message="本地模型返回的 JSON 中未找到音频 URL",
  1785. error_type="invalid_response",
  1786. )
  1787. # 下载音频并流式返回
  1788. async def generate_audio():
  1789. async with httpx.AsyncClient() as client:
  1790. async with client.stream("GET", audio_url) as audio_response:
  1791. audio_response.raise_for_status()
  1792. async for chunk in audio_response.aiter_bytes():
  1793. yield chunk
  1794. else:
  1795. # 响应直接是音频数据(如硅基流动)
  1796. audio_bytes = response.content
  1797. async def generate_audio():
  1798. # 直接返回音频数据
  1799. yield audio_bytes
  1800. media_type = f"audio/{request.response_format or 'mp3'}" if (request.response_format or 'mp3') != 'mp3' else "audio/mpeg"
  1801. bill = Decimal("0")
  1802. else:
  1803. # 云端模型处理
  1804. user = self._get_user(user_id)
  1805. dashscope_api_key = user.apikey if user and user.apikey else ""
  1806. from app.services.tts_service import TTSService
  1807. from app.schemas.audio_schema import TTSRequest
  1808. real_tts_service = TTSService(self.db, user_id=user_id, api_key=dashscope_api_key)
  1809. internal_req = TTSRequest(
  1810. text=request.input,
  1811. model=actual_model,
  1812. voice=actual_voice,
  1813. format=request.response_format or "mp3",
  1814. sample_rate=24000
  1815. )
  1816. tts_response = await real_tts_service.synthesize(internal_req)
  1817. import httpx
  1818. async def generate_audio():
  1819. async with httpx.AsyncClient() as client:
  1820. async with client.stream("GET", tts_response.audio_url) as response:
  1821. response.raise_for_status()
  1822. async for chunk in response.aiter_bytes():
  1823. yield chunk
  1824. media_type = f"audio/{request.response_format}" if request.response_format != 'mp3' else "audio/mpeg"
  1825. bill = Decimal("0")
  1826. log_service.create_log(
  1827. user_id=user_id,
  1828. api_key_id=api_key_id,
  1829. model_id=model.id if model else None,
  1830. model_name=request.model,
  1831. is_local=model.is_local if model else False,
  1832. input_tokens=len(request.input),
  1833. output_tokens=0,
  1834. bill=float(bill),
  1835. status="success",
  1836. request_ip=request_ip
  1837. )
  1838. return generate_audio(), media_type
  1839. except Exception as e:
  1840. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  1841. model_obj = self._find_model(request.model, user_id)
  1842. log_service.create_log(
  1843. user_id=user_id,
  1844. api_key_id=api_key_id,
  1845. model_id=model_obj.id if model_obj else None,
  1846. model_name=request.model,
  1847. is_local=model_obj.is_local if model_obj else False,
  1848. input_tokens=0,
  1849. output_tokens=0,
  1850. bill=0,
  1851. status="failed",
  1852. error_message=error_msg,
  1853. request_ip=request_ip
  1854. )
  1855. if isinstance(e, OpenAICompatError):
  1856. raise e
  1857. raise OpenAICompatError(status_code=500, message=error_msg, error_type="tts_error")
  1858. async def video_generations(
  1859. self,
  1860. request: VideoGenerationRequest,
  1861. user_id: str,
  1862. api_key_id: int,
  1863. request_ip: Optional[str] = None
  1864. ) -> VideoGenerationResponse:
  1865. """
  1866. 处理视频生成请求
  1867. 调用底层VideoService提交异步任务,并通过轮询将其封装为同步阻塞接口。
  1868. Args:
  1869. request: 视频生成请求对象
  1870. user_id: 用户ID
  1871. api_key_id: API Key ID
  1872. request_ip: 请求来源IP
  1873. Returns:
  1874. VideoGenerationResponse: 包含最终视频URL的响应对象
  1875. Raises:
  1876. OpenAICompatError: 模型不支持或生成失败时抛出
  1877. 需求: OpenAI兼容-视频生成
  1878. """
  1879. import time
  1880. import asyncio
  1881. from app.services.video_service import VideoService
  1882. from app.schemas.video_schema import VideoGenerateRequest
  1883. log_service = ApiCallLogService(self.db)
  1884. try:
  1885. model = await self._validate_model_and_balance(request.model, user_id)
  1886. # 检查模型类型是否支持视频生成
  1887. from app.models.model import ModelCategory
  1888. if not any(int(c) in [int(ModelCategory.VIDEO_GEN), int(ModelCategory.MULTIMODAL), int(ModelCategory.LLM)] for c in (model.categories or [])):
  1889. raise OpenAICompatError(
  1890. status_code=400,
  1891. message=f"Model '{request.model}' does not support video generation",
  1892. error_type="model_not_supported",
  1893. )
  1894. # 检查本地模型权限
  1895. if model.is_local:
  1896. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  1897. permission_service = UserLocalModelPermissionService(self.db)
  1898. if not await permission_service.check_user_model_access(user_id, model.id):
  1899. raise OpenAICompatError(
  1900. status_code=403,
  1901. message=f"You don't have permission to access model '{request.model}'",
  1902. error_type="permission_error",
  1903. )
  1904. if model.is_local:
  1905. # 从缓存获取模型信息
  1906. from app.services.cache_service import CacheService
  1907. from app.services.crypto_utils import decrypt_api_key
  1908. model_data = await CacheService.get_model(model.id)
  1909. if model_data:
  1910. base_url = model_data.get("base_url", "").rstrip("/")
  1911. local_api_key = model_data.get("local_api_key")
  1912. else:
  1913. # 从数据库获取
  1914. base_url = (model.base_url or "").rstrip("/")
  1915. local_api_key = model.local_api_key
  1916. # 缓存模型信息
  1917. await CacheService.set_model(model.id, {
  1918. "base_url": base_url,
  1919. "local_api_key": local_api_key,
  1920. "is_local": model.is_local,
  1921. "name": model.name
  1922. })
  1923. if not base_url:
  1924. raise OpenAICompatError(
  1925. status_code=500,
  1926. message="本地模型未配置 Base URL",
  1927. error_type="configuration_error",
  1928. )
  1929. # 构建请求头
  1930. headers = {"Content-Type": "application/json"}
  1931. if local_api_key:
  1932. api_key = decrypt_api_key(local_api_key)
  1933. if api_key:
  1934. headers["Authorization"] = f"Bearer {api_key}"
  1935. # 构建请求体(OpenAI 格式)
  1936. payload = {
  1937. "model": model.name,
  1938. "prompt": request.prompt,
  1939. "size": request.size or "1280x720", # 使用OpenAI标准格式
  1940. "duration": request.duration or 5
  1941. }
  1942. # 使用统一的适配器方法发送请求
  1943. api_url = f"{base_url}/videos/generations"
  1944. result_data = await self._handle_local_model_request(
  1945. api_url=api_url,
  1946. headers=headers,
  1947. payload=payload,
  1948. model_name=model.name,
  1949. base_url=base_url,
  1950. endpoint_type="video",
  1951. timeout=60.0
  1952. )
  1953. # 处理响应
  1954. videos = []
  1955. for item in result_data.get("data", []):
  1956. if item.get("url"):
  1957. videos.append(item.get("url"))
  1958. if not videos:
  1959. raise OpenAICompatError(
  1960. status_code=500,
  1961. message="视频生成失败:未返回视频",
  1962. error_type="video_generation_error",
  1963. )
  1964. # 组装结果
  1965. from app.schemas.openai_compat import VideoData
  1966. result = VideoGenerationResponse(
  1967. created=int(time.time()),
  1968. data=[VideoData(url=url, content_type="video/mp4") for url in videos]
  1969. )
  1970. bill = Decimal("0")
  1971. else:
  1972. # 云端模型处理
  1973. user = self._get_user(user_id)
  1974. from app.services.crypto_utils import get_effective_api_key
  1975. dashscope_api_key = get_effective_api_key(self.db, request.model, user.apikey if user else "") if user else ""
  1976. real_video_service = VideoService(self.db, user_id=int(user_id) if str(user_id).isdigit() else user_id, api_key=dashscope_api_key)
  1977. # 解析并转换size格式
  1978. try:
  1979. openai_size, internal_size = parse_video_size(request.size or "1280x720")
  1980. except ValueError as e:
  1981. raise OpenAICompatError(
  1982. status_code=400,
  1983. message=str(e),
  1984. error_type="invalid_request_error"
  1985. )
  1986. internal_req = VideoGenerateRequest(
  1987. prompt=request.prompt,
  1988. resolution=internal_size, # 使用内部格式 "720P"
  1989. duration=request.duration or 5,
  1990. prompt_extend=True,
  1991. watermark=False
  1992. )
  1993. # 提交异步任务
  1994. task_resp = await real_video_service.generate(internal_req)
  1995. task_id = task_resp.task_id
  1996. # 阻塞轮询
  1997. max_retries = 120
  1998. poll_interval = 5
  1999. final_video_url = None
  2000. for _ in range(max_retries):
  2001. await asyncio.sleep(poll_interval)
  2002. status_result = await real_video_service.get_task_status(task_id)
  2003. if status_result.task_status == "SUCCEEDED":
  2004. final_video_url = status_result.video_url
  2005. break
  2006. elif status_result.task_status == "FAILED":
  2007. raise OpenAICompatError(
  2008. status_code=500,
  2009. message=status_result.error_message or "底层视频生成任务失败",
  2010. error_type="video_generation_error"
  2011. )
  2012. if not final_video_url:
  2013. raise OpenAICompatError(
  2014. status_code=504,
  2015. message="视频生成任务超时,请稍后再试或通过平台任务列表查看结果",
  2016. error_type="timeout_error"
  2017. )
  2018. # 组装结果
  2019. from app.schemas.openai_compat import VideoData
  2020. result = VideoGenerationResponse(
  2021. created=int(time.time()),
  2022. data=[VideoData(url=final_video_url, content_type="video/mp4")]
  2023. )
  2024. bill = Decimal("0")
  2025. log_service.create_log(
  2026. user_id=user_id,
  2027. api_key_id=api_key_id,
  2028. model_id=model.id if model else None,
  2029. model_name=request.model,
  2030. is_local=model.is_local if model else False,
  2031. input_tokens=0,
  2032. output_tokens=request.duration or 5,
  2033. bill=float(bill),
  2034. status="success",
  2035. request_ip=request_ip
  2036. )
  2037. return result
  2038. except Exception as e:
  2039. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  2040. model_obj = self._find_model(request.model, user_id)
  2041. log_service.create_log(
  2042. user_id=user_id,
  2043. api_key_id=api_key_id,
  2044. model_id=model_obj.id if model_obj else None,
  2045. model_name=request.model,
  2046. is_local=model_obj.is_local if model_obj else False,
  2047. input_tokens=0,
  2048. output_tokens=0,
  2049. bill=0,
  2050. status="failed",
  2051. error_message=error_msg,
  2052. request_ip=request_ip
  2053. )
  2054. if isinstance(e, OpenAICompatError):
  2055. raise e
  2056. raise OpenAICompatError(status_code=500, message=error_msg, error_type="video_generation_error")
  2057. async def image_to_video_generations(
  2058. self,
  2059. image: Union[str, UploadFile],
  2060. prompt: str,
  2061. model_name: str,
  2062. size: str,
  2063. user_id: str,
  2064. api_key_id: int,
  2065. request_ip: Optional[str] = None
  2066. ) -> VideoGenerationResponse:
  2067. """
  2068. 处理图生视频(I2V)请求
  2069. 接收上传图片,转存OSS获取URL,调用底层VideoService提交图生视频异步任务,
  2070. 并轮询任务状态封装为同步接口返回。
  2071. Args:
  2072. image: 客户端上传的参考图像
  2073. prompt: 对视频的文本描述
  2074. model_name: 模型名称 (如 wan2.6-i2v)
  2075. size: 视频分辨率 (如 720P)
  2076. user_id: 用户ID
  2077. api_key_id: API Key ID
  2078. request_ip: 请求来源IP
  2079. Returns:
  2080. VideoGenerationResponse: 包含最终视频URL的响应对象
  2081. Raises:
  2082. OpenAICompatError: 处理失败时抛出
  2083. 需求: OpenAI兼容-图生视频
  2084. """
  2085. import time
  2086. import asyncio
  2087. import base64
  2088. from app.services.oss_service import get_oss_service
  2089. log_service = ApiCallLogService(self.db)
  2090. try:
  2091. model = await self._validate_model_and_balance(model_name, user_id)
  2092. # 检查模型类型是否支持视频生成
  2093. from app.models.model import ModelCategory
  2094. if not any(int(c) in [int(ModelCategory.VIDEO_GEN), int(ModelCategory.MULTIMODAL), int(ModelCategory.LLM)] for c in (model.categories or [])):
  2095. raise OpenAICompatError(
  2096. status_code=400,
  2097. message=f"Model '{model_name}' does not support video generation",
  2098. error_type="model_not_supported",
  2099. )
  2100. # 检查本地模型权限
  2101. if model.is_local:
  2102. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  2103. permission_service = UserLocalModelPermissionService(self.db)
  2104. if not await permission_service.check_user_model_access(user_id, model.id):
  2105. raise OpenAICompatError(
  2106. status_code=403,
  2107. message=f"You don't have permission to access model '{model_name}'",
  2108. error_type="permission_error",
  2109. )
  2110. # 处理图像数据
  2111. oss_service = get_oss_service()
  2112. if isinstance(image, str):
  2113. # 检查是否是URL
  2114. if image.startswith(('http://', 'https://')):
  2115. # 直接使用URL
  2116. image_url = image
  2117. else:
  2118. # 解码base64字符串
  2119. image_bytes = base64.b64decode(image)
  2120. image_url = oss_service.upload_file(
  2121. image_bytes,
  2122. prefix="ai-videos/i2v-source",
  2123. original_filename="i2v_source.png"
  2124. )
  2125. else:
  2126. # 处理UploadFile对象
  2127. image_bytes = await image.read()
  2128. image_url = oss_service.upload_file(
  2129. image_bytes,
  2130. prefix="ai-videos/i2v-source",
  2131. original_filename=image.filename or "i2v_source.png"
  2132. )
  2133. if model.is_local:
  2134. # 从缓存获取模型信息
  2135. from app.services.cache_service import CacheService
  2136. from app.services.crypto_utils import decrypt_api_key
  2137. model_data = await CacheService.get_model(model.id)
  2138. if model_data:
  2139. base_url = model_data.get("base_url", "").rstrip("/")
  2140. local_api_key = model_data.get("local_api_key")
  2141. else:
  2142. # 从数据库获取
  2143. base_url = (model.base_url or "").rstrip("/")
  2144. local_api_key = model.local_api_key
  2145. # 缓存模型信息
  2146. await CacheService.set_model(model.id, {
  2147. "base_url": base_url,
  2148. "local_api_key": local_api_key,
  2149. "is_local": model.is_local,
  2150. "name": model.name
  2151. })
  2152. if not base_url:
  2153. raise OpenAICompatError(
  2154. status_code=500,
  2155. message="本地模型未配置 Base URL",
  2156. error_type="configuration_error",
  2157. )
  2158. # 构建请求头
  2159. headers = {"Content-Type": "application/json"}
  2160. if local_api_key:
  2161. api_key = decrypt_api_key(local_api_key)
  2162. if api_key:
  2163. headers["Authorization"] = f"Bearer {api_key}"
  2164. # 构建请求体(OpenAI 格式)
  2165. payload = {
  2166. "model": model.name,
  2167. "prompt": prompt,
  2168. "image": image_url,
  2169. "size": size or "1280x720", # 使用OpenAI标准格式
  2170. "duration": 5
  2171. }
  2172. # 使用统一的适配器方法发送请求
  2173. api_url = f"{base_url}/videos/generations"
  2174. result_data = await self._handle_local_model_request(
  2175. api_url=api_url,
  2176. headers=headers,
  2177. payload=payload,
  2178. model_name=model.name,
  2179. base_url=base_url,
  2180. endpoint_type="video",
  2181. timeout=60.0
  2182. )
  2183. # 处理响应
  2184. videos = []
  2185. for item in result_data.get("data", []):
  2186. if item.get("url"):
  2187. videos.append(item.get("url"))
  2188. if not videos:
  2189. raise OpenAICompatError(
  2190. status_code=500,
  2191. message="图生视频失败:未返回视频",
  2192. error_type="video_generation_error",
  2193. )
  2194. # 组装结果
  2195. from app.schemas.openai_compat import VideoData
  2196. result = VideoGenerationResponse(
  2197. created=int(time.time()),
  2198. data=[VideoData(url=url, content_type="video/mp4") for url in videos]
  2199. )
  2200. bill = Decimal("0")
  2201. else:
  2202. # 云端模型处理
  2203. user = self._get_user(user_id)
  2204. from app.services.crypto_utils import get_effective_api_key
  2205. dashscope_api_key = get_effective_api_key(self.db, model_name, user.apikey if user else "") if user else ""
  2206. from app.services.video_service import VideoService
  2207. from app.schemas.video_schema import VideoGenerateRequest
  2208. real_video_service = VideoService(
  2209. self.db,
  2210. user_id=int(user_id) if str(user_id).isdigit() else user_id,
  2211. api_key=dashscope_api_key
  2212. )
  2213. # 解析并转换size格式
  2214. try:
  2215. openai_size, internal_size = parse_video_size(size or "1280x720")
  2216. except ValueError as e:
  2217. raise OpenAICompatError(
  2218. status_code=400,
  2219. message=str(e),
  2220. error_type="invalid_request_error"
  2221. )
  2222. internal_req = VideoGenerateRequest(
  2223. prompt=prompt,
  2224. first_frame_url=image_url,
  2225. resolution=internal_size, # 使用内部格式 "720P"
  2226. duration=5,
  2227. prompt_extend=True,
  2228. watermark=False
  2229. )
  2230. # 提交异步任务
  2231. task_resp = await real_video_service.generate(internal_req)
  2232. task_id = task_resp.task_id
  2233. # 阻塞轮询
  2234. max_retries = 120
  2235. poll_interval = 5
  2236. final_video_url = None
  2237. for _ in range(max_retries):
  2238. await asyncio.sleep(poll_interval)
  2239. status_result = await real_video_service.get_task_status(task_id)
  2240. if status_result.task_status == "SUCCEEDED":
  2241. final_video_url = status_result.video_url
  2242. break
  2243. elif status_result.task_status == "FAILED":
  2244. raise OpenAICompatError(
  2245. status_code=500,
  2246. message=status_result.error_message or "底层图生视频任务失败",
  2247. error_type="video_generation_error"
  2248. )
  2249. if not final_video_url:
  2250. raise OpenAICompatError(
  2251. status_code=504,
  2252. message="图生视频任务超时,请稍后再试",
  2253. error_type="timeout_error"
  2254. )
  2255. # 组装结果
  2256. from app.schemas.openai_compat import VideoData
  2257. result = VideoGenerationResponse(
  2258. created=int(time.time()),
  2259. data=[VideoData(url=final_video_url, content_type="video/mp4")]
  2260. )
  2261. bill = Decimal("0")
  2262. # 记录兼容层日志
  2263. log_service.create_log(
  2264. user_id=user_id,
  2265. api_key_id=api_key_id,
  2266. model_id=model.id if model else None,
  2267. model_name=model_name,
  2268. is_local=model.is_local if model else False,
  2269. input_tokens=0,
  2270. output_tokens=5,
  2271. bill=float(bill),
  2272. status="success",
  2273. request_ip=request_ip
  2274. )
  2275. return result
  2276. except Exception as e:
  2277. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  2278. model_obj = self._find_model(model_name, user_id)
  2279. log_service.create_log(
  2280. user_id=user_id,
  2281. api_key_id=api_key_id,
  2282. model_id=model_obj.id if model_obj else None,
  2283. model_name=model_name,
  2284. is_local=model_obj.is_local if model_obj else False,
  2285. input_tokens=0,
  2286. output_tokens=0,
  2287. bill=0,
  2288. status="failed",
  2289. error_message=error_msg,
  2290. request_ip=request_ip
  2291. )
  2292. if isinstance(e, OpenAICompatError):
  2293. raise e
  2294. raise OpenAICompatError(status_code=500, message=error_msg, error_type="video_generation_error")
  2295. async def rerank(
  2296. self,
  2297. request: RerankRequest,
  2298. user_id: str,
  2299. api_key_id: int,
  2300. request_ip: Optional[str] = None
  2301. ) -> RerankResponse:
  2302. """
  2303. 处理重排序(Rerank)请求
  2304. Args:
  2305. request: 重排序请求对象
  2306. user_id: 用户ID
  2307. api_key_id: 用于日志记录
  2308. request_ip: 请求来源IP
  2309. Returns:
  2310. RerankResponse: 包含排序结果的响应对象
  2311. Raises:
  2312. OpenAICompatError: 处理失败或鉴权失败时抛出
  2313. """
  2314. log_service = ApiCallLogService(self.db)
  2315. try:
  2316. model = await self._validate_model_and_balance(request.model, user_id)
  2317. # 检查模型类型是否支持重排序
  2318. from app.models.model import ModelCategory
  2319. if not any(int(c) in [int(ModelCategory.RERANK), int(ModelCategory.EMBEDDING), int(ModelCategory.LLM)] for c in (model.categories or [])):
  2320. raise OpenAICompatError(
  2321. status_code=400,
  2322. message=f"Model '{request.model}' does not support reranking",
  2323. error_type="model_not_supported",
  2324. )
  2325. # 检查本地模型权限
  2326. if model.is_local:
  2327. from app.services.user_local_model_permission_service import UserLocalModelPermissionService
  2328. permission_service = UserLocalModelPermissionService(self.db)
  2329. if not await permission_service.check_user_model_access(user_id, model.id):
  2330. raise OpenAICompatError(
  2331. status_code=403,
  2332. message=f"You don't have permission to access model '{request.model}'",
  2333. error_type="permission_error",
  2334. )
  2335. else:
  2336. user = self._get_user(user_id)
  2337. if not user or not user.apikey:
  2338. raise OpenAICompatError(
  2339. status_code=403,
  2340. message="User API key not configured.",
  2341. error_type="api_key_not_configured"
  2342. )
  2343. if model.is_local:
  2344. # 本地模型处理
  2345. from app.services.cache_service import CacheService
  2346. from app.services.crypto_utils import decrypt_api_key
  2347. model_data = await CacheService.get_model(model.id)
  2348. if model_data:
  2349. base_url = model_data.get("base_url", "").rstrip("/")
  2350. local_api_key = model_data.get("local_api_key")
  2351. else:
  2352. # 从数据库获取
  2353. base_url = (model.base_url or "").rstrip("/")
  2354. local_api_key = model.local_api_key
  2355. # 缓存模型信息
  2356. await CacheService.set_model(model.id, {
  2357. "base_url": base_url,
  2358. "local_api_key": local_api_key,
  2359. "is_local": model.is_local,
  2360. "name": model.name
  2361. })
  2362. if not base_url:
  2363. raise OpenAICompatError(
  2364. status_code=500,
  2365. message="本地模型未配置 Base URL",
  2366. error_type="configuration_error",
  2367. )
  2368. # 构建请求头
  2369. headers = {"Content-Type": "application/json"}
  2370. if local_api_key:
  2371. api_key = decrypt_api_key(local_api_key)
  2372. if api_key:
  2373. headers["Authorization"] = f"Bearer {api_key}"
  2374. # 构建请求体(OpenAI 格式)
  2375. payload = {
  2376. "model": model.name,
  2377. "query": request.query,
  2378. "documents": request.documents
  2379. }
  2380. if request.top_n is not None:
  2381. payload["top_n"] = request.top_n
  2382. if request.return_documents is not None:
  2383. payload["return_documents"] = request.return_documents
  2384. if request.user:
  2385. payload["user"] = request.user
  2386. # 使用统一的方法发送请求
  2387. api_url = f"{base_url}/rerank"
  2388. result_data = await self._handle_local_model_request(
  2389. api_url=api_url,
  2390. headers=headers,
  2391. payload=payload,
  2392. model_name=model.name,
  2393. base_url=base_url,
  2394. endpoint_type="rerank",
  2395. timeout=30.0
  2396. )
  2397. # 处理响应
  2398. import logging
  2399. logger = logging.getLogger(__name__)
  2400. results_list = result_data.get("data", [])
  2401. usage_data = result_data.get("usage", {})
  2402. # 如果 data 为空,尝试从 results 字段获取(某些模型如硅基流动使用 results)
  2403. if not results_list:
  2404. results_list = result_data.get("results", [])
  2405. logger.debug(f"Using 'results' field, found {len(results_list)} items")
  2406. # 尝试从不同位置获取 token 信息
  2407. if not usage_data or usage_data.get("total_tokens", 0) == 0:
  2408. # 尝试从 meta.tokens 获取(硅基流动格式)
  2409. meta = result_data.get("meta", {})
  2410. tokens = meta.get("tokens", {})
  2411. if tokens:
  2412. input_tokens = tokens.get("input_tokens", 0)
  2413. output_tokens = tokens.get("output_tokens", 0)
  2414. total_tokens = input_tokens + output_tokens
  2415. logger.debug(f"Using meta.tokens: input={input_tokens}, output={output_tokens}, total={total_tokens}")
  2416. else:
  2417. total_tokens = 0
  2418. else:
  2419. total_tokens = usage_data.get("total_tokens", 0)
  2420. data_list = []
  2421. for item in results_list:
  2422. # 获取文档内容(处理不同格式)
  2423. doc_content = None
  2424. if request.return_documents:
  2425. doc = item.get("document")
  2426. if isinstance(doc, dict):
  2427. # 如果 document 是对象,尝试获取 text 字段
  2428. doc_content = doc.get("text", "")
  2429. elif isinstance(doc, str):
  2430. # 如果 document 是字符串,直接使用
  2431. doc_content = doc
  2432. result_item = RerankResult(
  2433. index=item.get("index", 0),
  2434. relevance_score=item.get("relevance_score", 0.0)
  2435. )
  2436. if doc_content:
  2437. result_item.document = doc_content
  2438. data_list.append(result_item)
  2439. else:
  2440. # 云端模型处理(阿里云百炼)
  2441. api_url = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
  2442. headers = {
  2443. "Content-Type": "application/json",
  2444. "Authorization": f"Bearer {user.apikey}"
  2445. }
  2446. payload = {
  2447. "model": model.model_code,
  2448. "input": {
  2449. "query": request.query,
  2450. "documents": request.documents
  2451. }
  2452. }
  2453. if request.top_n is not None:
  2454. payload["parameters"] = {"top_n": request.top_n}
  2455. async with httpx.AsyncClient(timeout=30.0) as client:
  2456. response = await client.post(api_url, headers=headers, json=payload)
  2457. response.raise_for_status()
  2458. result_data = response.json()
  2459. output = result_data.get("output", {})
  2460. results_list = output.get("results", [])
  2461. usage_data = result_data.get("usage", {})
  2462. total_tokens = usage_data.get("total_tokens", 0)
  2463. data_list = []
  2464. for item in results_list:
  2465. result_item = RerankResult(
  2466. index=item.get("index", 0),
  2467. relevance_score=item.get("relevance_score", 0.0)
  2468. )
  2469. if request.return_documents:
  2470. result_item.document = request.documents[item.get("index", 0)]
  2471. data_list.append(result_item)
  2472. # 记录日志
  2473. bill = Decimal("0")
  2474. call_log = log_service.create_log(
  2475. user_id=user_id,
  2476. api_key_id=api_key_id,
  2477. model_id=model.id if model else None,
  2478. model_name=request.model,
  2479. is_local=model.is_local if model else False,
  2480. input_tokens=total_tokens,
  2481. output_tokens=0,
  2482. bill=float(bill),
  2483. status="success",
  2484. request_ip=request_ip
  2485. )
  2486. return RerankResponse(
  2487. model=request.model,
  2488. data=data_list,
  2489. usage=Usage(
  2490. prompt_tokens=total_tokens,
  2491. completion_tokens=0,
  2492. total_tokens=total_tokens
  2493. )
  2494. )
  2495. except Exception as e:
  2496. error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
  2497. model = self._find_model(request.model, user_id)
  2498. log_service.create_log(
  2499. user_id=user_id,
  2500. api_key_id=api_key_id,
  2501. model_id=model.id if model else None,
  2502. model_name=request.model,
  2503. is_local=model.is_local if model else False,
  2504. input_tokens=0,
  2505. output_tokens=0,
  2506. bill=0,
  2507. status="failed",
  2508. error_message=error_msg,
  2509. request_ip=request_ip
  2510. )
  2511. if isinstance(e, OpenAICompatError):
  2512. raise e
  2513. raise OpenAICompatError(status_code=500, message=error_msg, error_type="rerank_error")