| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867 |
- """
- OpenAI 兼容服务层
- 完整支持 /v1/chat/completions 和 /v1/models 接口
- 支持多种模型提供商的自动适配
- """
- import asyncio
- import logging
- import time
- import uuid
- from decimal import Decimal
- import httpx
- from fastapi import UploadFile
- from app.services.api_call_log_service import ApiCallLogService
- from app.services.image_service import ImageGenerationService
- from app.services.model_adapters import BaseAdapter, get_adapter, ModelProvider
- from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
- from sqlalchemy import desc
- from sqlalchemy.orm import Session
- from app.models.model import ModelNew as Model, ModelPriceNew as ModelPrice
- from app.models.user import User
- from app.schemas.openai_compat import (
- ChatCompletionsRequest,
- ModelInfo,
- EmbeddingsRequest,
- EmbeddingsResponse,
- EmbeddingData,
- Usage,
- ModelsListResponse,
- ImageGenerationRequest,
- ImageGenerationResponse,
- ImageData,
- AudioTranscriptionResponse,
- AudioSpeechRequest,
- VideoGenerationRequest,
- VideoGenerationResponse,
- RerankRequest,
- RerankResponse,
- RerankResult,
- )
- from app.services.crypto_utils import decrypt_api_key
- from app.services.system_config_manager import get_config_bool
- logger = logging.getLogger(__name__)
- # ─────────────────────────────────────────────
- # 工具函数
- # ─────────────────────────────────────────────
- def parse_video_size(size: str) -> Tuple[str, str]:
- """
- 解析视频尺寸,支持多种格式并转换为OpenAI标准格式和内部格式
-
- Args:
- size: 视频尺寸,支持以下格式:
- - OpenAI格式: "1280x720", "1920x1080", "720x1280"
- - 简写格式: "720P", "1080P", "720p", "1080p"
-
- Returns:
- (openai_format, internal_format)
- 例如: ("1280x720", "720P")
-
- Raises:
- ValueError: 如果格式无效
- """
- import re
-
- size = size.strip()
-
- # 如果是OpenAI格式 (widthxheight)
- if 'x' in size.lower():
- match = re.match(r'^(\d+)x(\d+)$', size.lower())
- if not match:
- raise ValueError(f"Invalid size format: {size}. Expected format: 1280x720")
-
- width, height = int(match.group(1)), int(match.group(2))
-
- # 推断简写格式(基于高度)
- if height == 720:
- internal = "720P"
- elif height == 1080:
- internal = "1080P"
- else:
- # 对于非标准分辨率,使用高度作为简写
- internal = f"{height}P"
-
- return (f"{width}x{height}", internal)
-
- # 如果是简写格式 (720P, 1080P)
- else:
- match = re.match(r'^(\d+)p$', size.lower())
- if not match:
- raise ValueError(f"Invalid size format: {size}. Expected format: 720P or 1280x720")
-
- height = int(match.group(1))
- size_upper = f"{height}P"
-
- # 标准分辨率映射(16:9比例)
- if height == 720:
- return ("1280x720", "720P")
- elif height == 1080:
- return ("1920x1080", "1080P")
- else:
- # 对于非标准分辨率,假设16:9比例
- width = int(height * 16 / 9)
- return (f"{width}x{height}", size_upper)
- class OpenAICompatError(Exception):
- """OpenAI 兼容服务错误"""
- def __init__(self, status_code: int, message: str, error_type: str = "invalid_request_error"):
- self.status_code = status_code
- self.message = message
- self.error_type = error_type
- super().__init__(message)
- class OpenAICompatService:
- """OpenAI API 兼容服务"""
- def __init__(self, db: Session):
- self.db = db
- self._user_cache: dict = {} # user_id → User,请求内缓存,避免重复查询
- def _get_user(self, user_id: str):
- """获取用户对象,同一请求内缓存,避免重复查询 users 表。"""
- if user_id not in self._user_cache:
- self._user_cache[user_id] = self.db.query(User).filter(User.id == user_id).first()
- return self._user_cache[user_id]
- # ─────────────────────────────────────────────
- # 主入口
- # ─────────────────────────────────────────────
- async def chat_completions(
- self,
- request: ChatCompletionsRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None,
- ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
- """
- 处理 Chat Completions 请求,包含日志记录与扣费。
- 权限检查流程:
- 1. 验证模型是否存在
- 2. 检查模型 is_api_enabled(云端模型)
- 3. 检查用户余额(云端模型)
- 4. 获取用户 API Key(云端模型)
- 5. 检查用户对本地模型的访问权限(本地模型)
- """
- log_service = ApiCallLogService(self.db)
- model = self._find_model(request.model, user_id)
- if not model:
- raise OpenAICompatError(
- status_code=404,
- message=f"The model '{request.model}' does not exist",
- error_type="model_not_found",
- )
- if not model.is_local and not model.is_api_enabled:
- raise OpenAICompatError(
- status_code=403,
- message=f"Model '{request.model}' does not support API calls",
- error_type="model_not_available",
- )
- # realtime 模型使用 WebSocket 实时流协议,不支持此 REST 接口
- if "realtime" in request.model.lower() and not model.is_local:
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' is a real-time streaming model that uses WebSocket protocol. "
- f"It cannot be called via /api/v1/chat/completions. "
- f"Please use the WebSocket API instead.",
- error_type="model_not_supported",
- )
- # 检查本地模型的访问权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
- user_api_key: Optional[str] = None
- if not model.is_local:
- user = self._get_user(user_id)
- if not user:
- raise OpenAICompatError(
- status_code=401,
- message="User not found",
- error_type="authentication_error",
- )
- # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己配置的 apikey
- if model.encrypted_api_key:
- from app.services.crypto_utils import decrypt_api_key
- decrypted = decrypt_api_key(model.encrypted_api_key)
- user_api_key = decrypted if decrypted else None
- if not user_api_key:
- user_api_key = user.apikey
- if not user_api_key:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured. Please configure your API key in settings.",
- error_type="api_key_not_configured",
- )
- # ── OCR 模型校验:必须包含图片 ────────────────────────────────────
- OCR_MODELS = ("qwen-vl-ocr",)
- if model.model_code in OCR_MODELS and not model.is_local:
- has_image = any(
- isinstance(msg.content, list) and
- any(isinstance(part, dict) and part.get("type") == "image_url"
- for part in msg.content)
- for msg in request.messages
- )
- # 也兼容 Pydantic 对象形式
- if not has_image:
- from app.schemas.openai_compat import ContentPartImage
- has_image = any(
- isinstance(msg.content, list) and
- any(isinstance(part, ContentPartImage) for part in msg.content)
- for msg in request.messages
- )
- if not has_image:
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model.model_code}' is an OCR model and requires at least one image in the messages. "
- f"Please include an image_url content part in your user message.",
- error_type="invalid_request_error",
- )
- # ── 流式 ──────────────────────────────────────────────────────────
- if request.stream:
- raw_stream = await self._call_local_model(model, request) if model.is_local \
- else await self._call_cloud_model(model, request, user_api_key)
- async def _stream_with_billing() -> AsyncGenerator[str, None]:
- input_text = "".join(
- [m.content for m in request.messages if isinstance(m.content, str)]
- )
- input_tokens = max(int(len(input_text) * 1.2), 1)
- output_tokens = 0
- stream_error: Optional[Exception] = None
- try:
- async for chunk in raw_stream:
- yield chunk
- if isinstance(chunk, str) and chunk.startswith("data: ") \
- and "data: [DONE]" not in chunk:
- try:
- import json as _json
- data_dict = _json.loads(chunk[6:])
- delta = data_dict.get("choices", [{}])[0] \
- .get("delta", {}).get("content", "")
- if delta:
- output_tokens += max(int(len(delta) * 1.2), 1)
- # 优先使用上游返回的真实 usage(部分模型在最后一个 chunk 里带)
- usage = data_dict.get("usage")
- if usage:
- input_tokens = usage.get("prompt_tokens", input_tokens)
- output_tokens = usage.get("completion_tokens", output_tokens)
- except Exception:
- pass
- except (GeneratorExit, asyncio.CancelledError):
- # 客户端断开/任务取消 - 仍按已产生的 token 扣费
- raise
- except Exception as exc:
- stream_error = exc
- raise
- finally:
- # 关键:无论流正常结束、客户端中断、还是异常都要扣费
- # 防止 "token 已被消耗但未扣费" 的资损
- try:
- if stream_error is not None:
- # 上游错误 - 记录失败日志
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=0, output_tokens=0,
- bill=0, status="failed",
- error_message=str(stream_error), request_ip=request_ip,
- )
- else:
- bill = self.calculate_bill(model, input_tokens, output_tokens)
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=input_tokens, output_tokens=output_tokens,
- bill=float(bill), status="success", request_ip=request_ip,
- )
- except Exception as fin_exc:
- logger.error("流式响应收尾日志记录失败: %s", fin_exc)
- return _stream_with_billing()
- # ── 非流式 ────────────────────────────────────────────────────────
- try:
- result = await self._call_local_model(model, request) if model.is_local \
- else await self._call_cloud_model(model, request, user_api_key)
- input_tokens, output_tokens = self.extract_usage_from_response(result)
- bill = self.calculate_bill(model, input_tokens, output_tokens)
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=input_tokens, output_tokens=output_tokens,
- bill=float(bill), status="success", request_ip=request_ip,
- )
- return result
- except OpenAICompatError:
- raise
- except Exception as exc:
- error_msg = str(exc) or repr(exc)
- logger.warning(
- "[CHAT_COMPLETION_FAILED] model=%s user_id=%s error_type=%s error=%s",
- request.model, user_id, type(exc).__name__, error_msg,
- )
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=0, output_tokens=0,
- bill=0, status="failed",
- error_message=error_msg, request_ip=request_ip,
- )
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="upstream_error")
- # ─────────────────────────────────────────────
- # 构建请求体(完整参数透传)
- # ─────────────────────────────────────────────
- def _build_request_body(
- self, request: ChatCompletionsRequest, model_name: str
- ) -> Dict[str, Any]:
- """将 ChatCompletionsRequest 转换为上游 API 请求体,透传所有非 None 参数"""
- # 序列化消息(content 支持 str 或 list)
- messages = []
- for msg in request.messages:
- m: Dict[str, Any] = {"role": msg.role}
- if msg.content is None:
- m["content"] = None
- elif isinstance(msg.content, str):
- m["content"] = msg.content
- else:
- # 多模态内容列表
- parts = []
- for part in msg.content:
- part_dict = part.model_dump(exclude_none=True)
- # 校验 image_url 格式
- if part_dict.get("type") == "image_url":
- url = (part_dict.get("image_url") or {}).get("url", "")
- if url and not url.startswith("data:"):
- import os
- ext = os.path.splitext(url.split("?")[0].lower())[1]
- SUPPORTED_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif', '.bmp'}
- if ext and ext not in SUPPORTED_IMAGE_EXTS:
- raise OpenAICompatError(
- status_code=400,
- message=f"Unsupported image format '{ext}'. "
- f"Supported formats: jpg, jpeg, png, webp, gif, bmp.",
- error_type="invalid_request_error",
- )
- parts.append(part_dict)
- m["content"] = parts
- if msg.name is not None:
- m["name"] = msg.name
- if msg.tool_calls is not None:
- m["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
- if msg.tool_call_id is not None:
- m["tool_call_id"] = msg.tool_call_id
- messages.append(m)
- body: Dict[str, Any] = {
- "model": model_name,
- "messages": messages,
- "stream": request.stream,
- }
- # 可选参数:只有非 None 才透传
- optional_fields = [
- "temperature", "top_p", "n", "stop",
- "presence_penalty", "frequency_penalty",
- "logit_bias", "logprobs", "top_logprobs",
- "seed", "user", "service_tier", "store", "metadata",
- "parallel_tool_calls",
- ]
- for field in optional_fields:
- val = getattr(request, field, None)
- if val is not None:
- body[field] = val
- # max_tokens / max_completion_tokens(优先新版)
- if request.max_completion_tokens is not None:
- body["max_completion_tokens"] = request.max_completion_tokens
- elif request.max_tokens is not None:
- body["max_tokens"] = request.max_tokens
- # 流式选项
- if request.stream and request.stream_options is not None:
- body["stream_options"] = request.stream_options.model_dump(exclude_none=True)
- # 工具调用
- if request.tools is not None:
- body["tools"] = [t.model_dump(exclude_none=True) for t in request.tools]
- if request.tool_choice is not None:
- if isinstance(request.tool_choice, str):
- body["tool_choice"] = request.tool_choice
- else:
- body["tool_choice"] = request.tool_choice.model_dump()
- # 响应格式
- if request.response_format is not None:
- body["response_format"] = request.response_format.model_dump(exclude_none=True)
- return body
- # ─────────────────────────────────────────────
- # 本地模型调用
- # ─────────────────────────────────────────────
- async def _call_local_model(
- self, model: Model, request: ChatCompletionsRequest
- ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- headers: Dict[str, str] = {"Content-Type": "application/json"}
- # 本地模型不使用用户的API密钥,而是使用模型配置的API密钥
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
- # 使用模型的name字段作为实际模型名称
- actual_name = model.name
- body = self._build_request_body(request, actual_name)
- api_url = f"{base_url}/chat/completions"
- if request.stream:
- return self._stream_response(api_url, headers, body)
- else:
- return await self._non_stream_response(api_url, headers, body)
- # ─────────────────────────────────────────────
- # 云端模型调用(阿里云百炼 OpenAI 兼容模式)
- # ─────────────────────────────────────────────
- async def _call_cloud_model(
- self, model: Model, request: ChatCompletionsRequest, user_api_key: str
- ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
- api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {user_api_key}",
- }
- body = self._build_request_body(request, model.model_code)
- if request.stream:
- return self._stream_response(api_url, headers, body)
- else:
- return await self._non_stream_response(api_url, headers, body)
- # ─────────────────────────────────────────────
- # HTTP 请求封装
- # ─────────────────────────────────────────────
- async def _non_stream_response(
- self, api_url: str, headers: Dict[str, str], body: Dict[str, Any],
- timeout: float = 300.0,
- ) -> Dict[str, Any]:
- model_name = body.get("model", "unknown")
- try:
- async with httpx.AsyncClient(timeout=timeout) as client:
- resp = await client.post(api_url, headers=headers, json=body)
- if resp.status_code >= 400:
- error_msg = self._extract_upstream_error(resp)
- logger.warning(
- "[UPSTREAM_ERROR] model=%s status=%d url=%s response=%s",
- model_name, resp.status_code, api_url, resp.text[:500],
- )
- raise OpenAICompatError(
- status_code=resp.status_code,
- message=error_msg,
- error_type="upstream_error",
- )
- return resp.json()
- except httpx.ReadTimeout:
- logger.warning("[UPSTREAM_TIMEOUT] model=%s url=%s timeout=%ss", model_name, api_url, timeout)
- raise OpenAICompatError(
- status_code=504,
- message=f"模型 '{model_name}' 响应超时({timeout}s),请稍后重试或换个模型",
- error_type="timeout_error",
- )
- async def _stream_response(
- self, api_url: str, headers: Dict[str, str], body: Dict[str, Any]
- ) -> AsyncGenerator[str, None]:
- """SSE 流式响应生成器"""
- model_name = body.get("model", "unknown")
- try:
- async with httpx.AsyncClient(
- timeout=httpx.Timeout(30.0, read=None)
- ) as client:
- async with client.stream("POST", api_url, headers=headers, json=body) as resp:
- if resp.status_code >= 400:
- error_body = await resp.aread()
- error_text = error_body.decode("utf-8", errors="replace")
- logger.warning(
- "[UPSTREAM_STREAM_ERROR] model=%s status=%d url=%s response=%s",
- model_name, resp.status_code, api_url, error_text[:500],
- )
- # 使用统一的错误提取方法
- error_detail = self._extract_upstream_error(resp)
- # 针对特定状态码提供友好提示
- if resp.status_code == 401:
- raise OpenAICompatError(
- status_code=500,
- message=f"模型 '{model_name}' 认证失败(401 Unauthorized)。\n"
- f"请在管理后台检查该模型的 API Key 配置是否正确。\n"
- f"错误详情: {error_detail}",
- error_type="authentication_error",
- )
- elif resp.status_code == 404:
- raise OpenAICompatError(
- status_code=500,
- message=f"模型 '{model_name}' 接口不存在(404 Not Found)。请检查 Base URL 和模型名称是否正确。\n"
- f"错误详情: {error_detail}",
- error_type="not_found_error",
- )
- else:
- raise OpenAICompatError(
- status_code=resp.status_code,
- message=f"模型 '{model_name}' 调用失败({resp.status_code}): {error_detail}",
- error_type="upstream_error",
- )
- async for line in resp.aiter_lines():
- if line.startswith("data: "):
- data = line[6:]
- if data.strip() == "[DONE]":
- yield "data: [DONE]\n\n"
- return
- yield f"data: {data}\n\n"
- elif line.strip():
- yield f"data: {line}\n\n"
- except httpx.ConnectTimeout:
- logger.warning("[UPSTREAM_CONNECT_TIMEOUT] model=%s url=%s", model_name, api_url)
- raise OpenAICompatError(
- status_code=504,
- message=f"模型 '{model_name}' 连接超时,请稍后重试",
- error_type="timeout_error",
- )
- def _extract_upstream_error(self, resp: httpx.Response) -> str:
- """从上游错误响应中提取错误信息,确保始终返回有意义的内容"""
- raw_text = ""
- try:
- raw_text = resp.text
- data = resp.json()
- # 标准 OpenAI 格式: {"error": {"message": "...", "type": "...", "code": "..."}}
- err = data.get("error", {})
- if isinstance(err, dict):
- msg = err.get("message", "")
- err_type = err.get("type", "")
- code = err.get("code", "")
- parts = [p for p in [msg, f"type={err_type}" if err_type else "", f"code={code}" if code else ""] if p]
- if parts:
- return " | ".join(parts)
- # DashScope 格式: {"code": "...", "message": "...", "request_id": "..."}
- code = data.get("code", "")
- msg = data.get("message", "")
- request_id = data.get("request_id", "")
- if msg:
- parts = [msg]
- if code:
- parts.append(f"code={code}")
- if request_id:
- parts.append(f"request_id={request_id}")
- return " | ".join(parts)
- # 兜底:返回整个 JSON
- if data:
- return str(data)
- except Exception:
- pass
- # JSON 解析失败或为空,返回原始文本
- if raw_text:
- return raw_text[:500]
- return f"Upstream error {resp.status_code} (empty response body)"
- # ─────────────────────────────────────────────
- # Models 列表
- # ─────────────────────────────────────────────
- def get_available_models(self, user_id: str, key_type: str = "public") -> ModelsListResponse:
- """返回用户可用的模型,根据API密钥类型过滤"""
- models_data: List[ModelInfo] = []
- # 根据密钥类型返回相应的模型
- if key_type == "public":
- # 公钥只能访问云端模型
- cloud_models = (
- self.db.query(Model)
- .filter(Model.is_local == False, Model.is_api_enabled == True)
- .all()
- )
- for m in cloud_models:
- models_data.append(
- ModelInfo(
- id=m.model_code,
- object="model",
- created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
- owned_by=m.supplier or "platform",
- )
- )
- elif key_type == "local":
- # 检查本地模型是否启用
- if get_config_bool("enable_local_models", True):
- # 如果本地模型启用,返回所有本地模型
- local_models = (
- self.db.query(Model)
- .filter(
- Model.is_local == True
- )
- .all()
- )
- for m in local_models:
- models_data.append(
- ModelInfo(
- id=f"local:{m.id}",
- object="model",
- created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
- owned_by="local",
- )
- )
- else:
- # 如果本地模型未启用,返回用户有权限的模型
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- permissions = permission_service.get_user_model_permissions(user_id)
- permitted_model_ids = [perm["model_id"] for perm in permissions if perm["has_access"]]
-
- local_models = (
- self.db.query(Model)
- .filter(
- Model.is_local == True,
- Model.id.in_(permitted_model_ids)
- )
- .all()
- )
- for m in local_models:
- models_data.append(
- ModelInfo(
- id=f"local:{m.id}",
- object="model",
- created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
- owned_by="local",
- )
- )
- return ModelsListResponse(object="list", data=models_data)
- # ─────────────────────────────────────────────
- # 工具方法
- # ─────────────────────────────────────────────
- async def _handle_local_model_request(
- self,
- api_url: str,
- headers: Dict[str, str],
- payload: Dict[str, Any],
- model_name: str,
- base_url: str,
- endpoint_type: str = "chat",
- timeout: float = 60.0,
- return_raw_response: bool = False
- ) -> Union[Dict[str, Any], httpx.Response]:
- """
- 统一处理本地模型 HTTP 请求,直接透传 OpenAI 格式
-
- 注意:本地模型必须是 OpenAI 兼容的,不做任何格式适配
-
- Args:
- api_url: 请求URL
- headers: 请求头
- payload: 请求体(OpenAI 格式)
- model_name: 模型名称(用于错误提示)
- base_url: 模型的 base_url
- endpoint_type: 端点类型(保留参数,暂未使用)
- timeout: 超时时间
- return_raw_response: 是否返回原始响应对象(用于处理音频等二进制数据)
-
- Returns:
- 响应的 JSON 数据(OpenAI 格式)或原始 httpx.Response 对象
-
- Raises:
- OpenAICompatError: 请求失败时抛出
- """
- async with httpx.AsyncClient(timeout=timeout) as client:
- try:
- response = await client.post(api_url, headers=headers, json=payload)
- response.raise_for_status()
-
- # 如果需要原始响应(例如音频数据),直接返回
- if return_raw_response:
- return response
-
- # 直接返回 JSON 响应,不做任何转换
- result = response.json()
- return result
-
- except httpx.HTTPStatusError as e:
- if e.response.status_code == 401:
- # 认证失败
- error_detail = ""
- try:
- error_data = e.response.json()
- if isinstance(error_data, dict):
- error_detail = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
- except:
- error_detail = e.response.text[:200]
-
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 '{model_name}' 认证失败(401 Unauthorized)。\n"
- f"请在管理后台检查该模型的 API Key 配置是否正确。\n"
- f"错误详情: {error_detail}",
- error_type="authentication_error",
- )
- elif e.response.status_code == 404:
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 '{model_name}' 接口不存在(404 Not Found)。请检查 Base URL 和模型名称是否正确。",
- error_type="not_found_error",
- )
- else:
- # 其他 HTTP 错误
- error_detail = ""
- try:
- error_data = e.response.json()
- if isinstance(error_data, dict):
- error_detail = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
- except:
- error_detail = e.response.text[:200]
-
- raise OpenAICompatError(
- status_code=e.response.status_code,
- message=f"本地模型调用失败({e.response.status_code}): {error_detail}",
- error_type="upstream_error",
- )
- except httpx.TimeoutException:
- raise OpenAICompatError(
- status_code=504,
- message=f"本地模型 '{model_name}' 请求超时。请检查网络连接或增加超时时间。",
- error_type="timeout_error",
- )
- except httpx.RequestError as e:
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 '{model_name}' 请求失败: {str(e)}",
- error_type="request_error",
- )
- def _find_model(self, model_name: str, user_id: str) -> Optional[Model]:
- # 优先识别 local:{id} 格式(精确匹配)
- if model_name.startswith("local:"):
- try:
- model_id = int(model_name[6:])
- except ValueError:
- return None
- return (
- self.db.query(Model)
- .filter(Model.id == model_id, Model.is_local == True)
- .first()
- )
- # supplier/name 格式(本地模型)
- if "/" in model_name:
- parts = model_name.split("/", 1)
- if len(parts) == 2:
- supplier, name = parts
- local_with_supplier = (
- self.db.query(Model)
- .filter(
- Model.supplier == supplier,
- Model.display_name == name,
- Model.is_local == True,
- )
- .order_by(Model.created_at.desc())
- .first()
- )
- if local_with_supplier:
- return local_with_supplier
- # 云端模型按 model_code 精确匹配
- cloud = (
- self.db.query(Model)
- .filter(Model.model_code == model_name, Model.is_local == False)
- .first()
- )
- if cloud:
- return cloud
- # 本地模型按 display_name 查找
- return (
- self.db.query(Model)
- .filter(Model.display_name == model_name, Model.is_local == True)
- .order_by(Model.created_at.desc())
- .first()
- )
- def calculate_bill(
- self, model: Model, input_tokens: int, output_tokens: int
- ) -> Decimal:
- """API 调用免费,始终返回 0"""
- return Decimal("0")
- def extract_usage_from_response(self, response: Dict[str, Any]) -> Tuple[int, int]:
- usage = response.get("usage", {})
- input_tokens = usage.get("prompt_tokens", 0)
- output_tokens = usage.get("completion_tokens", 0)
- return input_tokens, output_tokens
- async def _validate_model_and_balance(self, model_name: str, user_id: str) -> Model:
- """
- 验证模型状态
- Args:
- model_name: 模型名称
- user_id: 用户ID
- Returns:
- Model: 验证通过的模型对象
- Raises:
- OpenAICompatError: 模型不存在或不可用时抛出
- """
- model = self._find_model(model_name, user_id)
- if not model:
- raise OpenAICompatError(
- status_code=404,
- message=f"Model '{model_name}' not found",
- error_type="model_not_found"
- )
- if not model.is_local and not model.is_api_enabled:
- raise OpenAICompatError(
- status_code=403,
- message=f"Model '{model_name}' does not support API calls",
- error_type="model_not_available"
- )
- return model
-
- async def embeddings(
- self,
- request: EmbeddingsRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> EmbeddingsResponse:
- """
- 处理文本嵌入(Embeddings)请求
-
- Args:
- request: 嵌入请求对象
- user_id: 用户ID
- api_key_id: 用于日志记录
- request_ip: 请求来源IP
-
- Returns:
- EmbeddingsResponse: 包含向量数据的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
-
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持向量嵌入
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.EMBEDDING), int(ModelCategory.LLM), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support embeddings",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
- else:
- user = self._get_user(user_id)
- if not user:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured.",
- error_type="api_key_not_configured"
- )
- # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己配置的 apikey
- effective_api_key: Optional[str] = None
- if model.encrypted_api_key:
- from app.services.crypto_utils import decrypt_api_key as _decrypt
- decrypted = _decrypt(model.encrypted_api_key)
- effective_api_key = decrypted if decrypted else None
- if not effective_api_key:
- effective_api_key = user.apikey
- if not effective_api_key:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured.",
- error_type="api_key_not_configured"
- )
- texts = [request.input] if isinstance(request.input, str) else request.input
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "input": texts
- }
- if request.dimensions:
- payload["dimensions"] = request.dimensions
- if request.user:
- payload["user"] = request.user
- # 使用统一的方法发送请求
- api_url = f"{base_url}/embeddings"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="embedding",
- timeout=30.0
- )
- # 处理响应
- embeddings_list = result_data.get("data", [])
- usage_data = result_data.get("usage", {})
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in embeddings_list:
- data_list.append(
- EmbeddingData(
- index=item.get("index", 0),
- embedding=item.get("embedding", [])
- )
- )
- else:
- # 云端模型处理
- # 根据模型类型选择端点:多模态 embedding vs 文本 embedding
- # 多模态模型:名称含 vl/vision/multimodal
- code_lower = model.model_code.lower()
- is_multimodal = any(kw in code_lower for kw in ("vl", "vision", "multimodal"))
- if is_multimodal:
- api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/multimodal-embedding/multimodal-embedding"
- payload = {
- "model": model.model_code,
- "input": {
- "contents": [{"text": t} for t in texts]
- }
- }
- else:
- api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
- payload = {
- "model": model.model_code,
- "input": {
- "texts": texts
- }
- }
- if request.dimensions:
- payload.setdefault("parameters", {})["dimension"] = request.dimensions
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {effective_api_key}"
- }
- # 多模态 embedding 的 dimension 放在 parameters 里
- if is_multimodal and request.dimensions:
- payload.setdefault("parameters", {})["dimension"] = request.dimensions
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(api_url, headers=headers, json=payload)
- response.raise_for_status()
- result_data = response.json()
- output = result_data.get("output", {})
- embeddings_list = output.get("embeddings", [])
- usage_data = result_data.get("usage", {})
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in embeddings_list:
- data_list.append(
- EmbeddingData(
- index=item.get("text_index", 0),
- embedding=item.get("embedding", [])
- )
- )
- bill = Decimal("0")
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=total_tokens,
- output_tokens=0,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
- return EmbeddingsResponse(
- model=request.model,
- data=data_list,
- usage=Usage(
- prompt_tokens=total_tokens,
- completion_tokens=0,
- total_tokens=total_tokens
- )
- )
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="embeddings_error")
-
- async def image_generations(
- self, request: ImageGenerationRequest, user_id: str, api_key_id: int, request_ip: str = None
- ) -> ImageGenerationResponse:
- """
- 处理图像生成请求
-
- 调用底层ImageGenerationService完成真实的图像生成与业务扣费。
-
- Args:
- request: 图像生成请求对象
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- ImageGenerationResponse: 符合OpenAI规范的图像响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- # 验证模型状态与用户基础余额
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持图像生成(文生图)
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.IMAGE_GEN), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support image generation. Use a model with category IMAGE_GEN or MULTIMODAL.",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- try:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
- else:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型 API Key 解密失败",
- error_type="configuration_error",
- )
- except Exception as e:
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 API Key 处理失败: {str(e)}",
- error_type="configuration_error",
- )
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": request.prompt,
- "n": request.n or 1,
- "size": request.size or "1024x1024"
- }
- if request.quality:
- payload["quality"] = request.quality
- if request.response_format:
- payload["response_format"] = request.response_format
- if request.style:
- payload["style"] = request.style
- if request.user:
- payload["user"] = request.user
- # 使用统一的方法发送请求
- api_url = f"{base_url}/images/generations"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="image",
- timeout=60.0
- )
- # 处理响应
- images = []
- for item in result_data.get("data", []):
- if item.get("url"):
- images.append(item.get("url"))
- elif item.get("b64_json"):
- # 处理base64编码的图像
- import base64
- from app.services.oss_service import get_oss_service
- oss_service = get_oss_service()
- image_bytes = base64.b64decode(item.get("b64_json"))
- url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/generations",
- original_filename=f"generated_{time.time()}.png"
- )
- images.append(url)
- if not images:
- raise OpenAICompatError(
- status_code=500,
- message="图像生成失败:未返回图像",
- error_type="image_generation_error",
- )
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in images]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- # 获取用户API Key并实例化底层服务
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
- real_image_service = ImageGenerationService(self.db, api_key=dashscope_api_key)
-
- # 适配尺寸参数
- mapped_size = request.size.replace("x", "*") if request.size else "1024*1024"
-
- # 调用底层图像生成服务
- result_obj = await real_image_service.text_to_image(
- user_id=user_id,
- prompt=request.prompt,
- model=model.model_code,
- n=request.n or 1,
- size=mapped_size
- )
-
- if not result_obj.success:
- raise OpenAICompatError(
- status_code=500,
- message=result_obj.error or "图像生成失败",
- error_type="image_generation_error"
- )
-
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in result_obj.images]
- )
- bill = result_obj.bill
-
- # 记录日志
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=len(result.data),
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="image_generation_error")
- async def image_edits(
- self,
- image: Union[str, UploadFile],
- prompt: str,
- mask: Optional[Union[str, UploadFile]],
- model_name: str,
- n: int,
- size: str,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> ImageGenerationResponse:
- """
- 处理图像编辑(图生图)请求
-
- 接收上传图片,转存OSS后调用底层ImageGenerationService处理。
-
- Args:
- image: 用户上传的原始图片
- prompt: 对新图像的文本描述
- mask: 可选的遮罩图
- model_name: 模型名称
- n: 生成数量
- size: 生成尺寸
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- ImageGenerationResponse: 包含生成图片URL的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
-
- 需求: OpenAI兼容-图生图
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(model_name, user_id)
-
- # 检查模型类型是否支持图像编辑(图生图)
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.IMAGE_EDIT), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model_name}' does not support image editing. Use a model with category IMAGE_EDIT or MULTIMODAL.",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{model_name}'",
- error_type="permission_error",
- )
-
- import base64
- from app.services.oss_service import get_oss_service
- oss_service = get_oss_service()
-
- # 处理图像数据
- if isinstance(image, str):
- # 检查是否是URL
- if image.startswith(('http://', 'https://')):
- # 直接使用URL
- image_url = image
- else:
- # 解码base64字符串
- image_bytes = base64.b64decode(image)
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/edits",
- original_filename="edit_source.png"
- )
- else:
- # 处理UploadFile对象
- image_bytes = await image.read()
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/edits",
- original_filename=image.filename or "edit_source.png"
- )
- image_urls = [image_url]
-
- # 处理遮罩数据
- if mask:
- if isinstance(mask, str):
- # 检查是否是URL
- if mask.startswith(('http://', 'https://')):
- # 直接使用URL
- mask_url = mask
- else:
- # 解码base64字符串
- mask_bytes = base64.b64decode(mask)
- mask_url = oss_service.upload_file(
- mask_bytes,
- prefix="ai-images/edits",
- original_filename="edit_mask.png"
- )
- else:
- # 处理UploadFile对象
- mask_bytes = await mask.read()
- mask_url = oss_service.upload_file(
- mask_bytes,
- prefix="ai-images/edits",
- original_filename=mask.filename or "edit_mask.png"
- )
- image_urls.append(mask_url)
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": prompt,
- "n": n,
- "size": size or "1024x1024"
- }
-
- # 处理图像和遮罩
- if len(image_urls) == 1:
- payload["image"] = image_urls[0]
- elif len(image_urls) == 2:
- payload["image"] = image_urls[0]
- payload["mask"] = image_urls[1]
-
- if prompt:
- payload["prompt"] = prompt
- # 使用统一的方法发送请求
- api_url = f"{base_url}/images/edits"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="image",
- timeout=60.0
- )
- # 处理响应
- images = []
- for item in result_data.get("data", []):
- if item.get("url"):
- images.append(item.get("url"))
- elif item.get("b64_json"):
- # 处理base64编码的图像
- image_bytes = base64.b64decode(item.get("b64_json"))
- url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/edits",
- original_filename=f"edited_{time.time()}.png"
- )
- images.append(url)
- if not images:
- raise OpenAICompatError(
- status_code=500,
- message="图像编辑失败:未返回图像",
- error_type="image_edit_error",
- )
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in images]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
-
- from app.services.image_service import ImageGenerationService
- real_image_service = ImageGenerationService(self.db, api_key=dashscope_api_key)
-
- mapped_size = size.replace("x", "*") if size else "1024*1024"
-
- result_obj = await real_image_service.image_to_image(
- user_id=user_id,
- image_urls=image_urls,
- prompt=prompt,
- model=model.model_code,
- n=n,
- size=mapped_size
- )
-
- if not result_obj.success:
- raise OpenAICompatError(
- status_code=500,
- message=result_obj.error or "图生图编辑失败",
- error_type="image_edit_error"
- )
-
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in result_obj.images]
- )
- bill = result_obj.bill
-
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=model_name,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=len(result.data),
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(model_name, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=model_name,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="image_edit_error")
-
- async def audio_transcriptions(
- self,
- file: Union[str, UploadFile],
- model_name: str,
- language: Optional[str],
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> AudioTranscriptionResponse:
- """
- 处理语音识别(STT)请求
-
- 接收上传的音频文件,转换为Base64编码后调用底层ASRService处理。
- 包含模型名称从OpenAI(whisper-1)到DashScope原生模型的映射。
-
- Args:
- file: 客户端上传的音频文件
- model_name: 模型名称
- language: 语言代码 (ISO-639-1)
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- AudioTranscriptionResponse: 包含识别文本的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
-
- 需求: OpenAI兼容-语音转文字
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- actual_model = "qwen3-asr-flash" if model_name in ["whisper-1", "whisper-large-v3"] else model_name
- # realtime 模型使用 WebSocket 实时流协议,不支持此文件上传接口
- if "realtime" in actual_model.lower():
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{actual_model}' is a real-time streaming model that uses WebSocket protocol. "
- f"It cannot be called via /api/v1/audio/transcriptions. "
- f"Please use the WebSocket API instead.",
- error_type="model_not_supported",
- )
- model = await self._validate_model_and_balance(actual_model, user_id)
-
- # 检查模型类型是否支持语音识别
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.STT), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model_name}' does not support speech transcription",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{model_name}'",
- error_type="permission_error",
- )
-
- import base64
- import httpx
- if isinstance(file, str):
- # 检查是否是URL
- if file.startswith(('http://', 'https://')):
- # 从URL下载音频文件并转换为base64
- async with httpx.AsyncClient() as client:
- response = await client.get(file)
- response.raise_for_status()
- audio_bytes = response.content
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
- else:
- # 使用base64字符串
- audio_base64 = file
- else:
- # 处理UploadFile对象
- audio_bytes = await file.read()
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "file": audio_base64
- }
- if language:
- payload["language"] = language
- # 使用统一的方法发送请求
- api_url = f"{base_url}/audio/transcriptions"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="audio_stt",
- timeout=60.0
- )
- # 处理响应
- text = result_data.get("text", "")
- duration_seconds = len(audio_base64) // 10000 # 粗略估算
- text_length = len(text)
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
-
- from app.services.asr_service import ASRService
- from app.schemas.audio_schema import ASRRequest
-
- real_asr_service = ASRService(self.db, user_id=user_id, api_key=dashscope_api_key)
- internal_req = ASRRequest(
- model=actual_model,
- audio_base64=audio_base64,
- language=language
- )
-
- asr_response = await real_asr_service.recognize(internal_req)
-
- text = asr_response.text
- duration_seconds = asr_response.usage.seconds if asr_response.usage else 0
- text_length = len(text) if text else 0
- bill = Decimal("0")
-
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=model_name,
- is_local=model.is_local if model else False,
- input_tokens=duration_seconds,
- output_tokens=text_length,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return AudioTranscriptionResponse(text=text)
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(model_name, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=model_name,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="stt_error")
-
- async def audio_speech(
- self,
- request: AudioSpeechRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> Tuple[AsyncGenerator, str]:
- """
- 处理文字转语音(TTS)请求
-
- 执行全量音色映射与模型能力校验。
- 通过底层 TTSService 生成语音并转存 OSS,随后转换为流式下发。
-
- Args:
- request: TTS请求对象
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- Tuple[AsyncGenerator, str]: 音频二进制生成器与 MIME 类型
-
- Raises:
- OpenAICompatError: 模型不支持所选音色或处理失败时抛出
-
- 需求: OpenAI兼容-文字转语音
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- actual_model = "cosyvoice-v3-flash" if request.model in ["tts-1", "tts-1-hd"] else request.model
- # realtime 模型使用 WebSocket 实时流协议,不支持此 REST 接口
- if "realtime" in actual_model.lower():
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{actual_model}' is a real-time streaming model that uses WebSocket protocol. "
- f"It cannot be called via /api/v1/audio/speech. "
- f"Please use the WebSocket API instead.",
- error_type="model_not_supported",
- )
- # cosyvoice-clone 系列:voice 参数就是 voice_id,不做映射,直接透传
- is_clone = "clone" in actual_model.lower()
- voice_map = {
- "alloy": "longxiaochun_v3",
- "echo": "longcheng_v3",
- "shimmer": "longwan_v3",
- "onyx": "longhua_v3",
- "nova": "longxiaoxia_v3",
- "fable": "longshu_v3",
- "sunny": "longanyang",
- "lively": "longanhuan",
- "cute_girl": "longhuhu_v3",
- "cute_boy": "longniuniu_v3",
- "bubble": "longpaopao_v3",
- "naughty": "longjielidou_v3",
- "bold_girl": "longxian_v3",
- "cantonese_f": "longjiaxin_v3",
- "cantonese_m": "longanyue_v3",
- "dongbei": "longlaotie_v3",
- "shanbei": "longshange_v3",
- "korean": "loongkyong_v3",
- "japanese": "loongriko_v3",
- "news_m": "longfei_v3",
- "news_f": "longxiaoxia_v3",
- "story_m": "longxiu_v3",
- "story_f": "longmiao_v3",
- "customer_service": "longyingxiao_v3",
- "monkey": "longhouge_v3",
- "robot": "longjiqi_v3",
- "daiyu": "longdaiyu_v3",
- "uncle": "longlaobo_v3",
- "aunt": "longlaoyi_v3"
- }
-
- actual_voice = request.voice if is_clone else voice_map.get(request.voice.lower(), request.voice)
- if "plus" in actual_model.lower():
- plus_allowed_voices = ["longanyang", "longanhuan"]
- if actual_voice not in plus_allowed_voices:
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{actual_model}' only supports voices: {plus_allowed_voices}. Requested: '{actual_voice}'.",
- error_type="invalid_request_error"
- )
- model = await self._validate_model_and_balance(actual_model, user_id)
-
- # 检查模型类型是否支持语音合成
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.TTS), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support speech synthesis",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- # 注意:本地模型使用原始的 OpenAI 音色名称,不进行映射
- payload = {
- "model": model.name,
- "input": request.input,
- "voice": request.voice, # 使用原始音色名称,不映射
- "response_format": request.response_format or "mp3",
- "speed": request.speed or 1.0
- }
- # 使用统一的适配器方法发送请求(获取原始响应)
- api_url = f"{base_url}/audio/speech"
- response = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="audio_tts",
- timeout=60.0,
- return_raw_response=True
- )
-
- # 检查响应类型
- content_type = response.headers.get("content-type", "")
-
- import httpx
- if "application/json" in content_type:
- # 响应是 JSON,包含音频 URL
- result_data = response.json()
- audio_url = result_data.get("audio_url") or result_data.get("url")
-
- if not audio_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型返回的 JSON 中未找到音频 URL",
- error_type="invalid_response",
- )
-
- # 下载音频并流式返回
- async def generate_audio():
- async with httpx.AsyncClient() as client:
- async with client.stream("GET", audio_url) as audio_response:
- audio_response.raise_for_status()
- async for chunk in audio_response.aiter_bytes():
- yield chunk
- else:
- # 响应直接是音频数据(如硅基流动)
- audio_bytes = response.content
-
- async def generate_audio():
- # 直接返回音频数据
- yield audio_bytes
- media_type = f"audio/{request.response_format or 'mp3'}" if (request.response_format or 'mp3') != 'mp3' else "audio/mpeg"
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
-
- from app.services.tts_service import TTSService
- from app.schemas.audio_schema import TTSRequest
-
- real_tts_service = TTSService(self.db, user_id=user_id, api_key=dashscope_api_key)
-
- internal_req = TTSRequest(
- text=request.input,
- model=actual_model,
- voice=actual_voice,
- format=request.response_format or "mp3",
- sample_rate=24000
- )
-
- tts_response = await real_tts_service.synthesize(internal_req)
-
- import httpx
- async def generate_audio():
- async with httpx.AsyncClient() as client:
- async with client.stream("GET", tts_response.audio_url) as response:
- response.raise_for_status()
- async for chunk in response.aiter_bytes():
- yield chunk
-
- media_type = f"audio/{request.response_format}" if request.response_format != 'mp3' else "audio/mpeg"
- bill = Decimal("0")
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=len(request.input),
- output_tokens=0,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
- return generate_audio(), media_type
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=request.model,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="tts_error")
-
- async def video_generations(
- self,
- request: VideoGenerationRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> VideoGenerationResponse:
- """
- 处理视频生成请求
-
- 调用底层VideoService提交异步任务,并通过轮询将其封装为同步阻塞接口。
-
- Args:
- request: 视频生成请求对象
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- VideoGenerationResponse: 包含最终视频URL的响应对象
-
- Raises:
- OpenAICompatError: 模型不支持或生成失败时抛出
-
- 需求: OpenAI兼容-视频生成
- """
- import time
- import asyncio
- from app.services.video_service import VideoService
- from app.schemas.video_schema import VideoGenerateRequest
-
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持视频生成
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.VIDEO_GEN), int(ModelCategory.MULTIMODAL), int(ModelCategory.LLM)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support video generation",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": request.prompt,
- "size": request.size or "1280x720", # 使用OpenAI标准格式
- "duration": request.duration or 5
- }
- # 使用统一的适配器方法发送请求
- api_url = f"{base_url}/videos/generations"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="video",
- timeout=60.0
- )
- # 处理响应
- videos = []
- for item in result_data.get("data", []):
- if item.get("url"):
- videos.append(item.get("url"))
- if not videos:
- raise OpenAICompatError(
- status_code=500,
- message="视频生成失败:未返回视频",
- error_type="video_generation_error",
- )
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=url, content_type="video/mp4") for url in videos]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- from app.services.crypto_utils import get_effective_api_key
- dashscope_api_key = get_effective_api_key(self.db, request.model, user.apikey if user else "") if user else ""
-
- real_video_service = VideoService(self.db, user_id=int(user_id) if str(user_id).isdigit() else user_id, api_key=dashscope_api_key)
-
- # 解析并转换size格式
- try:
- openai_size, internal_size = parse_video_size(request.size or "1280x720")
- except ValueError as e:
- raise OpenAICompatError(
- status_code=400,
- message=str(e),
- error_type="invalid_request_error"
- )
-
- internal_req = VideoGenerateRequest(
- prompt=request.prompt,
- resolution=internal_size, # 使用内部格式 "720P"
- duration=request.duration or 5,
- prompt_extend=True,
- watermark=False
- )
-
- # 提交异步任务
- task_resp = await real_video_service.generate(internal_req)
- task_id = task_resp.task_id
-
- # 阻塞轮询
- max_retries = 120
- poll_interval = 5
-
- final_video_url = None
-
- for _ in range(max_retries):
- await asyncio.sleep(poll_interval)
- status_result = await real_video_service.get_task_status(task_id)
-
- if status_result.task_status == "SUCCEEDED":
- final_video_url = status_result.video_url
- break
- elif status_result.task_status == "FAILED":
- raise OpenAICompatError(
- status_code=500,
- message=status_result.error_message or "底层视频生成任务失败",
- error_type="video_generation_error"
- )
-
- if not final_video_url:
- raise OpenAICompatError(
- status_code=504,
- message="视频生成任务超时,请稍后再试或通过平台任务列表查看结果",
- error_type="timeout_error"
- )
-
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=final_video_url, content_type="video/mp4")]
- )
- bill = Decimal("0")
-
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=request.duration or 5,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=request.model,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="video_generation_error")
-
- async def image_to_video_generations(
- self,
- image: Union[str, UploadFile],
- prompt: str,
- model_name: str,
- size: str,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> VideoGenerationResponse:
- """
- 处理图生视频(I2V)请求
-
- 接收上传图片,转存OSS获取URL,调用底层VideoService提交图生视频异步任务,
- 并轮询任务状态封装为同步接口返回。
-
- Args:
- image: 客户端上传的参考图像
- prompt: 对视频的文本描述
- model_name: 模型名称 (如 wan2.6-i2v)
- size: 视频分辨率 (如 720P)
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- VideoGenerationResponse: 包含最终视频URL的响应对象
-
- Raises:
- OpenAICompatError: 处理失败时抛出
-
- 需求: OpenAI兼容-图生视频
- """
- import time
- import asyncio
- import base64
- from app.services.oss_service import get_oss_service
-
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(model_name, user_id)
-
- # 检查模型类型是否支持视频生成
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.VIDEO_GEN), int(ModelCategory.MULTIMODAL), int(ModelCategory.LLM)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model_name}' does not support video generation",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{model_name}'",
- error_type="permission_error",
- )
-
- # 处理图像数据
- oss_service = get_oss_service()
- if isinstance(image, str):
- # 检查是否是URL
- if image.startswith(('http://', 'https://')):
- # 直接使用URL
- image_url = image
- else:
- # 解码base64字符串
- image_bytes = base64.b64decode(image)
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-videos/i2v-source",
- original_filename="i2v_source.png"
- )
- else:
- # 处理UploadFile对象
- image_bytes = await image.read()
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-videos/i2v-source",
- original_filename=image.filename or "i2v_source.png"
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": prompt,
- "image": image_url,
- "size": size or "1280x720", # 使用OpenAI标准格式
- "duration": 5
- }
- # 使用统一的适配器方法发送请求
- api_url = f"{base_url}/videos/generations"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="video",
- timeout=60.0
- )
- # 处理响应
- videos = []
- for item in result_data.get("data", []):
- if item.get("url"):
- videos.append(item.get("url"))
- if not videos:
- raise OpenAICompatError(
- status_code=500,
- message="图生视频失败:未返回视频",
- error_type="video_generation_error",
- )
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=url, content_type="video/mp4") for url in videos]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- from app.services.crypto_utils import get_effective_api_key
- dashscope_api_key = get_effective_api_key(self.db, model_name, user.apikey if user else "") if user else ""
-
- from app.services.video_service import VideoService
- from app.schemas.video_schema import VideoGenerateRequest
-
- real_video_service = VideoService(
- self.db,
- user_id=int(user_id) if str(user_id).isdigit() else user_id,
- api_key=dashscope_api_key
- )
-
- # 解析并转换size格式
- try:
- openai_size, internal_size = parse_video_size(size or "1280x720")
- except ValueError as e:
- raise OpenAICompatError(
- status_code=400,
- message=str(e),
- error_type="invalid_request_error"
- )
-
- internal_req = VideoGenerateRequest(
- prompt=prompt,
- first_frame_url=image_url,
- resolution=internal_size, # 使用内部格式 "720P"
- duration=5,
- prompt_extend=True,
- watermark=False
- )
-
- # 提交异步任务
- task_resp = await real_video_service.generate(internal_req)
- task_id = task_resp.task_id
-
- # 阻塞轮询
- max_retries = 120
- poll_interval = 5
- final_video_url = None
-
- for _ in range(max_retries):
- await asyncio.sleep(poll_interval)
- status_result = await real_video_service.get_task_status(task_id)
-
- if status_result.task_status == "SUCCEEDED":
- final_video_url = status_result.video_url
- break
- elif status_result.task_status == "FAILED":
- raise OpenAICompatError(
- status_code=500,
- message=status_result.error_message or "底层图生视频任务失败",
- error_type="video_generation_error"
- )
-
- if not final_video_url:
- raise OpenAICompatError(
- status_code=504,
- message="图生视频任务超时,请稍后再试",
- error_type="timeout_error"
- )
-
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=final_video_url, content_type="video/mp4")]
- )
- bill = Decimal("0")
-
- # 记录兼容层日志
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=model_name,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=5,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(model_name, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=model_name,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="video_generation_error")
- async def rerank(
- self,
- request: RerankRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> RerankResponse:
- """
- 处理重排序(Rerank)请求
-
- Args:
- request: 重排序请求对象
- user_id: 用户ID
- api_key_id: 用于日志记录
- request_ip: 请求来源IP
-
- Returns:
- RerankResponse: 包含排序结果的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持重排序
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.RERANK), int(ModelCategory.EMBEDDING), int(ModelCategory.LLM)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support reranking",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
- else:
- user = self._get_user(user_id)
- if not user or not user.apikey:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured.",
- error_type="api_key_not_configured"
- )
- if model.is_local:
- # 本地模型处理
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "query": request.query,
- "documents": request.documents
- }
- if request.top_n is not None:
- payload["top_n"] = request.top_n
- if request.return_documents is not None:
- payload["return_documents"] = request.return_documents
- if request.user:
- payload["user"] = request.user
- # 使用统一的方法发送请求
- api_url = f"{base_url}/rerank"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="rerank",
- timeout=30.0
- )
- # 处理响应
- import logging
- logger = logging.getLogger(__name__)
-
- results_list = result_data.get("data", [])
- usage_data = result_data.get("usage", {})
-
- # 如果 data 为空,尝试从 results 字段获取(某些模型如硅基流动使用 results)
- if not results_list:
- results_list = result_data.get("results", [])
- logger.debug(f"Using 'results' field, found {len(results_list)} items")
-
- # 尝试从不同位置获取 token 信息
- if not usage_data or usage_data.get("total_tokens", 0) == 0:
- # 尝试从 meta.tokens 获取(硅基流动格式)
- meta = result_data.get("meta", {})
- tokens = meta.get("tokens", {})
- if tokens:
- input_tokens = tokens.get("input_tokens", 0)
- output_tokens = tokens.get("output_tokens", 0)
- total_tokens = input_tokens + output_tokens
- logger.debug(f"Using meta.tokens: input={input_tokens}, output={output_tokens}, total={total_tokens}")
- else:
- total_tokens = 0
- else:
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in results_list:
- # 获取文档内容(处理不同格式)
- doc_content = None
- if request.return_documents:
- doc = item.get("document")
- if isinstance(doc, dict):
- # 如果 document 是对象,尝试获取 text 字段
- doc_content = doc.get("text", "")
- elif isinstance(doc, str):
- # 如果 document 是字符串,直接使用
- doc_content = doc
-
- result_item = RerankResult(
- index=item.get("index", 0),
- relevance_score=item.get("relevance_score", 0.0)
- )
- if doc_content:
- result_item.document = doc_content
- data_list.append(result_item)
- else:
- # 云端模型处理(阿里云百炼)
- api_url = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {user.apikey}"
- }
-
- payload = {
- "model": model.model_code,
- "input": {
- "query": request.query,
- "documents": request.documents
- }
- }
- if request.top_n is not None:
- payload["parameters"] = {"top_n": request.top_n}
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(api_url, headers=headers, json=payload)
- response.raise_for_status()
- result_data = response.json()
- output = result_data.get("output", {})
- results_list = output.get("results", [])
- usage_data = result_data.get("usage", {})
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in results_list:
- result_item = RerankResult(
- index=item.get("index", 0),
- relevance_score=item.get("relevance_score", 0.0)
- )
- if request.return_documents:
- result_item.document = request.documents[item.get("index", 0)]
- data_list.append(result_item)
- # 记录日志
- bill = Decimal("0")
- call_log = log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=total_tokens,
- output_tokens=0,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
- return RerankResponse(
- model=request.model,
- data=data_list,
- usage=Usage(
- prompt_tokens=total_tokens,
- completion_tokens=0,
- total_tokens=total_tokens
- )
- )
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="rerank_error")
|