| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865 |
- """
- OpenAI 兼容服务层
- 完整支持 /v1/chat/completions 和 /v1/models 接口
- 支持多种模型提供商的自动适配
- """
- import asyncio
- import logging
- import time
- import uuid
- from decimal import Decimal
- import httpx
- from fastapi import UploadFile
- from app.services.api_call_log_service import ApiCallLogService
- from app.services.image_service import ImageGenerationService
- from app.services.model_adapters import BaseAdapter, get_adapter, ModelProvider
- from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
- from sqlalchemy import desc
- from sqlalchemy.orm import Session
- from app.models.model import ModelNew as Model, ModelPriceNew as ModelPrice
- from app.models.user import User
- from app.schemas.openai_compat import (
- ChatCompletionsRequest,
- ModelInfo,
- EmbeddingsRequest,
- EmbeddingsResponse,
- EmbeddingData,
- Usage,
- ModelsListResponse,
- ImageGenerationRequest,
- ImageGenerationResponse,
- ImageData,
- AudioTranscriptionResponse,
- AudioSpeechRequest,
- VideoGenerationRequest,
- VideoGenerationResponse,
- RerankRequest,
- RerankResponse,
- RerankResult,
- )
- from app.services.crypto_utils import decrypt_api_key
- from app.services.system_config_manager import get_config_bool
- logger = logging.getLogger(__name__)
- # ─────────────────────────────────────────────
- # 工具函数
- # ─────────────────────────────────────────────
- def parse_video_size(size: str) -> Tuple[str, str]:
- """
- 解析视频尺寸,支持多种格式并转换为OpenAI标准格式和内部格式
-
- Args:
- size: 视频尺寸,支持以下格式:
- - OpenAI格式: "1280x720", "1920x1080", "720x1280"
- - 简写格式: "720P", "1080P", "720p", "1080p"
-
- Returns:
- (openai_format, internal_format)
- 例如: ("1280x720", "720P")
-
- Raises:
- ValueError: 如果格式无效
- """
- import re
-
- size = size.strip()
-
- # 如果是OpenAI格式 (widthxheight)
- if 'x' in size.lower():
- match = re.match(r'^(\d+)x(\d+)$', size.lower())
- if not match:
- raise ValueError(f"Invalid size format: {size}. Expected format: 1280x720")
-
- width, height = int(match.group(1)), int(match.group(2))
-
- # 推断简写格式(基于高度)
- if height == 720:
- internal = "720P"
- elif height == 1080:
- internal = "1080P"
- else:
- # 对于非标准分辨率,使用高度作为简写
- internal = f"{height}P"
-
- return (f"{width}x{height}", internal)
-
- # 如果是简写格式 (720P, 1080P)
- else:
- match = re.match(r'^(\d+)p$', size.lower())
- if not match:
- raise ValueError(f"Invalid size format: {size}. Expected format: 720P or 1280x720")
-
- height = int(match.group(1))
- size_upper = f"{height}P"
-
- # 标准分辨率映射(16:9比例)
- if height == 720:
- return ("1280x720", "720P")
- elif height == 1080:
- return ("1920x1080", "1080P")
- else:
- # 对于非标准分辨率,假设16:9比例
- width = int(height * 16 / 9)
- return (f"{width}x{height}", size_upper)
- class OpenAICompatError(Exception):
- """OpenAI 兼容服务错误"""
- def __init__(self, status_code: int, message: str, error_type: str = "invalid_request_error"):
- self.status_code = status_code
- self.message = message
- self.error_type = error_type
- super().__init__(message)
- class OpenAICompatService:
- """OpenAI API 兼容服务"""
- def __init__(self, db: Session):
- self.db = db
- self._user_cache: dict = {} # user_id → User,请求内缓存,避免重复查询
- def _get_user(self, user_id: str):
- """获取用户对象,同一请求内缓存,避免重复查询 users 表。"""
- if user_id not in self._user_cache:
- self._user_cache[user_id] = self.db.query(User).filter(User.id == user_id).first()
- return self._user_cache[user_id]
- # ─────────────────────────────────────────────
- # 主入口
- # ─────────────────────────────────────────────
- async def chat_completions(
- self,
- request: ChatCompletionsRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None,
- ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
- """
- 处理 Chat Completions 请求,包含日志记录与扣费。
- 权限检查流程:
- 1. 验证模型是否存在
- 2. 检查模型 is_api_enabled(云端模型)
- 3. 检查用户余额(云端模型)
- 4. 获取用户 API Key(云端模型)
- 5. 检查用户对本地模型的访问权限(本地模型)
- """
- log_service = ApiCallLogService(self.db)
- model = self._find_model(request.model, user_id)
- if not model:
- raise OpenAICompatError(
- status_code=404,
- message=f"The model '{request.model}' does not exist",
- error_type="model_not_found",
- )
- if not model.is_local and not model.is_api_enabled:
- raise OpenAICompatError(
- status_code=403,
- message=f"Model '{request.model}' does not support API calls",
- error_type="model_not_available",
- )
- # realtime 模型使用 WebSocket 实时流协议,不支持此 REST 接口
- if "realtime" in request.model.lower() and not model.is_local:
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' is a real-time streaming model that uses WebSocket protocol. "
- f"It cannot be called via /api/v1/chat/completions. "
- f"Please use the WebSocket API instead.",
- error_type="model_not_supported",
- )
- # 检查本地模型的访问权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
- user_api_key: Optional[str] = None
- if not model.is_local:
- user = self._get_user(user_id)
- if not user:
- raise OpenAICompatError(
- status_code=401,
- message="User not found",
- error_type="authentication_error",
- )
- # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己配置的 apikey
- if model.encrypted_api_key:
- from app.services.crypto_utils import decrypt_api_key
- decrypted = decrypt_api_key(model.encrypted_api_key)
- user_api_key = decrypted if decrypted else None
- if not user_api_key:
- user_api_key = user.apikey
- if not user_api_key:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured. Please configure your API key in settings.",
- error_type="api_key_not_configured",
- )
- # ── OCR 模型校验:必须包含图片 ────────────────────────────────────
- OCR_MODELS = ("qwen-vl-ocr",)
- if model.model_code in OCR_MODELS and not model.is_local:
- has_image = any(
- isinstance(msg.content, list) and
- any(isinstance(part, dict) and part.get("type") == "image_url"
- for part in msg.content)
- for msg in request.messages
- )
- # 也兼容 Pydantic 对象形式
- if not has_image:
- from app.schemas.openai_compat import ContentPartImage
- has_image = any(
- isinstance(msg.content, list) and
- any(isinstance(part, ContentPartImage) for part in msg.content)
- for msg in request.messages
- )
- if not has_image:
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model.model_code}' is an OCR model and requires at least one image in the messages. "
- f"Please include an image_url content part in your user message.",
- error_type="invalid_request_error",
- )
- # ── 流式 ──────────────────────────────────────────────────────────
- if request.stream:
- raw_stream = await self._call_local_model(model, request) if model.is_local \
- else await self._call_cloud_model(model, request, user_api_key)
- async def _stream_with_billing() -> AsyncGenerator[str, None]:
- input_text = "".join(
- [m.content for m in request.messages if isinstance(m.content, str)]
- )
- input_tokens = max(int(len(input_text) * 1.2), 1)
- output_tokens = 0
- stream_error: Optional[Exception] = None
- try:
- async for chunk in raw_stream:
- yield chunk
- if isinstance(chunk, str) and chunk.startswith("data: ") \
- and "data: [DONE]" not in chunk:
- try:
- import json as _json
- data_dict = _json.loads(chunk[6:])
- delta = data_dict.get("choices", [{}])[0] \
- .get("delta", {}).get("content", "")
- if delta:
- output_tokens += max(int(len(delta) * 1.2), 1)
- # 优先使用上游返回的真实 usage(部分模型在最后一个 chunk 里带)
- usage = data_dict.get("usage")
- if usage:
- input_tokens = usage.get("prompt_tokens", input_tokens)
- output_tokens = usage.get("completion_tokens", output_tokens)
- except Exception:
- pass
- except (GeneratorExit, asyncio.CancelledError):
- # 客户端断开/任务取消 - 仍按已产生的 token 扣费
- raise
- except Exception as exc:
- stream_error = exc
- raise
- finally:
- # 关键:无论流正常结束、客户端中断、还是异常都要扣费
- # 防止 "token 已被消耗但未扣费" 的资损
- try:
- if stream_error is not None:
- # 上游错误 - 记录失败日志
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=0, output_tokens=0,
- bill=0, status="failed",
- error_message=str(stream_error), request_ip=request_ip,
- )
- else:
- bill = self.calculate_bill(model, input_tokens, output_tokens)
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=input_tokens, output_tokens=output_tokens,
- bill=float(bill), status="success", request_ip=request_ip,
- )
- except Exception as fin_exc:
- logger.error("流式响应收尾日志记录失败: %s", fin_exc)
- return _stream_with_billing()
- # ── 非流式 ────────────────────────────────────────────────────────
- try:
- result = await self._call_local_model(model, request) if model.is_local \
- else await self._call_cloud_model(model, request, user_api_key)
- input_tokens, output_tokens = self.extract_usage_from_response(result)
- bill = self.calculate_bill(model, input_tokens, output_tokens)
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=input_tokens, output_tokens=output_tokens,
- bill=float(bill), status="success", request_ip=request_ip,
- )
- return result
- except OpenAICompatError:
- raise
- except Exception as exc:
- error_msg = str(exc) or repr(exc)
- logger.warning(
- "[CHAT_COMPLETION_FAILED] model=%s user_id=%s error_type=%s error=%s",
- request.model, user_id, type(exc).__name__, error_msg,
- )
- log_service.create_log(
- user_id=user_id, api_key_id=api_key_id,
- model_id=model.id, model_name=request.model,
- is_local=model.is_local,
- input_tokens=0, output_tokens=0,
- bill=0, status="failed",
- error_message=error_msg, request_ip=request_ip,
- )
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="upstream_error")
- # ─────────────────────────────────────────────
- # 构建请求体(完整参数透传)
- # ─────────────────────────────────────────────
- def _build_request_body(
- self, request: ChatCompletionsRequest, model_name: str
- ) -> Dict[str, Any]:
- """将 ChatCompletionsRequest 转换为上游 API 请求体,透传所有非 None 参数"""
- # 序列化消息(content 支持 str 或 list)
- messages = []
- for msg in request.messages:
- m: Dict[str, Any] = {"role": msg.role}
- if msg.content is None:
- m["content"] = None
- elif isinstance(msg.content, str):
- m["content"] = msg.content
- else:
- # 多模态内容列表
- parts = []
- for part in msg.content:
- part_dict = part.model_dump(exclude_none=True)
- # 校验 image_url 格式
- if part_dict.get("type") == "image_url":
- url = (part_dict.get("image_url") or {}).get("url", "")
- if url and not url.startswith("data:"):
- import os
- ext = os.path.splitext(url.split("?")[0].lower())[1]
- SUPPORTED_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif', '.bmp'}
- if ext and ext not in SUPPORTED_IMAGE_EXTS:
- raise OpenAICompatError(
- status_code=400,
- message=f"Unsupported image format '{ext}'. "
- f"Supported formats: jpg, jpeg, png, webp, gif, bmp.",
- error_type="invalid_request_error",
- )
- parts.append(part_dict)
- m["content"] = parts
- if msg.name is not None:
- m["name"] = msg.name
- if msg.tool_calls is not None:
- m["tool_calls"] = [tc.model_dump() for tc in msg.tool_calls]
- if msg.tool_call_id is not None:
- m["tool_call_id"] = msg.tool_call_id
- messages.append(m)
- body: Dict[str, Any] = {
- "model": model_name,
- "messages": messages,
- "stream": request.stream,
- }
- # 可选参数:只有非 None 才透传
- optional_fields = [
- "temperature", "top_p", "n", "stop",
- "presence_penalty", "frequency_penalty",
- "logit_bias", "logprobs", "top_logprobs",
- "seed", "user", "service_tier", "store", "metadata",
- "parallel_tool_calls",
- ]
- for field in optional_fields:
- val = getattr(request, field, None)
- if val is not None:
- body[field] = val
- # max_tokens / max_completion_tokens(优先新版)
- if request.max_completion_tokens is not None:
- body["max_completion_tokens"] = request.max_completion_tokens
- elif request.max_tokens is not None:
- body["max_tokens"] = request.max_tokens
- # 流式选项
- if request.stream and request.stream_options is not None:
- body["stream_options"] = request.stream_options.model_dump(exclude_none=True)
- # 工具调用
- if request.tools is not None:
- body["tools"] = [t.model_dump(exclude_none=True) for t in request.tools]
- if request.tool_choice is not None:
- if isinstance(request.tool_choice, str):
- body["tool_choice"] = request.tool_choice
- else:
- body["tool_choice"] = request.tool_choice.model_dump()
- # 响应格式
- if request.response_format is not None:
- body["response_format"] = request.response_format.model_dump(exclude_none=True)
- return body
- # ─────────────────────────────────────────────
- # 本地模型调用
- # ─────────────────────────────────────────────
- async def _call_local_model(
- self, model: Model, request: ChatCompletionsRequest
- ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- headers: Dict[str, str] = {"Content-Type": "application/json"}
- # 本地模型不使用用户的API密钥,而是使用模型配置的API密钥
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
- # 使用模型的name字段作为实际模型名称
- actual_name = model.name
- body = self._build_request_body(request, actual_name)
- api_url = f"{base_url}/chat/completions"
- if request.stream:
- return self._stream_response(api_url, headers, body)
- else:
- return await self._non_stream_response(api_url, headers, body)
- # ─────────────────────────────────────────────
- # 云端模型调用(阿里云百炼 OpenAI 兼容模式)
- # ─────────────────────────────────────────────
- async def _call_cloud_model(
- self, model: Model, request: ChatCompletionsRequest, user_api_key: str
- ) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
- api_url = "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {user_api_key}",
- }
- body = self._build_request_body(request, model.model_code)
- if request.stream:
- return self._stream_response(api_url, headers, body)
- else:
- return await self._non_stream_response(api_url, headers, body)
- # ─────────────────────────────────────────────
- # HTTP 请求封装
- # ─────────────────────────────────────────────
- async def _non_stream_response(
- self, api_url: str, headers: Dict[str, str], body: Dict[str, Any],
- timeout: float = 300.0,
- ) -> Dict[str, Any]:
- model_name = body.get("model", "unknown")
- try:
- async with httpx.AsyncClient(timeout=timeout) as client:
- resp = await client.post(api_url, headers=headers, json=body)
- if resp.status_code >= 400:
- error_msg = self._extract_upstream_error(resp)
- logger.warning(
- "[UPSTREAM_ERROR] model=%s status=%d url=%s response=%s",
- model_name, resp.status_code, api_url, resp.text[:500],
- )
- raise OpenAICompatError(
- status_code=resp.status_code,
- message=error_msg,
- error_type="upstream_error",
- )
- return resp.json()
- except httpx.ReadTimeout:
- logger.warning("[UPSTREAM_TIMEOUT] model=%s url=%s timeout=%ss", model_name, api_url, timeout)
- raise OpenAICompatError(
- status_code=504,
- message=f"模型 '{model_name}' 响应超时({timeout}s),请稍后重试或换个模型",
- error_type="timeout_error",
- )
- async def _stream_response(
- self, api_url: str, headers: Dict[str, str], body: Dict[str, Any]
- ) -> AsyncGenerator[str, None]:
- """SSE 流式响应生成器"""
- model_name = body.get("model", "unknown")
- try:
- async with httpx.AsyncClient(
- timeout=httpx.Timeout(30.0, read=None)
- ) as client:
- async with client.stream("POST", api_url, headers=headers, json=body) as resp:
- if resp.status_code >= 400:
- error_body = await resp.aread()
- error_text = error_body.decode("utf-8", errors="replace")
- logger.warning(
- "[UPSTREAM_STREAM_ERROR] model=%s status=%d url=%s response=%s",
- model_name, resp.status_code, api_url, error_text[:500],
- )
- # 使用统一的错误提取方法
- error_detail = self._extract_upstream_error(resp)
- # 针对特定状态码提供友好提示
- if resp.status_code == 401:
- raise OpenAICompatError(
- status_code=500,
- message=f"模型 '{model_name}' 认证失败(401 Unauthorized)。\n"
- f"请在管理后台检查该模型的 API Key 配置是否正确。\n"
- f"错误详情: {error_detail}",
- error_type="authentication_error",
- )
- elif resp.status_code == 404:
- raise OpenAICompatError(
- status_code=500,
- message=f"模型 '{model_name}' 接口不存在(404 Not Found)。请检查 Base URL 和模型名称是否正确。\n"
- f"错误详情: {error_detail}",
- error_type="not_found_error",
- )
- else:
- raise OpenAICompatError(
- status_code=resp.status_code,
- message=f"模型 '{model_name}' 调用失败({resp.status_code}): {error_detail}",
- error_type="upstream_error",
- )
- async for line in resp.aiter_lines():
- if line.startswith("data: "):
- data = line[6:]
- if data.strip() == "[DONE]":
- yield "data: [DONE]\n\n"
- return
- yield f"data: {data}\n\n"
- elif line.strip():
- yield f"data: {line}\n\n"
- except httpx.ConnectTimeout:
- logger.warning("[UPSTREAM_CONNECT_TIMEOUT] model=%s url=%s", model_name, api_url)
- raise OpenAICompatError(
- status_code=504,
- message=f"模型 '{model_name}' 连接超时,请稍后重试",
- error_type="timeout_error",
- )
- def _extract_upstream_error(self, resp: httpx.Response) -> str:
- """从上游错误响应中提取错误信息,确保始终返回有意义的内容"""
- raw_text = ""
- try:
- raw_text = resp.text
- data = resp.json()
- # 标准 OpenAI 格式: {"error": {"message": "...", "type": "...", "code": "..."}}
- err = data.get("error", {})
- if isinstance(err, dict):
- msg = err.get("message", "")
- err_type = err.get("type", "")
- code = err.get("code", "")
- parts = [p for p in [msg, f"type={err_type}" if err_type else "", f"code={code}" if code else ""] if p]
- if parts:
- return " | ".join(parts)
- # DashScope 格式: {"code": "...", "message": "...", "request_id": "..."}
- code = data.get("code", "")
- msg = data.get("message", "")
- request_id = data.get("request_id", "")
- if msg:
- parts = [msg]
- if code:
- parts.append(f"code={code}")
- if request_id:
- parts.append(f"request_id={request_id}")
- return " | ".join(parts)
- # 兜底:返回整个 JSON
- if data:
- return str(data)
- except Exception:
- pass
- # JSON 解析失败或为空,返回原始文本
- if raw_text:
- return raw_text[:500]
- return f"Upstream error {resp.status_code} (empty response body)"
- # ─────────────────────────────────────────────
- # Models 列表
- # ─────────────────────────────────────────────
- def get_available_models(self, user_id: str, key_type: str = "public") -> ModelsListResponse:
- """返回用户可用的模型,根据API密钥类型过滤"""
- models_data: List[ModelInfo] = []
- # 根据密钥类型返回相应的模型
- if key_type == "public":
- # 公钥只能访问云端模型
- cloud_models = (
- self.db.query(Model)
- .filter(Model.is_local == False, Model.is_api_enabled == True)
- .all()
- )
- for m in cloud_models:
- models_data.append(
- ModelInfo(
- id=m.model_code,
- object="model",
- created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
- owned_by=m.supplier or "platform",
- )
- )
- elif key_type == "local":
- # 检查本地模型是否启用
- if get_config_bool("enable_local_models", True):
- # 如果本地模型启用,返回所有本地模型
- local_models = (
- self.db.query(Model)
- .filter(
- Model.is_local == True
- )
- .all()
- )
- for m in local_models:
- models_data.append(
- ModelInfo(
- id=f"local:{m.id}",
- object="model",
- created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
- owned_by="local",
- )
- )
- else:
- # 如果本地模型未启用,返回用户有权限的模型
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- permissions = permission_service.get_user_model_permissions(user_id)
- permitted_model_ids = [perm["model_id"] for perm in permissions if perm["has_access"]]
-
- local_models = (
- self.db.query(Model)
- .filter(
- Model.is_local == True,
- Model.id.in_(permitted_model_ids)
- )
- .all()
- )
- for m in local_models:
- models_data.append(
- ModelInfo(
- id=f"local:{m.id}",
- object="model",
- created=int(m.created_at.timestamp()) if m.created_at else int(time.time()),
- owned_by="local",
- )
- )
- return ModelsListResponse(object="list", data=models_data)
- # ─────────────────────────────────────────────
- # 工具方法
- # ─────────────────────────────────────────────
- async def _handle_local_model_request(
- self,
- api_url: str,
- headers: Dict[str, str],
- payload: Dict[str, Any],
- model_name: str,
- base_url: str,
- endpoint_type: str = "chat",
- timeout: float = 60.0,
- return_raw_response: bool = False
- ) -> Union[Dict[str, Any], httpx.Response]:
- """
- 统一处理本地模型 HTTP 请求,直接透传 OpenAI 格式
-
- 注意:本地模型必须是 OpenAI 兼容的,不做任何格式适配
-
- Args:
- api_url: 请求URL
- headers: 请求头
- payload: 请求体(OpenAI 格式)
- model_name: 模型名称(用于错误提示)
- base_url: 模型的 base_url
- endpoint_type: 端点类型(保留参数,暂未使用)
- timeout: 超时时间
- return_raw_response: 是否返回原始响应对象(用于处理音频等二进制数据)
-
- Returns:
- 响应的 JSON 数据(OpenAI 格式)或原始 httpx.Response 对象
-
- Raises:
- OpenAICompatError: 请求失败时抛出
- """
- async with httpx.AsyncClient(timeout=timeout) as client:
- try:
- response = await client.post(api_url, headers=headers, json=payload)
- response.raise_for_status()
-
- # 如果需要原始响应(例如音频数据),直接返回
- if return_raw_response:
- return response
-
- # 直接返回 JSON 响应,不做任何转换
- result = response.json()
- return result
-
- except httpx.HTTPStatusError as e:
- if e.response.status_code == 401:
- # 认证失败
- error_detail = ""
- try:
- error_data = e.response.json()
- if isinstance(error_data, dict):
- error_detail = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
- except:
- error_detail = e.response.text[:200]
-
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 '{model_name}' 认证失败(401 Unauthorized)。\n"
- f"请在管理后台检查该模型的 API Key 配置是否正确。\n"
- f"错误详情: {error_detail}",
- error_type="authentication_error",
- )
- elif e.response.status_code == 404:
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 '{model_name}' 接口不存在(404 Not Found)。请检查 Base URL 和模型名称是否正确。",
- error_type="not_found_error",
- )
- else:
- # 其他 HTTP 错误
- error_detail = ""
- try:
- error_data = e.response.json()
- if isinstance(error_data, dict):
- error_detail = error_data.get("error", {}).get("message", "") or error_data.get("message", "")
- except:
- error_detail = e.response.text[:200]
-
- raise OpenAICompatError(
- status_code=e.response.status_code,
- message=f"本地模型调用失败({e.response.status_code}): {error_detail}",
- error_type="upstream_error",
- )
- except httpx.TimeoutException:
- raise OpenAICompatError(
- status_code=504,
- message=f"本地模型 '{model_name}' 请求超时。请检查网络连接或增加超时时间。",
- error_type="timeout_error",
- )
- except httpx.RequestError as e:
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 '{model_name}' 请求失败: {str(e)}",
- error_type="request_error",
- )
- def _find_model(self, model_name: str, user_id: str) -> Optional[Model]:
- # 优先识别 local:{id} 格式(精确匹配)
- if model_name.startswith("local:"):
- try:
- model_id = int(model_name[6:])
- except ValueError:
- return None
- return (
- self.db.query(Model)
- .filter(Model.id == model_id, Model.is_local == True)
- .first()
- )
- # supplier/name 格式(本地模型)
- if "/" in model_name:
- parts = model_name.split("/", 1)
- if len(parts) == 2:
- supplier, name = parts
- local_with_supplier = (
- self.db.query(Model)
- .filter(
- Model.supplier == supplier,
- Model.display_name == name,
- Model.is_local == True,
- )
- .order_by(Model.created_at.desc())
- .first()
- )
- if local_with_supplier:
- return local_with_supplier
- # 云端模型按 model_code 精确匹配
- cloud = (
- self.db.query(Model)
- .filter(Model.model_code == model_name, Model.is_local == False)
- .first()
- )
- if cloud:
- return cloud
- # 本地模型按 display_name 查找
- return (
- self.db.query(Model)
- .filter(Model.display_name == model_name, Model.is_local == True)
- .order_by(Model.created_at.desc())
- .first()
- )
- def calculate_bill(
- self, model: Model, input_tokens: int, output_tokens: int
- ) -> Decimal:
- """API 调用免费,始终返回 0"""
- return Decimal("0")
- def extract_usage_from_response(self, response: Dict[str, Any]) -> Tuple[int, int]:
- usage = response.get("usage", {})
- input_tokens = usage.get("prompt_tokens", 0)
- output_tokens = usage.get("completion_tokens", 0)
- return input_tokens, output_tokens
- async def _validate_model_and_balance(self, model_name: str, user_id: str) -> Model:
- """
- 验证模型状态
- Args:
- model_name: 模型名称
- user_id: 用户ID
- Returns:
- Model: 验证通过的模型对象
- Raises:
- OpenAICompatError: 模型不存在或不可用时抛出
- """
- model = self._find_model(model_name, user_id)
- if not model:
- raise OpenAICompatError(
- status_code=404,
- message=f"Model '{model_name}' not found",
- error_type="model_not_found"
- )
- if not model.is_local and not model.is_api_enabled:
- raise OpenAICompatError(
- status_code=403,
- message=f"Model '{model_name}' does not support API calls",
- error_type="model_not_available"
- )
- return model
-
- async def embeddings(
- self,
- request: EmbeddingsRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> EmbeddingsResponse:
- """
- 处理文本嵌入(Embeddings)请求
-
- Args:
- request: 嵌入请求对象
- user_id: 用户ID
- api_key_id: 用于日志记录
- request_ip: 请求来源IP
-
- Returns:
- EmbeddingsResponse: 包含向量数据的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
-
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持向量嵌入
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.EMBEDDING), int(ModelCategory.LLM), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support embeddings",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
- else:
- user = self._get_user(user_id)
- if not user:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured.",
- error_type="api_key_not_configured"
- )
- # 优先使用模型自带的 api_key(爬虫同步的),没有则 fallback 到用户自己配置的 apikey
- effective_api_key: Optional[str] = None
- if model.encrypted_api_key:
- from app.services.crypto_utils import decrypt_api_key as _decrypt
- decrypted = _decrypt(model.encrypted_api_key)
- effective_api_key = decrypted if decrypted else None
- if not effective_api_key:
- effective_api_key = user.apikey
- if not effective_api_key:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured.",
- error_type="api_key_not_configured"
- )
- texts = [request.input] if isinstance(request.input, str) else request.input
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "input": texts
- }
- if request.dimensions:
- payload["dimensions"] = request.dimensions
- if request.user:
- payload["user"] = request.user
- # 使用统一的方法发送请求
- api_url = f"{base_url}/embeddings"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="embedding",
- timeout=30.0
- )
- # 处理响应
- embeddings_list = result_data.get("data", [])
- usage_data = result_data.get("usage", {})
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in embeddings_list:
- data_list.append(
- EmbeddingData(
- index=item.get("index", 0),
- embedding=item.get("embedding", [])
- )
- )
- else:
- # 云端模型处理
- # 根据模型类型选择端点:多模态 embedding vs 文本 embedding
- # 多模态模型:名称含 vl/vision/multimodal
- code_lower = model.model_code.lower()
- is_multimodal = any(kw in code_lower for kw in ("vl", "vision", "multimodal"))
- if is_multimodal:
- api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/multimodal-embedding/multimodal-embedding"
- payload = {
- "model": model.model_code,
- "input": {
- "contents": [{"text": t} for t in texts]
- }
- }
- else:
- api_url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
- payload = {
- "model": model.model_code,
- "input": texts
- }
- if request.dimensions:
- payload["dimensions"] = request.dimensions
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {effective_api_key}"
- }
- # 多模态 embedding 的 dimension 放在 parameters 里
- if is_multimodal and request.dimensions:
- payload.setdefault("parameters", {})["dimension"] = request.dimensions
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(api_url, headers=headers, json=payload)
- response.raise_for_status()
- result_data = response.json()
- output = result_data.get("output", {})
- embeddings_list = output.get("embeddings", [])
- usage_data = result_data.get("usage", {})
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in embeddings_list:
- data_list.append(
- EmbeddingData(
- index=item.get("text_index", 0),
- embedding=item.get("embedding", [])
- )
- )
- bill = Decimal("0")
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=total_tokens,
- output_tokens=0,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
- return EmbeddingsResponse(
- model=request.model,
- data=data_list,
- usage=Usage(
- prompt_tokens=total_tokens,
- completion_tokens=0,
- total_tokens=total_tokens
- )
- )
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="embeddings_error")
-
- async def image_generations(
- self, request: ImageGenerationRequest, user_id: str, api_key_id: int, request_ip: str = None
- ) -> ImageGenerationResponse:
- """
- 处理图像生成请求
-
- 调用底层ImageGenerationService完成真实的图像生成与业务扣费。
-
- Args:
- request: 图像生成请求对象
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- ImageGenerationResponse: 符合OpenAI规范的图像响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- # 验证模型状态与用户基础余额
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持图像生成(文生图)
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.IMAGE_GEN), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support image generation. Use a model with category IMAGE_GEN or MULTIMODAL.",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- try:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
- else:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型 API Key 解密失败",
- error_type="configuration_error",
- )
- except Exception as e:
- raise OpenAICompatError(
- status_code=500,
- message=f"本地模型 API Key 处理失败: {str(e)}",
- error_type="configuration_error",
- )
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": request.prompt,
- "n": request.n or 1,
- "size": request.size or "1024x1024"
- }
- if request.quality:
- payload["quality"] = request.quality
- if request.response_format:
- payload["response_format"] = request.response_format
- if request.style:
- payload["style"] = request.style
- if request.user:
- payload["user"] = request.user
- # 使用统一的方法发送请求
- api_url = f"{base_url}/images/generations"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="image",
- timeout=60.0
- )
- # 处理响应
- images = []
- for item in result_data.get("data", []):
- if item.get("url"):
- images.append(item.get("url"))
- elif item.get("b64_json"):
- # 处理base64编码的图像
- import base64
- from app.services.oss_service import get_oss_service
- oss_service = get_oss_service()
- image_bytes = base64.b64decode(item.get("b64_json"))
- url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/generations",
- original_filename=f"generated_{time.time()}.png"
- )
- images.append(url)
- if not images:
- raise OpenAICompatError(
- status_code=500,
- message="图像生成失败:未返回图像",
- error_type="image_generation_error",
- )
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in images]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- # 获取用户API Key并实例化底层服务
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
- real_image_service = ImageGenerationService(self.db, api_key=dashscope_api_key)
-
- # 适配尺寸参数
- mapped_size = request.size.replace("x", "*") if request.size else "1024*1024"
-
- # 调用底层图像生成服务
- result_obj = await real_image_service.text_to_image(
- user_id=user_id,
- prompt=request.prompt,
- model=model.model_code,
- n=request.n or 1,
- size=mapped_size
- )
-
- if not result_obj.success:
- raise OpenAICompatError(
- status_code=500,
- message=result_obj.error or "图像生成失败",
- error_type="image_generation_error"
- )
-
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in result_obj.images]
- )
- bill = result_obj.bill
-
- # 记录日志
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=len(result.data),
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="image_generation_error")
- async def image_edits(
- self,
- image: Union[str, UploadFile],
- prompt: str,
- mask: Optional[Union[str, UploadFile]],
- model_name: str,
- n: int,
- size: str,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> ImageGenerationResponse:
- """
- 处理图像编辑(图生图)请求
-
- 接收上传图片,转存OSS后调用底层ImageGenerationService处理。
-
- Args:
- image: 用户上传的原始图片
- prompt: 对新图像的文本描述
- mask: 可选的遮罩图
- model_name: 模型名称
- n: 生成数量
- size: 生成尺寸
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- ImageGenerationResponse: 包含生成图片URL的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
-
- 需求: OpenAI兼容-图生图
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(model_name, user_id)
-
- # 检查模型类型是否支持图像编辑(图生图)
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.IMAGE_EDIT), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model_name}' does not support image editing. Use a model with category IMAGE_EDIT or MULTIMODAL.",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{model_name}'",
- error_type="permission_error",
- )
-
- import base64
- from app.services.oss_service import get_oss_service
- oss_service = get_oss_service()
-
- # 处理图像数据
- if isinstance(image, str):
- # 检查是否是URL
- if image.startswith(('http://', 'https://')):
- # 直接使用URL
- image_url = image
- else:
- # 解码base64字符串
- image_bytes = base64.b64decode(image)
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/edits",
- original_filename="edit_source.png"
- )
- else:
- # 处理UploadFile对象
- image_bytes = await image.read()
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/edits",
- original_filename=image.filename or "edit_source.png"
- )
- image_urls = [image_url]
-
- # 处理遮罩数据
- if mask:
- if isinstance(mask, str):
- # 检查是否是URL
- if mask.startswith(('http://', 'https://')):
- # 直接使用URL
- mask_url = mask
- else:
- # 解码base64字符串
- mask_bytes = base64.b64decode(mask)
- mask_url = oss_service.upload_file(
- mask_bytes,
- prefix="ai-images/edits",
- original_filename="edit_mask.png"
- )
- else:
- # 处理UploadFile对象
- mask_bytes = await mask.read()
- mask_url = oss_service.upload_file(
- mask_bytes,
- prefix="ai-images/edits",
- original_filename=mask.filename or "edit_mask.png"
- )
- image_urls.append(mask_url)
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": prompt,
- "n": n,
- "size": size or "1024x1024"
- }
-
- # 处理图像和遮罩
- if len(image_urls) == 1:
- payload["image"] = image_urls[0]
- elif len(image_urls) == 2:
- payload["image"] = image_urls[0]
- payload["mask"] = image_urls[1]
-
- if prompt:
- payload["prompt"] = prompt
- # 使用统一的方法发送请求
- api_url = f"{base_url}/images/edits"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="image",
- timeout=60.0
- )
- # 处理响应
- images = []
- for item in result_data.get("data", []):
- if item.get("url"):
- images.append(item.get("url"))
- elif item.get("b64_json"):
- # 处理base64编码的图像
- image_bytes = base64.b64decode(item.get("b64_json"))
- url = oss_service.upload_file(
- image_bytes,
- prefix="ai-images/edits",
- original_filename=f"edited_{time.time()}.png"
- )
- images.append(url)
- if not images:
- raise OpenAICompatError(
- status_code=500,
- message="图像编辑失败:未返回图像",
- error_type="image_edit_error",
- )
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in images]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
-
- from app.services.image_service import ImageGenerationService
- real_image_service = ImageGenerationService(self.db, api_key=dashscope_api_key)
-
- mapped_size = size.replace("x", "*") if size else "1024*1024"
-
- result_obj = await real_image_service.image_to_image(
- user_id=user_id,
- image_urls=image_urls,
- prompt=prompt,
- model=model.model_code,
- n=n,
- size=mapped_size
- )
-
- if not result_obj.success:
- raise OpenAICompatError(
- status_code=500,
- message=result_obj.error or "图生图编辑失败",
- error_type="image_edit_error"
- )
-
- result = ImageGenerationResponse(
- created=int(time.time()),
- data=[ImageData(url=url) for url in result_obj.images]
- )
- bill = result_obj.bill
-
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=model_name,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=len(result.data),
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(model_name, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=model_name,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="image_edit_error")
-
- async def audio_transcriptions(
- self,
- file: Union[str, UploadFile],
- model_name: str,
- language: Optional[str],
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> AudioTranscriptionResponse:
- """
- 处理语音识别(STT)请求
-
- 接收上传的音频文件,转换为Base64编码后调用底层ASRService处理。
- 包含模型名称从OpenAI(whisper-1)到DashScope原生模型的映射。
-
- Args:
- file: 客户端上传的音频文件
- model_name: 模型名称
- language: 语言代码 (ISO-639-1)
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- AudioTranscriptionResponse: 包含识别文本的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
-
- 需求: OpenAI兼容-语音转文字
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- actual_model = "qwen3-asr-flash" if model_name in ["whisper-1", "whisper-large-v3"] else model_name
- # realtime 模型使用 WebSocket 实时流协议,不支持此文件上传接口
- if "realtime" in actual_model.lower():
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{actual_model}' is a real-time streaming model that uses WebSocket protocol. "
- f"It cannot be called via /api/v1/audio/transcriptions. "
- f"Please use the WebSocket API instead.",
- error_type="model_not_supported",
- )
- model = await self._validate_model_and_balance(actual_model, user_id)
-
- # 检查模型类型是否支持语音识别
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.STT), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model_name}' does not support speech transcription",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{model_name}'",
- error_type="permission_error",
- )
-
- import base64
- import httpx
- if isinstance(file, str):
- # 检查是否是URL
- if file.startswith(('http://', 'https://')):
- # 从URL下载音频文件并转换为base64
- async with httpx.AsyncClient() as client:
- response = await client.get(file)
- response.raise_for_status()
- audio_bytes = response.content
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
- else:
- # 使用base64字符串
- audio_base64 = file
- else:
- # 处理UploadFile对象
- audio_bytes = await file.read()
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "file": audio_base64
- }
- if language:
- payload["language"] = language
- # 使用统一的方法发送请求
- api_url = f"{base_url}/audio/transcriptions"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="audio_stt",
- timeout=60.0
- )
- # 处理响应
- text = result_data.get("text", "")
- duration_seconds = len(audio_base64) // 10000 # 粗略估算
- text_length = len(text)
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
-
- from app.services.asr_service import ASRService
- from app.schemas.audio_schema import ASRRequest
-
- real_asr_service = ASRService(self.db, user_id=user_id, api_key=dashscope_api_key)
- internal_req = ASRRequest(
- model=actual_model,
- audio_base64=audio_base64,
- language=language
- )
-
- asr_response = await real_asr_service.recognize(internal_req)
-
- text = asr_response.text
- duration_seconds = asr_response.usage.seconds if asr_response.usage else 0
- text_length = len(text) if text else 0
- bill = Decimal("0")
-
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=model_name,
- is_local=model.is_local if model else False,
- input_tokens=duration_seconds,
- output_tokens=text_length,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return AudioTranscriptionResponse(text=text)
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(model_name, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=model_name,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="stt_error")
-
- async def audio_speech(
- self,
- request: AudioSpeechRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> Tuple[AsyncGenerator, str]:
- """
- 处理文字转语音(TTS)请求
-
- 执行全量音色映射与模型能力校验。
- 通过底层 TTSService 生成语音并转存 OSS,随后转换为流式下发。
-
- Args:
- request: TTS请求对象
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- Tuple[AsyncGenerator, str]: 音频二进制生成器与 MIME 类型
-
- Raises:
- OpenAICompatError: 模型不支持所选音色或处理失败时抛出
-
- 需求: OpenAI兼容-文字转语音
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- actual_model = "cosyvoice-v3-flash" if request.model in ["tts-1", "tts-1-hd"] else request.model
- # realtime 模型使用 WebSocket 实时流协议,不支持此 REST 接口
- if "realtime" in actual_model.lower():
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{actual_model}' is a real-time streaming model that uses WebSocket protocol. "
- f"It cannot be called via /api/v1/audio/speech. "
- f"Please use the WebSocket API instead.",
- error_type="model_not_supported",
- )
- # cosyvoice-clone 系列:voice 参数就是 voice_id,不做映射,直接透传
- is_clone = "clone" in actual_model.lower()
- voice_map = {
- "alloy": "longxiaochun_v3",
- "echo": "longcheng_v3",
- "shimmer": "longwan_v3",
- "onyx": "longhua_v3",
- "nova": "longxiaoxia_v3",
- "fable": "longshu_v3",
- "sunny": "longanyang",
- "lively": "longanhuan",
- "cute_girl": "longhuhu_v3",
- "cute_boy": "longniuniu_v3",
- "bubble": "longpaopao_v3",
- "naughty": "longjielidou_v3",
- "bold_girl": "longxian_v3",
- "cantonese_f": "longjiaxin_v3",
- "cantonese_m": "longanyue_v3",
- "dongbei": "longlaotie_v3",
- "shanbei": "longshange_v3",
- "korean": "loongkyong_v3",
- "japanese": "loongriko_v3",
- "news_m": "longfei_v3",
- "news_f": "longxiaoxia_v3",
- "story_m": "longxiu_v3",
- "story_f": "longmiao_v3",
- "customer_service": "longyingxiao_v3",
- "monkey": "longhouge_v3",
- "robot": "longjiqi_v3",
- "daiyu": "longdaiyu_v3",
- "uncle": "longlaobo_v3",
- "aunt": "longlaoyi_v3"
- }
-
- actual_voice = request.voice if is_clone else voice_map.get(request.voice.lower(), request.voice)
- if "plus" in actual_model.lower():
- plus_allowed_voices = ["longanyang", "longanhuan"]
- if actual_voice not in plus_allowed_voices:
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{actual_model}' only supports voices: {plus_allowed_voices}. Requested: '{actual_voice}'.",
- error_type="invalid_request_error"
- )
- model = await self._validate_model_and_balance(actual_model, user_id)
-
- # 检查模型类型是否支持语音合成
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.TTS), int(ModelCategory.MULTIMODAL)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support speech synthesis",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- # 注意:本地模型使用原始的 OpenAI 音色名称,不进行映射
- payload = {
- "model": model.name,
- "input": request.input,
- "voice": request.voice, # 使用原始音色名称,不映射
- "response_format": request.response_format or "mp3",
- "speed": request.speed or 1.0
- }
- # 使用统一的适配器方法发送请求(获取原始响应)
- api_url = f"{base_url}/audio/speech"
- response = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="audio_tts",
- timeout=60.0,
- return_raw_response=True
- )
-
- # 检查响应类型
- content_type = response.headers.get("content-type", "")
-
- import httpx
- if "application/json" in content_type:
- # 响应是 JSON,包含音频 URL
- result_data = response.json()
- audio_url = result_data.get("audio_url") or result_data.get("url")
-
- if not audio_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型返回的 JSON 中未找到音频 URL",
- error_type="invalid_response",
- )
-
- # 下载音频并流式返回
- async def generate_audio():
- async with httpx.AsyncClient() as client:
- async with client.stream("GET", audio_url) as audio_response:
- audio_response.raise_for_status()
- async for chunk in audio_response.aiter_bytes():
- yield chunk
- else:
- # 响应直接是音频数据(如硅基流动)
- audio_bytes = response.content
-
- async def generate_audio():
- # 直接返回音频数据
- yield audio_bytes
- media_type = f"audio/{request.response_format or 'mp3'}" if (request.response_format or 'mp3') != 'mp3' else "audio/mpeg"
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- dashscope_api_key = user.apikey if user and user.apikey else ""
-
- from app.services.tts_service import TTSService
- from app.schemas.audio_schema import TTSRequest
-
- real_tts_service = TTSService(self.db, user_id=user_id, api_key=dashscope_api_key)
-
- internal_req = TTSRequest(
- text=request.input,
- model=actual_model,
- voice=actual_voice,
- format=request.response_format or "mp3",
- sample_rate=24000
- )
-
- tts_response = await real_tts_service.synthesize(internal_req)
-
- import httpx
- async def generate_audio():
- async with httpx.AsyncClient() as client:
- async with client.stream("GET", tts_response.audio_url) as response:
- response.raise_for_status()
- async for chunk in response.aiter_bytes():
- yield chunk
-
- media_type = f"audio/{request.response_format}" if request.response_format != 'mp3' else "audio/mpeg"
- bill = Decimal("0")
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=len(request.input),
- output_tokens=0,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
- return generate_audio(), media_type
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=request.model,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="tts_error")
-
- async def video_generations(
- self,
- request: VideoGenerationRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> VideoGenerationResponse:
- """
- 处理视频生成请求
-
- 调用底层VideoService提交异步任务,并通过轮询将其封装为同步阻塞接口。
-
- Args:
- request: 视频生成请求对象
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- VideoGenerationResponse: 包含最终视频URL的响应对象
-
- Raises:
- OpenAICompatError: 模型不支持或生成失败时抛出
-
- 需求: OpenAI兼容-视频生成
- """
- import time
- import asyncio
- from app.services.video_service import VideoService
- from app.schemas.video_schema import VideoGenerateRequest
-
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持视频生成
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.VIDEO_GEN), int(ModelCategory.MULTIMODAL), int(ModelCategory.LLM)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support video generation",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": request.prompt,
- "size": request.size or "1280x720", # 使用OpenAI标准格式
- "duration": request.duration or 5
- }
- # 使用统一的适配器方法发送请求
- api_url = f"{base_url}/videos/generations"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="video",
- timeout=60.0
- )
- # 处理响应
- videos = []
- for item in result_data.get("data", []):
- if item.get("url"):
- videos.append(item.get("url"))
- if not videos:
- raise OpenAICompatError(
- status_code=500,
- message="视频生成失败:未返回视频",
- error_type="video_generation_error",
- )
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=url, content_type="video/mp4") for url in videos]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- from app.services.crypto_utils import get_effective_api_key
- dashscope_api_key = get_effective_api_key(self.db, request.model, user.apikey if user else "") if user else ""
-
- real_video_service = VideoService(self.db, user_id=int(user_id) if str(user_id).isdigit() else user_id, api_key=dashscope_api_key)
-
- # 解析并转换size格式
- try:
- openai_size, internal_size = parse_video_size(request.size or "1280x720")
- except ValueError as e:
- raise OpenAICompatError(
- status_code=400,
- message=str(e),
- error_type="invalid_request_error"
- )
-
- internal_req = VideoGenerateRequest(
- prompt=request.prompt,
- resolution=internal_size, # 使用内部格式 "720P"
- duration=request.duration or 5,
- prompt_extend=True,
- watermark=False
- )
-
- # 提交异步任务
- task_resp = await real_video_service.generate(internal_req)
- task_id = task_resp.task_id
-
- # 阻塞轮询
- max_retries = 120
- poll_interval = 5
-
- final_video_url = None
-
- for _ in range(max_retries):
- await asyncio.sleep(poll_interval)
- status_result = await real_video_service.get_task_status(task_id)
-
- if status_result.task_status == "SUCCEEDED":
- final_video_url = status_result.video_url
- break
- elif status_result.task_status == "FAILED":
- raise OpenAICompatError(
- status_code=500,
- message=status_result.error_message or "底层视频生成任务失败",
- error_type="video_generation_error"
- )
-
- if not final_video_url:
- raise OpenAICompatError(
- status_code=504,
- message="视频生成任务超时,请稍后再试或通过平台任务列表查看结果",
- error_type="timeout_error"
- )
-
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=final_video_url, content_type="video/mp4")]
- )
- bill = Decimal("0")
-
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=request.duration or 5,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=request.model,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="video_generation_error")
-
- async def image_to_video_generations(
- self,
- image: Union[str, UploadFile],
- prompt: str,
- model_name: str,
- size: str,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> VideoGenerationResponse:
- """
- 处理图生视频(I2V)请求
-
- 接收上传图片,转存OSS获取URL,调用底层VideoService提交图生视频异步任务,
- 并轮询任务状态封装为同步接口返回。
-
- Args:
- image: 客户端上传的参考图像
- prompt: 对视频的文本描述
- model_name: 模型名称 (如 wan2.6-i2v)
- size: 视频分辨率 (如 720P)
- user_id: 用户ID
- api_key_id: API Key ID
- request_ip: 请求来源IP
-
- Returns:
- VideoGenerationResponse: 包含最终视频URL的响应对象
-
- Raises:
- OpenAICompatError: 处理失败时抛出
-
- 需求: OpenAI兼容-图生视频
- """
- import time
- import asyncio
- import base64
- from app.services.oss_service import get_oss_service
-
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(model_name, user_id)
-
- # 检查模型类型是否支持视频生成
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.VIDEO_GEN), int(ModelCategory.MULTIMODAL), int(ModelCategory.LLM)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{model_name}' does not support video generation",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{model_name}'",
- error_type="permission_error",
- )
-
- # 处理图像数据
- oss_service = get_oss_service()
- if isinstance(image, str):
- # 检查是否是URL
- if image.startswith(('http://', 'https://')):
- # 直接使用URL
- image_url = image
- else:
- # 解码base64字符串
- image_bytes = base64.b64decode(image)
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-videos/i2v-source",
- original_filename="i2v_source.png"
- )
- else:
- # 处理UploadFile对象
- image_bytes = await image.read()
- image_url = oss_service.upload_file(
- image_bytes,
- prefix="ai-videos/i2v-source",
- original_filename=image.filename or "i2v_source.png"
- )
-
- if model.is_local:
- # 从缓存获取模型信息
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "prompt": prompt,
- "image": image_url,
- "size": size or "1280x720", # 使用OpenAI标准格式
- "duration": 5
- }
- # 使用统一的适配器方法发送请求
- api_url = f"{base_url}/videos/generations"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="video",
- timeout=60.0
- )
- # 处理响应
- videos = []
- for item in result_data.get("data", []):
- if item.get("url"):
- videos.append(item.get("url"))
- if not videos:
- raise OpenAICompatError(
- status_code=500,
- message="图生视频失败:未返回视频",
- error_type="video_generation_error",
- )
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=url, content_type="video/mp4") for url in videos]
- )
- bill = Decimal("0")
- else:
- # 云端模型处理
- user = self._get_user(user_id)
- from app.services.crypto_utils import get_effective_api_key
- dashscope_api_key = get_effective_api_key(self.db, model_name, user.apikey if user else "") if user else ""
-
- from app.services.video_service import VideoService
- from app.schemas.video_schema import VideoGenerateRequest
-
- real_video_service = VideoService(
- self.db,
- user_id=int(user_id) if str(user_id).isdigit() else user_id,
- api_key=dashscope_api_key
- )
-
- # 解析并转换size格式
- try:
- openai_size, internal_size = parse_video_size(size or "1280x720")
- except ValueError as e:
- raise OpenAICompatError(
- status_code=400,
- message=str(e),
- error_type="invalid_request_error"
- )
-
- internal_req = VideoGenerateRequest(
- prompt=prompt,
- first_frame_url=image_url,
- resolution=internal_size, # 使用内部格式 "720P"
- duration=5,
- prompt_extend=True,
- watermark=False
- )
-
- # 提交异步任务
- task_resp = await real_video_service.generate(internal_req)
- task_id = task_resp.task_id
-
- # 阻塞轮询
- max_retries = 120
- poll_interval = 5
- final_video_url = None
-
- for _ in range(max_retries):
- await asyncio.sleep(poll_interval)
- status_result = await real_video_service.get_task_status(task_id)
-
- if status_result.task_status == "SUCCEEDED":
- final_video_url = status_result.video_url
- break
- elif status_result.task_status == "FAILED":
- raise OpenAICompatError(
- status_code=500,
- message=status_result.error_message or "底层图生视频任务失败",
- error_type="video_generation_error"
- )
-
- if not final_video_url:
- raise OpenAICompatError(
- status_code=504,
- message="图生视频任务超时,请稍后再试",
- error_type="timeout_error"
- )
-
- # 组装结果
- from app.schemas.openai_compat import VideoData
- result = VideoGenerationResponse(
- created=int(time.time()),
- data=[VideoData(url=final_video_url, content_type="video/mp4")]
- )
- bill = Decimal("0")
-
- # 记录兼容层日志
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=model_name,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=5,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
-
- return result
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model_obj = self._find_model(model_name, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_obj.id if model_obj else None,
- model_name=model_name,
- is_local=model_obj.is_local if model_obj else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="video_generation_error")
- async def rerank(
- self,
- request: RerankRequest,
- user_id: str,
- api_key_id: int,
- request_ip: Optional[str] = None
- ) -> RerankResponse:
- """
- 处理重排序(Rerank)请求
-
- Args:
- request: 重排序请求对象
- user_id: 用户ID
- api_key_id: 用于日志记录
- request_ip: 请求来源IP
-
- Returns:
- RerankResponse: 包含排序结果的响应对象
-
- Raises:
- OpenAICompatError: 处理失败或鉴权失败时抛出
- """
- log_service = ApiCallLogService(self.db)
-
- try:
- model = await self._validate_model_and_balance(request.model, user_id)
-
- # 检查模型类型是否支持重排序
- from app.models.model import ModelCategory
- if not any(int(c) in [int(ModelCategory.RERANK), int(ModelCategory.EMBEDDING), int(ModelCategory.LLM)] for c in (model.categories or [])):
- raise OpenAICompatError(
- status_code=400,
- message=f"Model '{request.model}' does not support reranking",
- error_type="model_not_supported",
- )
-
- # 检查本地模型权限
- if model.is_local:
- from app.services.user_local_model_permission_service import UserLocalModelPermissionService
- permission_service = UserLocalModelPermissionService(self.db)
- if not await permission_service.check_user_model_access(user_id, model.id):
- raise OpenAICompatError(
- status_code=403,
- message=f"You don't have permission to access model '{request.model}'",
- error_type="permission_error",
- )
- else:
- user = self._get_user(user_id)
- if not user or not user.apikey:
- raise OpenAICompatError(
- status_code=403,
- message="User API key not configured.",
- error_type="api_key_not_configured"
- )
- if model.is_local:
- # 本地模型处理
- from app.services.cache_service import CacheService
- from app.services.crypto_utils import decrypt_api_key
- model_data = await CacheService.get_model(model.id)
-
- if model_data:
- base_url = model_data.get("base_url", "").rstrip("/")
- local_api_key = model_data.get("local_api_key")
- else:
- # 从数据库获取
- base_url = (model.base_url or "").rstrip("/")
- local_api_key = model.local_api_key
- # 缓存模型信息
- await CacheService.set_model(model.id, {
- "base_url": base_url,
- "local_api_key": local_api_key,
- "is_local": model.is_local,
- "name": model.name
- })
- if not base_url:
- raise OpenAICompatError(
- status_code=500,
- message="本地模型未配置 Base URL",
- error_type="configuration_error",
- )
- # 构建请求头
- headers = {"Content-Type": "application/json"}
- if local_api_key:
- api_key = decrypt_api_key(local_api_key)
- if api_key:
- headers["Authorization"] = f"Bearer {api_key}"
-
- # 构建请求体(OpenAI 格式)
- payload = {
- "model": model.name,
- "query": request.query,
- "documents": request.documents
- }
- if request.top_n is not None:
- payload["top_n"] = request.top_n
- if request.return_documents is not None:
- payload["return_documents"] = request.return_documents
- if request.user:
- payload["user"] = request.user
- # 使用统一的方法发送请求
- api_url = f"{base_url}/rerank"
- result_data = await self._handle_local_model_request(
- api_url=api_url,
- headers=headers,
- payload=payload,
- model_name=model.name,
- base_url=base_url,
- endpoint_type="rerank",
- timeout=30.0
- )
- # 处理响应
- import logging
- logger = logging.getLogger(__name__)
-
- results_list = result_data.get("data", [])
- usage_data = result_data.get("usage", {})
-
- # 如果 data 为空,尝试从 results 字段获取(某些模型如硅基流动使用 results)
- if not results_list:
- results_list = result_data.get("results", [])
- logger.debug(f"Using 'results' field, found {len(results_list)} items")
-
- # 尝试从不同位置获取 token 信息
- if not usage_data or usage_data.get("total_tokens", 0) == 0:
- # 尝试从 meta.tokens 获取(硅基流动格式)
- meta = result_data.get("meta", {})
- tokens = meta.get("tokens", {})
- if tokens:
- input_tokens = tokens.get("input_tokens", 0)
- output_tokens = tokens.get("output_tokens", 0)
- total_tokens = input_tokens + output_tokens
- logger.debug(f"Using meta.tokens: input={input_tokens}, output={output_tokens}, total={total_tokens}")
- else:
- total_tokens = 0
- else:
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in results_list:
- # 获取文档内容(处理不同格式)
- doc_content = None
- if request.return_documents:
- doc = item.get("document")
- if isinstance(doc, dict):
- # 如果 document 是对象,尝试获取 text 字段
- doc_content = doc.get("text", "")
- elif isinstance(doc, str):
- # 如果 document 是字符串,直接使用
- doc_content = doc
-
- result_item = RerankResult(
- index=item.get("index", 0),
- relevance_score=item.get("relevance_score", 0.0)
- )
- if doc_content:
- result_item.document = doc_content
- data_list.append(result_item)
- else:
- # 云端模型处理(阿里云百炼)
- api_url = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {user.apikey}"
- }
-
- payload = {
- "model": model.model_code,
- "input": {
- "query": request.query,
- "documents": request.documents
- }
- }
- if request.top_n is not None:
- payload["parameters"] = {"top_n": request.top_n}
- async with httpx.AsyncClient(timeout=30.0) as client:
- response = await client.post(api_url, headers=headers, json=payload)
- response.raise_for_status()
- result_data = response.json()
- output = result_data.get("output", {})
- results_list = output.get("results", [])
- usage_data = result_data.get("usage", {})
- total_tokens = usage_data.get("total_tokens", 0)
-
- data_list = []
- for item in results_list:
- result_item = RerankResult(
- index=item.get("index", 0),
- relevance_score=item.get("relevance_score", 0.0)
- )
- if request.return_documents:
- result_item.document = request.documents[item.get("index", 0)]
- data_list.append(result_item)
- # 记录日志
- bill = Decimal("0")
- call_log = log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=total_tokens,
- output_tokens=0,
- bill=float(bill),
- status="success",
- request_ip=request_ip
- )
- return RerankResponse(
- model=request.model,
- data=data_list,
- usage=Usage(
- prompt_tokens=total_tokens,
- completion_tokens=0,
- total_tokens=total_tokens
- )
- )
-
- except Exception as e:
- error_msg = str(e) if not isinstance(e, OpenAICompatError) else e.message
- model = self._find_model(request.model, user_id)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model.id if model else None,
- model_name=request.model,
- is_local=model.is_local if model else False,
- input_tokens=0,
- output_tokens=0,
- bill=0,
- status="failed",
- error_message=error_msg,
- request_ip=request_ip
- )
- if isinstance(e, OpenAICompatError):
- raise e
- raise OpenAICompatError(status_code=500, message=error_msg, error_type="rerank_error")
|