| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811 |
- """
- OpenAI 兼容 API 路由
- 完整支持 /v1/chat/completions 和 /v1/models 接口
- """
- from fastapi import APIRouter, Depends, HTTPException, Request, Header, File, Form, UploadFile
- from fastapi.responses import StreamingResponse
- from sqlalchemy.orm import Session
- from typing import Optional
- import json
- from app.database import get_db, SessionLocal
- from app.services.openai_compat_service import OpenAICompatService, OpenAICompatError
- from app.services.platform_api_key_service import PlatformApiKeyService
- from app.schemas.openai_compat import (
- ChatCompletionsRequest,
- ChatCompletionsResponse,
- EmbeddingsRequest,
- EmbeddingsResponse,
- ModelsListResponse,
- ImageGenerationRequest,
- ImageGenerationResponse,
- ImageEditsRequest,
- AudioTranscriptionResponse,
- AudioTranscriptionsRequest,
- AudioTranslationsRequest,
- AudioSpeechRequest,
- AudioSpeechResponse,
- VideoGenerationRequest,
- VideoGenerationResponse,
- RerankRequest,
- RerankResponse,
- )
- router = APIRouter(prefix="/api/v1", tags=["OpenAI 兼容 API"])
- # ─────────────────────────────────────────────
- # 认证依赖
- # ─────────────────────────────────────────────
- async def get_api_key_auth(
- authorization: Optional[str] = Header(None),
- db: Session = Depends(get_db),
- ) -> tuple:
- """验证 Bearer Token,返回 (user_id, key_id, key_type)"""
- if not authorization:
- raise HTTPException(
- status_code=401,
- detail={"error": {"message": "Missing Authorization header", "type": "authentication_error", "code": "missing_auth"}},
- )
- if not authorization.startswith("Bearer "):
- raise HTTPException(
- status_code=401,
- detail={"error": {"message": "Invalid Authorization header format. Expected 'Bearer <token>'", "type": "authentication_error", "code": "invalid_auth_format"}},
- )
- api_key = authorization[7:]
- # 首先验证API密钥是否有效
- result = PlatformApiKeyService(db).verify_api_key(api_key)
- if not result:
- raise HTTPException(
- status_code=401,
- detail={"error": {"message": "Incorrect API key provided", "type": "authentication_error", "code": "invalid_api_key"}},
- )
-
- user_id, key_id = result
-
- # 从缓存获取API密钥类型
- from app.services.cache_service import CacheService
- key_data = await CacheService.get_api_key(key_id)
-
- if key_data:
- key_type = key_data.get("key_type", "public")
- else:
- # 从数据库获取
- from app.models.platform_api_key import PlatformApiKey
- api_key_record = db.query(PlatformApiKey).filter(
- PlatformApiKey.id == key_id
- ).first()
-
- key_type = api_key_record.key_type if api_key_record else "public"
- # 缓存API密钥信息
- await CacheService.set_api_key(key_id, {
- "key_type": key_type,
- "status": api_key_record.status if api_key_record else "active"
- })
-
- return (user_id, key_id, key_type) # (user_id, key_id, key_type)
- # ─────────────────────────────────────────────
- # POST /api/v1/chat/completions
- # ─────────────────────────────────────────────
- @router.post("/chat/completions", summary="聊天补全", description="OpenAI兼容的聊天补全接口。支持流式和非流式输出,支持多模态输入(文本、图片、音频)。")
- async def chat_completions(
- request: ChatCompletionsRequest,
- req: Request,
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db),
- ):
- """
- 聊天补全接口
-
- 完全兼容OpenAI的 /v1/chat/completions 接口规范。
-
- **认证方式**:
- - 在请求头中添加:`Authorization: Bearer YOUR_API_KEY`
- - API Key需要先通过 `/api/platform/api-keys` 接口创建
-
- **请求参数**:
- - **model**: 模型名称(必填),如 "gpt-4", "qwen-max" 等
- - **messages**: 消息列表(必填),包含role和content
- - **temperature**: 采样温度,0-2之间,默认1
- - **max_tokens**: 最大输出token数
- - **stream**: 是否流式输出,默认false
- - 更多参数请参考OpenAI官方文档
-
- **返回格式**:
- - 非流式:返回完整的JSON响应
- - 流式:返回SSE格式的数据流
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
- request_ip = req.client.host if req.client else None
- try:
- # 检查 API Key 类型与模型类型是否匹配
- model = service._find_model(request.model, user_id)
- if not model:
- raise OpenAICompatError(
- status_code=404,
- message=f"The model '{request.model}' does not exist",
- error_type="model_not_found",
- )
- if model.is_local and key_type != "local":
- raise OpenAICompatError(
- status_code=403,
- message="Local models can only be accessed with local API keys",
- error_type="permission_error",
- )
- if not model.is_local and key_type != "public":
- raise OpenAICompatError(
- status_code=403,
- message="Cloud models can only be accessed with public API keys",
- error_type="permission_error",
- )
- # 流式:使用独立 session 避免请求结束时被关闭,并避免重复调用上游
- if request.stream:
- stream_db = SessionLocal()
- async def stream_and_close():
- try:
- stream_service = OpenAICompatService(stream_db)
- raw = await stream_service.chat_completions(
- request, user_id, api_key_id, request_ip
- )
- async for chunk in raw:
- yield chunk
- finally:
- stream_db.close()
- return StreamingResponse(
- stream_and_close(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- },
- )
- # 非流式:使用依赖注入的 db
- result = await service.chat_completions(request, user_id, api_key_id, request_ip)
- return result
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}},
- )
- # ─────────────────────────────────────────────
- # GET /api/v1/models
- # ─────────────────────────────────────────────
- @router.get("/models", response_model=ModelsListResponse, summary="获取模型列表", description="获取当前用户可用的模型列表。根据API Key类型返回相应的模型(public key返回云端模型,local key返回本地模型)。")
- def list_models(
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db),
- ):
- """
- 获取可用模型列表
-
- **认证方式**:
- - 在请求头中添加:`Authorization: Bearer YOUR_API_KEY`
-
- **返回内容**:
- - 根据API Key类型返回相应的模型列表
- - public类型的key返回云端模型
- - local类型的key返回本地模型
- """
- user_id, _, key_type = auth
- service = OpenAICompatService(db)
- # 根据API密钥类型返回相应的模型列表
- return service.get_available_models(user_id, key_type)
- @router.post("/embeddings", response_model=EmbeddingsResponse)
- async def embeddings(
- request: EmbeddingsRequest,
- req: Request,
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 文本嵌入
-
- 将文本转换为向量表示。
-
- 参数说明:
- - input: 要嵌入的文本(字符串或字符串数组)
- - model: 使用的嵌入模型
- - encoding_format: 返回格式(float或base64)
- - dimensions: 向量维度(可选)
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- return await service.embeddings(request, user_id, api_key_id, req.client.host)
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
- @router.post("/images/generations", response_model=ImageGenerationResponse)
- async def image_generations(
- request: ImageGenerationRequest,
- req: Request,
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 文生图
-
- 根据文本描述生成图像。
-
- 参数说明:
- - prompt: 图像描述文本
- - model: 使用的图像生成模型
- - n: 生成图像数量
- - quality: 图像质量(standard或hd)
- - size: 图像尺寸
- - style: 图像风格(vivid或natural)
- - response_format: 返回格式(url或b64_json)
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- # 调用 Service 层的图像生成逻辑
- return await service.image_generations(request, user_id, api_key_id, req.client.host)
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
-
- @router.post("/images/edits", response_model=ImageGenerationResponse)
- async def image_edits(
- req: Request,
- image: Optional[UploadFile] = File(None, description="要编辑的原始图像(推荐PNG/JPG格式)"),
- prompt: Optional[str] = Form(None, description="对新图像的文本描述"),
- mask: Optional[UploadFile] = File(None, description="可选的遮罩层图像"),
- model: Optional[str] = Form("wan2.6-image", description="模型ID"),
- n: Optional[int] = Form(1, description="生成数量"),
- size: Optional[str] = Form("1024x1024", description="图像尺寸"),
- response_format: Optional[str] = Form("url", description="返回格式"),
- user: Optional[str] = Form(None, description="终端用户标识"),
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 图像编辑/图生图
-
- 基于原始图像和文本描述生成新图像。
-
- 支持两种请求方式:
- 1. application/json: 使用ImageEditsRequest模型,image和mask为base64编码
- 2. multipart/form-data: 使用File和Form参数,直接上传文件
-
- 参数说明:
- - image: 要编辑的原始图像
- - prompt: 对新图像的文本描述
- - mask: 遮罩层图像(可选)
- - model: 使用的图像生成模型
- - n: 生成图像数量
- - size: 图像尺寸
- - response_format: 返回格式(url或b64_json)
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- content_type = req.headers.get("content-type", "")
-
- # 处理JSON请求
- if "application/json" in content_type:
- body = await req.json()
- from app.schemas.openai_compat import ImageEditsRequest
- request_obj = ImageEditsRequest(**body)
- return await service.image_edits(
- image=request_obj.image,
- prompt=request_obj.prompt,
- mask=request_obj.mask,
- model_name=request_obj.model,
- n=request_obj.n,
- size=request_obj.size,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- # 处理multipart/form-data请求
- elif "multipart/form-data" in content_type and image and prompt:
- return await service.image_edits(
- image=image,
- prompt=prompt,
- mask=mask,
- model_name=model,
- n=n,
- size=size,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- else:
- raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
-
- @router.post("/audio/transcriptions", response_model=AudioTranscriptionResponse)
- async def audio_transcriptions(
- req: Request,
- file: Optional[UploadFile] = File(None, description="要识别的音频文件(如 mp3, wav)"),
- model: Optional[str] = Form(None, description="模型名称"),
- language: Optional[str] = Form(None, description="ISO-639-1 语言代码"),
- response_format: Optional[str] = Form("json", description="返回格式"),
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 语音转文字(STT)
-
- 将音频文件转换为文本。
-
- 支持两种请求方式:
- 1. application/json: 使用AudioTranscriptionsRequest模型,file为base64编码
- 2. multipart/form-data: 使用File和Form参数,直接上传文件
-
- 参数说明:
- - file: 要识别的音频文件
- - model: 语音识别模型名称
- - language: 音频语言代码(可选)
- - response_format: 返回格式
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- content_type = req.headers.get("content-type", "")
-
- # 处理JSON请求
- if "application/json" in content_type:
- body = await req.json()
- from app.schemas.openai_compat import AudioTranscriptionsRequest
- request_obj = AudioTranscriptionsRequest(**body)
- return await service.audio_transcriptions(
- file=request_obj.file,
- model_name=request_obj.model,
- language=request_obj.language,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- # 处理multipart/form-data请求
- elif "multipart/form-data" in content_type and file and model:
- return await service.audio_transcriptions(
- file=file,
- model_name=model,
- language=language,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- else:
- raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
- @router.post("/audio/translations", response_model=AudioTranscriptionResponse)
- async def audio_translations(
- req: Request,
- file: Optional[UploadFile] = File(None, description="要翻译的音频文件(如 mp3, wav)"),
- model: Optional[str] = Form(None, description="使用的语音识别模型"),
- source_language: Optional[str] = Form(None, description="原语音语言代码"),
- target_language: Optional[str] = Form("en", description="目标翻译语言代码"),
- translation_model: Optional[str] = Form("qwen-max", description="执行翻译的文本大模型"),
- prompt: Optional[str] = Form(None, description="可选的翻译提示词"),
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 语音翻译
-
- 将音频文件识别并翻译为目标语言文本。
-
- 支持两种请求方式:
- 1. application/json: 使用AudioTranslationsRequest模型,file为base64编码
- 2. multipart/form-data: 使用File和Form参数,直接上传文件
-
- 参数说明:
- - file: 要翻译的音频文件
- - model: 语音识别模型名称
- - source_language: 源语言代码(可选)
- - target_language: 目标语言代码,默认为英语
- - translation_model: 用于翻译的文本模型
- - prompt: 翻译提示词(可选)
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- content_type = req.headers.get("content-type", "")
-
- # 处理JSON请求
- if "application/json" in content_type:
- body = await req.json()
- from app.schemas.openai_compat import AudioTranslationsRequest
- request_obj = AudioTranslationsRequest(**body)
- asr_result = await service.audio_transcriptions(
- file=request_obj.file,
- model_name=request_obj.model,
- language=request_obj.source_language,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
-
- original_text = asr_result.text
-
- if not original_text or not original_text.strip():
- return AudioTranscriptionResponse(text="")
-
- if request_obj.target_language == "en" and request_obj.source_language == "en":
- return AudioTranscriptionResponse(text=original_text)
- from app.schemas.openai_compat import ChatCompletionsRequest, Message
- import json
- from decimal import Decimal
- from app.models.model import ModelPriceNew
- from app.services.api_call_log_service import ApiCallLogService
- if request_obj.prompt:
- system_prompt += f"\n参考提示:{request_obj.prompt}"
-
- translation_request = ChatCompletionsRequest(
- model=request_obj.translation_model,
- messages=[
- Message(role="system", content=system_prompt),
- Message(role="user", content=original_text)
- ],
- stream=True
- )
-
- chat_result_stream = await service.chat_completions(
- request=translation_request,
- user_id=user_id,
- api_key_id=api_key_id
- )
-
- final_text_chunks = []
- async for chunk in chat_result_stream:
- if isinstance(chunk, str) and chunk.startswith("data: ") and chunk.strip() != "data: [DONE]":
- try:
- data_dict = json.loads(chunk[6:])
- delta_content = data_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
- if delta_content:
- final_text_chunks.append(delta_content)
- except Exception:
- continue
-
- final_text = "".join(final_text_chunks)
-
- estimated_input_tokens = int((len(system_prompt) + len(original_text)) * 1.2) or 1
- estimated_output_tokens = int(len(final_text) * 1.2) or 1
-
- trans_model_obj = service._find_model(request_obj.translation_model, user_id)
- bill = Decimal("0")
- if trans_model_obj and not trans_model_obj.is_local:
- price_info = db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == trans_model_obj.model_code,
- ModelPriceNew.is_active == True,
- ).first()
- if price_info:
- in_cost = price_info.input_price_discounted * Decimal(str(estimated_input_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
- out_cost = price_info.output_price_discounted * Decimal(str(estimated_output_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
- bill = in_cost + out_cost
-
- log_service = ApiCallLogService(db)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=trans_model_obj.id if trans_model_obj else None,
- model_name=request_obj.translation_model,
- is_local=trans_model_obj.is_local if trans_model_obj else False,
- input_tokens=estimated_input_tokens,
- output_tokens=estimated_output_tokens,
- bill=float(bill),
- status="success",
- request_ip=req.client.host
- )
- return AudioTranscriptionResponse(text=final_text.strip())
- # 处理multipart/form-data请求
- elif "multipart/form-data" in content_type and file and model:
- asr_result = await service.audio_transcriptions(
- file=file,
- model_name=model,
- language=source_language,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
-
- original_text = asr_result.text
-
- if not original_text or not original_text.strip():
- return AudioTranscriptionResponse(text="")
-
- if target_language == "en" and source_language == "en":
- return AudioTranscriptionResponse(text=original_text)
- from app.schemas.openai_compat import ChatCompletionsRequest, Message
- import json
- from decimal import Decimal
- from app.models.model import ModelPriceNew
- from app.services.api_call_log_service import ApiCallLogService
-
- system_prompt = f"你是一个专业的翻译官。请将用户的文本准确翻译成 {target_language}。只返回翻译后的纯文本结果,不要任何多余解释。"
- if prompt:
- system_prompt += f"\n参考提示:{prompt}"
-
- translation_request = ChatCompletionsRequest(
- model=translation_model,
- messages=[
- Message(role="system", content=system_prompt),
- Message(role="user", content=original_text)
- ],
- stream=True
- )
-
- chat_result_stream = await service.chat_completions(
- request=translation_request,
- user_id=user_id,
- api_key_id=api_key_id
- )
-
- final_text_chunks = []
- async for chunk in chat_result_stream:
- if isinstance(chunk, str) and chunk.startswith("data: ") and chunk.strip() != "data: [DONE]":
- try:
- data_dict = json.loads(chunk[6:])
- delta_content = data_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
- if delta_content:
- final_text_chunks.append(delta_content)
- except Exception:
- continue
-
- final_text = "".join(final_text_chunks)
-
- estimated_input_tokens = int((len(system_prompt) + len(original_text)) * 1.2) or 1
- estimated_output_tokens = int(len(final_text) * 1.2) or 1
-
- trans_model_obj = service._find_model(translation_model, user_id)
- bill = Decimal("0")
- if trans_model_obj and not trans_model_obj.is_local:
- price_info = db.query(ModelPriceNew).filter(
- ModelPriceNew.model_code == trans_model_obj.model_code,
- ModelPriceNew.is_active == True,
- ).first()
- if price_info:
- in_cost = price_info.input_price_discounted * Decimal(str(estimated_input_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
- out_cost = price_info.output_price_discounted * Decimal(str(estimated_output_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
- bill = in_cost + out_cost
-
- log_service = ApiCallLogService(db)
- log_service.create_log(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=trans_model_obj.id if trans_model_obj else None,
- model_name=translation_model,
- is_local=trans_model_obj.is_local if trans_model_obj else False,
- input_tokens=estimated_input_tokens,
- output_tokens=estimated_output_tokens,
- bill=float(bill),
- status="success",
- request_ip=req.client.host
- )
- return AudioTranscriptionResponse(text=final_text.strip())
- else:
- raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
- except Exception as e:
- import httpx
- if isinstance(e, httpx.HTTPStatusError):
- raise HTTPException(
- status_code=400,
- detail={"error": {"message": f"Upstream API Error: {e.response.text}", "type": "upstream_api_error"}}
- )
- raise HTTPException(
- status_code=500,
- detail={"error": {"message": str(e), "type": "internal_error"}}
- )
- @router.post("/audio/speech", response_model=AudioSpeechResponse)
- async def audio_speech(
- request: AudioSpeechRequest,
- req: Request,
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 文字转语音(TTS)
-
- 将文本转换为语音音频。
-
- 参数说明:
- - model: 使用的TTS模型
- - input: 要转换的文本内容
- - voice: 发音人声音类型
- - response_format: 音频格式(mp3, opus, aac, flac, wav, pcm)
- - speed: 语速,范围0.25-4.0
-
- 返回音频二进制流。
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- audio_stream, media_type = await service.audio_speech(
- request=request,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
-
- return StreamingResponse(
- audio_stream,
- media_type=media_type,
- headers={"Content-Disposition": f"attachment; filename=speech.{request.response_format}"}
- )
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
-
- @router.post("/videos/generations", response_model=VideoGenerationResponse)
- async def video_generations(
- req: Request,
- image: Optional[UploadFile] = File(None, description="图生视频的参考图像"),
- prompt: Optional[str] = Form(None, description="生成视频的文本提示词"),
- model: Optional[str] = Form("wan2.6-t2v", description="使用的模型ID"),
- size: Optional[str] = Form("1080P", description="视频分辨率,如 720P, 1080P"),
- duration: Optional[int] = Form(5, description="视频时长(秒)"),
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 视频生成
-
- 根据文本提示词或参考图像生成视频。
-
- 支持两种请求方式:
- 1. application/json: 使用VideoGenerationRequest模型,image为base64编码
- 2. multipart/form-data: 使用File和Form参数,直接上传文件
-
- 支持两种模式:
- - 文生视频 (T2V): 只提供 prompt
- - 图生视频 (I2V): 同时提供 prompt 和 image
-
- 参数说明:
- - prompt: 生成视频的文本描述
- - image: 参考图像(图生视频模式)
- - model: 使用的视频生成模型
- - size: 视频分辨率
- - duration: 视频时长(秒)
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- content_type = req.headers.get("content-type", "")
-
- # 处理JSON请求
- if "application/json" in content_type:
- body = await req.json()
- from app.schemas.openai_compat import VideoGenerationRequest
- request_obj = VideoGenerationRequest(**body)
- # 图生视频
- if request_obj.image:
- return await service.image_to_video_generations(
- image=request_obj.image,
- prompt=request_obj.prompt,
- model_name=request_obj.model,
- size=request_obj.size,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- # 文生视频
- else:
- return await service.video_generations(
- request=request_obj,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- # 处理multipart/form-data请求
- elif "multipart/form-data" in content_type and prompt:
- # 图生视频
- if image:
- return await service.image_to_video_generations(
- image=image,
- prompt=prompt,
- model_name=model,
- size=size,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- # 文生视频
- else:
- from app.schemas.openai_compat import VideoGenerationRequest
- request_obj = VideoGenerationRequest(
- prompt=prompt,
- model=model,
- size=size,
- duration=duration
- )
- return await service.video_generations(
- request=request_obj,
- user_id=user_id,
- api_key_id=api_key_id,
- request_ip=req.client.host
- )
- else:
- raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
-
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
- @router.post("/rerank", response_model=RerankResponse)
- async def rerank(
- request: RerankRequest,
- req: Request,
- auth: tuple = Depends(get_api_key_auth),
- db: Session = Depends(get_db)
- ):
- """
- 文档重排序
-
- 根据查询文本对文档列表进行相关性排序。
-
- 参数说明:
- - model: 使用的重排序模型,如 bge-reranker-v2-m3
- - query: 查询文本
- - documents: 待排序的文档列表
- - top_n: 返回前N个结果(可选,默认返回全部)
- - return_documents: 是否在结果中返回文档内容(默认true)
-
- 返回结果按相关性分数降序排列。
- """
- user_id, api_key_id, key_type = auth
- service = OpenAICompatService(db)
-
- try:
- return await service.rerank(request, user_id, api_key_id, req.client.host)
- except OpenAICompatError as e:
- raise HTTPException(
- status_code=e.status_code,
- detail={"error": {"message": e.message, "type": e.error_type}}
- )
|