""" OpenAI 兼容 API 路由 完整支持 /v1/chat/completions 和 /v1/models 接口 """ from fastapi import APIRouter, Depends, HTTPException, Request, Header, File, Form, UploadFile from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from typing import Optional import json from app.database import get_db, SessionLocal from app.services.openai_compat_service import OpenAICompatService, OpenAICompatError from app.services.platform_api_key_service import PlatformApiKeyService from app.schemas.openai_compat import ( ChatCompletionsRequest, ChatCompletionsResponse, EmbeddingsRequest, EmbeddingsResponse, ModelsListResponse, ImageGenerationRequest, ImageGenerationResponse, ImageEditsRequest, AudioTranscriptionResponse, AudioTranscriptionsRequest, AudioTranslationsRequest, AudioSpeechRequest, AudioSpeechResponse, VideoGenerationRequest, VideoGenerationResponse, RerankRequest, RerankResponse, ) router = APIRouter(prefix="/api/v1", tags=["OpenAI 兼容 API"]) # ───────────────────────────────────────────── # 认证依赖 # ───────────────────────────────────────────── async def get_api_key_auth( authorization: Optional[str] = Header(None), db: Session = Depends(get_db), ) -> tuple: """验证 Bearer Token,返回 (user_id, key_id, key_type)""" if not authorization: raise HTTPException( status_code=401, detail={"error": {"message": "Missing Authorization header", "type": "authentication_error", "code": "missing_auth"}}, ) if not authorization.startswith("Bearer "): raise HTTPException( status_code=401, detail={"error": {"message": "Invalid Authorization header format. Expected 'Bearer '", "type": "authentication_error", "code": "invalid_auth_format"}}, ) api_key = authorization[7:] # 首先验证API密钥是否有效 result = PlatformApiKeyService(db).verify_api_key(api_key) if not result: raise HTTPException( status_code=401, detail={"error": {"message": "Incorrect API key provided", "type": "authentication_error", "code": "invalid_api_key"}}, ) user_id, key_id = result # 从缓存获取API密钥类型 from app.services.cache_service import CacheService key_data = await CacheService.get_api_key(key_id) if key_data: key_type = key_data.get("key_type", "public") else: # 从数据库获取 from app.models.platform_api_key import PlatformApiKey api_key_record = db.query(PlatformApiKey).filter( PlatformApiKey.id == key_id ).first() key_type = api_key_record.key_type if api_key_record else "public" # 缓存API密钥信息 await CacheService.set_api_key(key_id, { "key_type": key_type, "status": api_key_record.status if api_key_record else "active" }) return (user_id, key_id, key_type) # (user_id, key_id, key_type) # ───────────────────────────────────────────── # POST /api/v1/chat/completions # ───────────────────────────────────────────── @router.post("/chat/completions", summary="聊天补全", description="OpenAI兼容的聊天补全接口。支持流式和非流式输出,支持多模态输入(文本、图片、音频)。") async def chat_completions( request: ChatCompletionsRequest, req: Request, auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db), ): """ 聊天补全接口 完全兼容OpenAI的 /v1/chat/completions 接口规范。 **认证方式**: - 在请求头中添加:`Authorization: Bearer YOUR_API_KEY` - API Key需要先通过 `/api/platform/api-keys` 接口创建 **请求参数**: - **model**: 模型名称(必填),如 "gpt-4", "qwen-max" 等 - **messages**: 消息列表(必填),包含role和content - **temperature**: 采样温度,0-2之间,默认1 - **max_tokens**: 最大输出token数 - **stream**: 是否流式输出,默认false - 更多参数请参考OpenAI官方文档 **返回格式**: - 非流式:返回完整的JSON响应 - 流式:返回SSE格式的数据流 """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) request_ip = req.client.host if req.client else None try: # 检查 API Key 类型与模型类型是否匹配 model = service._find_model(request.model, user_id) if not model: raise OpenAICompatError( status_code=404, message=f"The model '{request.model}' does not exist", error_type="model_not_found", ) if model.is_local and key_type != "local": raise OpenAICompatError( status_code=403, message="Local models can only be accessed with local API keys", error_type="permission_error", ) if not model.is_local and key_type != "public": raise OpenAICompatError( status_code=403, message="Cloud models can only be accessed with public API keys", error_type="permission_error", ) # 流式:使用独立 session 避免请求结束时被关闭,并避免重复调用上游 if request.stream: stream_db = SessionLocal() async def stream_and_close(): try: stream_service = OpenAICompatService(stream_db) raw = await stream_service.chat_completions( request, user_id, api_key_id, request_ip ) async for chunk in raw: yield chunk finally: stream_db.close() return StreamingResponse( stream_and_close(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # 非流式:使用依赖注入的 db result = await service.chat_completions(request, user_id, api_key_id, request_ip) return result except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}}, ) # ───────────────────────────────────────────── # GET /api/v1/models # ───────────────────────────────────────────── @router.get("/models", response_model=ModelsListResponse, summary="获取模型列表", description="获取当前用户可用的模型列表。根据API Key类型返回相应的模型(public key返回云端模型,local key返回本地模型)。") def list_models( auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db), ): """ 获取可用模型列表 **认证方式**: - 在请求头中添加:`Authorization: Bearer YOUR_API_KEY` **返回内容**: - 根据API Key类型返回相应的模型列表 - public类型的key返回云端模型 - local类型的key返回本地模型 """ user_id, _, key_type = auth service = OpenAICompatService(db) # 根据API密钥类型返回相应的模型列表 return service.get_available_models(user_id, key_type) @router.post("/embeddings", response_model=EmbeddingsResponse) async def embeddings( request: EmbeddingsRequest, req: Request, auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 文本嵌入 将文本转换为向量表示。 参数说明: - input: 要嵌入的文本(字符串或字符串数组) - model: 使用的嵌入模型 - encoding_format: 返回格式(float或base64) - dimensions: 向量维度(可选) """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: return await service.embeddings(request, user_id, api_key_id, req.client.host) except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) @router.post("/images/generations", response_model=ImageGenerationResponse) async def image_generations( request: ImageGenerationRequest, req: Request, auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 文生图 根据文本描述生成图像。 参数说明: - prompt: 图像描述文本 - model: 使用的图像生成模型 - n: 生成图像数量 - quality: 图像质量(standard或hd) - size: 图像尺寸 - style: 图像风格(vivid或natural) - response_format: 返回格式(url或b64_json) """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: # 调用 Service 层的图像生成逻辑 return await service.image_generations(request, user_id, api_key_id, req.client.host) except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) @router.post("/images/edits", response_model=ImageGenerationResponse) async def image_edits( req: Request, image: Optional[UploadFile] = File(None, description="要编辑的原始图像(推荐PNG/JPG格式)"), prompt: Optional[str] = Form(None, description="对新图像的文本描述"), mask: Optional[UploadFile] = File(None, description="可选的遮罩层图像"), model: Optional[str] = Form("wan2.6-image", description="模型ID"), n: Optional[int] = Form(1, description="生成数量"), size: Optional[str] = Form("1024x1024", description="图像尺寸"), response_format: Optional[str] = Form("url", description="返回格式"), user: Optional[str] = Form(None, description="终端用户标识"), auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 图像编辑/图生图 基于原始图像和文本描述生成新图像。 支持两种请求方式: 1. application/json: 使用ImageEditsRequest模型,image和mask为base64编码 2. multipart/form-data: 使用File和Form参数,直接上传文件 参数说明: - image: 要编辑的原始图像 - prompt: 对新图像的文本描述 - mask: 遮罩层图像(可选) - model: 使用的图像生成模型 - n: 生成图像数量 - size: 图像尺寸 - response_format: 返回格式(url或b64_json) """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: content_type = req.headers.get("content-type", "") # 处理JSON请求 if "application/json" in content_type: body = await req.json() from app.schemas.openai_compat import ImageEditsRequest request_obj = ImageEditsRequest(**body) return await service.image_edits( image=request_obj.image, prompt=request_obj.prompt, mask=request_obj.mask, model_name=request_obj.model, n=request_obj.n, size=request_obj.size, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) # 处理multipart/form-data请求 elif "multipart/form-data" in content_type and image and prompt: return await service.image_edits( image=image, prompt=prompt, mask=mask, model_name=model, n=n, size=size, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) else: raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error") except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) @router.post("/audio/transcriptions", response_model=AudioTranscriptionResponse) async def audio_transcriptions( req: Request, file: Optional[UploadFile] = File(None, description="要识别的音频文件(如 mp3, wav)"), model: Optional[str] = Form(None, description="模型名称"), language: Optional[str] = Form(None, description="ISO-639-1 语言代码"), response_format: Optional[str] = Form("json", description="返回格式"), auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 语音转文字(STT) 将音频文件转换为文本。 支持两种请求方式: 1. application/json: 使用AudioTranscriptionsRequest模型,file为base64编码 2. multipart/form-data: 使用File和Form参数,直接上传文件 参数说明: - file: 要识别的音频文件 - model: 语音识别模型名称 - language: 音频语言代码(可选) - response_format: 返回格式 """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: content_type = req.headers.get("content-type", "") # 处理JSON请求 if "application/json" in content_type: body = await req.json() from app.schemas.openai_compat import AudioTranscriptionsRequest request_obj = AudioTranscriptionsRequest(**body) return await service.audio_transcriptions( file=request_obj.file, model_name=request_obj.model, language=request_obj.language, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) # 处理multipart/form-data请求 elif "multipart/form-data" in content_type and file and model: return await service.audio_transcriptions( file=file, model_name=model, language=language, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) else: raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error") except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) @router.post("/audio/translations", response_model=AudioTranscriptionResponse) async def audio_translations( req: Request, file: Optional[UploadFile] = File(None, description="要翻译的音频文件(如 mp3, wav)"), model: Optional[str] = Form(None, description="使用的语音识别模型"), source_language: Optional[str] = Form(None, description="原语音语言代码"), target_language: Optional[str] = Form("en", description="目标翻译语言代码"), translation_model: Optional[str] = Form("qwen-max", description="执行翻译的文本大模型"), prompt: Optional[str] = Form(None, description="可选的翻译提示词"), auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 语音翻译 将音频文件识别并翻译为目标语言文本。 支持两种请求方式: 1. application/json: 使用AudioTranslationsRequest模型,file为base64编码 2. multipart/form-data: 使用File和Form参数,直接上传文件 参数说明: - file: 要翻译的音频文件 - model: 语音识别模型名称 - source_language: 源语言代码(可选) - target_language: 目标语言代码,默认为英语 - translation_model: 用于翻译的文本模型 - prompt: 翻译提示词(可选) """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: content_type = req.headers.get("content-type", "") # 处理JSON请求 if "application/json" in content_type: body = await req.json() from app.schemas.openai_compat import AudioTranslationsRequest request_obj = AudioTranslationsRequest(**body) asr_result = await service.audio_transcriptions( file=request_obj.file, model_name=request_obj.model, language=request_obj.source_language, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) original_text = asr_result.text if not original_text or not original_text.strip(): return AudioTranscriptionResponse(text="") if request_obj.target_language == "en" and request_obj.source_language == "en": return AudioTranscriptionResponse(text=original_text) from app.schemas.openai_compat import ChatCompletionsRequest, Message import json from decimal import Decimal from app.models.model import ModelPriceNew from app.services.api_call_log_service import ApiCallLogService if request_obj.prompt: system_prompt += f"\n参考提示:{request_obj.prompt}" translation_request = ChatCompletionsRequest( model=request_obj.translation_model, messages=[ Message(role="system", content=system_prompt), Message(role="user", content=original_text) ], stream=True ) chat_result_stream = await service.chat_completions( request=translation_request, user_id=user_id, api_key_id=api_key_id ) final_text_chunks = [] async for chunk in chat_result_stream: if isinstance(chunk, str) and chunk.startswith("data: ") and chunk.strip() != "data: [DONE]": try: data_dict = json.loads(chunk[6:]) delta_content = data_dict.get("choices", [{}])[0].get("delta", {}).get("content", "") if delta_content: final_text_chunks.append(delta_content) except Exception: continue final_text = "".join(final_text_chunks) estimated_input_tokens = int((len(system_prompt) + len(original_text)) * 1.2) or 1 estimated_output_tokens = int(len(final_text) * 1.2) or 1 trans_model_obj = service._find_model(request_obj.translation_model, user_id) bill = Decimal("0") if trans_model_obj and not trans_model_obj.is_local: price_info = db.query(ModelPriceNew).filter( ModelPriceNew.model_code == trans_model_obj.model_code, ModelPriceNew.is_active == True, ).first() if price_info: in_cost = price_info.input_price_discounted * Decimal(str(estimated_input_tokens)) / Decimal(str(price_info.display_multiplier or 1000000)) out_cost = price_info.output_price_discounted * Decimal(str(estimated_output_tokens)) / Decimal(str(price_info.display_multiplier or 1000000)) bill = in_cost + out_cost log_service = ApiCallLogService(db) log_service.create_log( user_id=user_id, api_key_id=api_key_id, model_id=trans_model_obj.id if trans_model_obj else None, model_name=request_obj.translation_model, is_local=trans_model_obj.is_local if trans_model_obj else False, input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens, bill=float(bill), status="success", request_ip=req.client.host ) return AudioTranscriptionResponse(text=final_text.strip()) # 处理multipart/form-data请求 elif "multipart/form-data" in content_type and file and model: asr_result = await service.audio_transcriptions( file=file, model_name=model, language=source_language, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) original_text = asr_result.text if not original_text or not original_text.strip(): return AudioTranscriptionResponse(text="") if target_language == "en" and source_language == "en": return AudioTranscriptionResponse(text=original_text) from app.schemas.openai_compat import ChatCompletionsRequest, Message import json from decimal import Decimal from app.models.model import ModelPriceNew from app.services.api_call_log_service import ApiCallLogService system_prompt = f"你是一个专业的翻译官。请将用户的文本准确翻译成 {target_language}。只返回翻译后的纯文本结果,不要任何多余解释。" if prompt: system_prompt += f"\n参考提示:{prompt}" translation_request = ChatCompletionsRequest( model=translation_model, messages=[ Message(role="system", content=system_prompt), Message(role="user", content=original_text) ], stream=True ) chat_result_stream = await service.chat_completions( request=translation_request, user_id=user_id, api_key_id=api_key_id ) final_text_chunks = [] async for chunk in chat_result_stream: if isinstance(chunk, str) and chunk.startswith("data: ") and chunk.strip() != "data: [DONE]": try: data_dict = json.loads(chunk[6:]) delta_content = data_dict.get("choices", [{}])[0].get("delta", {}).get("content", "") if delta_content: final_text_chunks.append(delta_content) except Exception: continue final_text = "".join(final_text_chunks) estimated_input_tokens = int((len(system_prompt) + len(original_text)) * 1.2) or 1 estimated_output_tokens = int(len(final_text) * 1.2) or 1 trans_model_obj = service._find_model(translation_model, user_id) bill = Decimal("0") if trans_model_obj and not trans_model_obj.is_local: price_info = db.query(ModelPriceNew).filter( ModelPriceNew.model_code == trans_model_obj.model_code, ModelPriceNew.is_active == True, ).first() if price_info: in_cost = price_info.input_price_discounted * Decimal(str(estimated_input_tokens)) / Decimal(str(price_info.display_multiplier or 1000000)) out_cost = price_info.output_price_discounted * Decimal(str(estimated_output_tokens)) / Decimal(str(price_info.display_multiplier or 1000000)) bill = in_cost + out_cost log_service = ApiCallLogService(db) log_service.create_log( user_id=user_id, api_key_id=api_key_id, model_id=trans_model_obj.id if trans_model_obj else None, model_name=translation_model, is_local=trans_model_obj.is_local if trans_model_obj else False, input_tokens=estimated_input_tokens, output_tokens=estimated_output_tokens, bill=float(bill), status="success", request_ip=req.client.host ) return AudioTranscriptionResponse(text=final_text.strip()) else: raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error") except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) except Exception as e: import httpx if isinstance(e, httpx.HTTPStatusError): raise HTTPException( status_code=400, detail={"error": {"message": f"Upstream API Error: {e.response.text}", "type": "upstream_api_error"}} ) raise HTTPException( status_code=500, detail={"error": {"message": str(e), "type": "internal_error"}} ) @router.post("/audio/speech", response_model=AudioSpeechResponse) async def audio_speech( request: AudioSpeechRequest, req: Request, auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 文字转语音(TTS) 将文本转换为语音音频。 参数说明: - model: 使用的TTS模型 - input: 要转换的文本内容 - voice: 发音人声音类型 - response_format: 音频格式(mp3, opus, aac, flac, wav, pcm) - speed: 语速,范围0.25-4.0 返回音频二进制流。 """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: audio_stream, media_type = await service.audio_speech( request=request, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) return StreamingResponse( audio_stream, media_type=media_type, headers={"Content-Disposition": f"attachment; filename=speech.{request.response_format}"} ) except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) @router.post("/videos/generations", response_model=VideoGenerationResponse) async def video_generations( req: Request, image: Optional[UploadFile] = File(None, description="图生视频的参考图像"), prompt: Optional[str] = Form(None, description="生成视频的文本提示词"), model: Optional[str] = Form("wan2.6-t2v", description="使用的模型ID"), size: Optional[str] = Form("1080P", description="视频分辨率,如 720P, 1080P"), duration: Optional[int] = Form(5, description="视频时长(秒)"), auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 视频生成 根据文本提示词或参考图像生成视频。 支持两种请求方式: 1. application/json: 使用VideoGenerationRequest模型,image为base64编码 2. multipart/form-data: 使用File和Form参数,直接上传文件 支持两种模式: - 文生视频 (T2V): 只提供 prompt - 图生视频 (I2V): 同时提供 prompt 和 image 参数说明: - prompt: 生成视频的文本描述 - image: 参考图像(图生视频模式) - model: 使用的视频生成模型 - size: 视频分辨率 - duration: 视频时长(秒) """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: content_type = req.headers.get("content-type", "") # 处理JSON请求 if "application/json" in content_type: body = await req.json() from app.schemas.openai_compat import VideoGenerationRequest request_obj = VideoGenerationRequest(**body) # 图生视频 if request_obj.image: return await service.image_to_video_generations( image=request_obj.image, prompt=request_obj.prompt, model_name=request_obj.model, size=request_obj.size, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) # 文生视频 else: return await service.video_generations( request=request_obj, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) # 处理multipart/form-data请求 elif "multipart/form-data" in content_type and prompt: # 图生视频 if image: return await service.image_to_video_generations( image=image, prompt=prompt, model_name=model, size=size, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) # 文生视频 else: from app.schemas.openai_compat import VideoGenerationRequest request_obj = VideoGenerationRequest( prompt=prompt, model=model, size=size, duration=duration ) return await service.video_generations( request=request_obj, user_id=user_id, api_key_id=api_key_id, request_ip=req.client.host ) else: raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error") except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} ) @router.post("/rerank", response_model=RerankResponse) async def rerank( request: RerankRequest, req: Request, auth: tuple = Depends(get_api_key_auth), db: Session = Depends(get_db) ): """ 文档重排序 根据查询文本对文档列表进行相关性排序。 参数说明: - model: 使用的重排序模型,如 bge-reranker-v2-m3 - query: 查询文本 - documents: 待排序的文档列表 - top_n: 返回前N个结果(可选,默认返回全部) - return_documents: 是否在结果中返回文档内容(默认true) 返回结果按相关性分数降序排列。 """ user_id, api_key_id, key_type = auth service = OpenAICompatService(db) try: return await service.rerank(request, user_id, api_key_id, req.client.host) except OpenAICompatError as e: raise HTTPException( status_code=e.status_code, detail={"error": {"message": e.message, "type": e.error_type}} )