openai_compat_router.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811
  1. """
  2. OpenAI 兼容 API 路由
  3. 完整支持 /v1/chat/completions 和 /v1/models 接口
  4. """
  5. from fastapi import APIRouter, Depends, HTTPException, Request, Header, File, Form, UploadFile
  6. from fastapi.responses import StreamingResponse
  7. from sqlalchemy.orm import Session
  8. from typing import Optional
  9. import json
  10. from app.database import get_db, SessionLocal
  11. from app.services.openai_compat_service import OpenAICompatService, OpenAICompatError
  12. from app.services.platform_api_key_service import PlatformApiKeyService
  13. from app.schemas.openai_compat import (
  14. ChatCompletionsRequest,
  15. ChatCompletionsResponse,
  16. EmbeddingsRequest,
  17. EmbeddingsResponse,
  18. ModelsListResponse,
  19. ImageGenerationRequest,
  20. ImageGenerationResponse,
  21. ImageEditsRequest,
  22. AudioTranscriptionResponse,
  23. AudioTranscriptionsRequest,
  24. AudioTranslationsRequest,
  25. AudioSpeechRequest,
  26. AudioSpeechResponse,
  27. VideoGenerationRequest,
  28. VideoGenerationResponse,
  29. RerankRequest,
  30. RerankResponse,
  31. )
  32. router = APIRouter(prefix="/api/v1", tags=["OpenAI 兼容 API"])
  33. # ─────────────────────────────────────────────
  34. # 认证依赖
  35. # ─────────────────────────────────────────────
  36. async def get_api_key_auth(
  37. authorization: Optional[str] = Header(None),
  38. db: Session = Depends(get_db),
  39. ) -> tuple:
  40. """验证 Bearer Token,返回 (user_id, key_id, key_type)"""
  41. if not authorization:
  42. raise HTTPException(
  43. status_code=401,
  44. detail={"error": {"message": "Missing Authorization header", "type": "authentication_error", "code": "missing_auth"}},
  45. )
  46. if not authorization.startswith("Bearer "):
  47. raise HTTPException(
  48. status_code=401,
  49. detail={"error": {"message": "Invalid Authorization header format. Expected 'Bearer <token>'", "type": "authentication_error", "code": "invalid_auth_format"}},
  50. )
  51. api_key = authorization[7:]
  52. # 首先验证API密钥是否有效
  53. result = PlatformApiKeyService(db).verify_api_key(api_key)
  54. if not result:
  55. raise HTTPException(
  56. status_code=401,
  57. detail={"error": {"message": "Incorrect API key provided", "type": "authentication_error", "code": "invalid_api_key"}},
  58. )
  59. user_id, key_id = result
  60. # 从缓存获取API密钥类型
  61. from app.services.cache_service import CacheService
  62. key_data = await CacheService.get_api_key(key_id)
  63. if key_data:
  64. key_type = key_data.get("key_type", "public")
  65. else:
  66. # 从数据库获取
  67. from app.models.platform_api_key import PlatformApiKey
  68. api_key_record = db.query(PlatformApiKey).filter(
  69. PlatformApiKey.id == key_id
  70. ).first()
  71. key_type = api_key_record.key_type if api_key_record else "public"
  72. # 缓存API密钥信息
  73. await CacheService.set_api_key(key_id, {
  74. "key_type": key_type,
  75. "status": api_key_record.status if api_key_record else "active"
  76. })
  77. return (user_id, key_id, key_type) # (user_id, key_id, key_type)
  78. # ─────────────────────────────────────────────
  79. # POST /api/v1/chat/completions
  80. # ─────────────────────────────────────────────
  81. @router.post("/chat/completions", summary="聊天补全", description="OpenAI兼容的聊天补全接口。支持流式和非流式输出,支持多模态输入(文本、图片、音频)。")
  82. async def chat_completions(
  83. request: ChatCompletionsRequest,
  84. req: Request,
  85. auth: tuple = Depends(get_api_key_auth),
  86. db: Session = Depends(get_db),
  87. ):
  88. """
  89. 聊天补全接口
  90. 完全兼容OpenAI的 /v1/chat/completions 接口规范。
  91. **认证方式**:
  92. - 在请求头中添加:`Authorization: Bearer YOUR_API_KEY`
  93. - API Key需要先通过 `/api/platform/api-keys` 接口创建
  94. **请求参数**:
  95. - **model**: 模型名称(必填),如 "gpt-4", "qwen-max" 等
  96. - **messages**: 消息列表(必填),包含role和content
  97. - **temperature**: 采样温度,0-2之间,默认1
  98. - **max_tokens**: 最大输出token数
  99. - **stream**: 是否流式输出,默认false
  100. - 更多参数请参考OpenAI官方文档
  101. **返回格式**:
  102. - 非流式:返回完整的JSON响应
  103. - 流式:返回SSE格式的数据流
  104. """
  105. user_id, api_key_id, key_type = auth
  106. service = OpenAICompatService(db)
  107. request_ip = req.client.host if req.client else None
  108. try:
  109. # 检查 API Key 类型与模型类型是否匹配
  110. model = service._find_model(request.model, user_id)
  111. if not model:
  112. raise OpenAICompatError(
  113. status_code=404,
  114. message=f"The model '{request.model}' does not exist",
  115. error_type="model_not_found",
  116. )
  117. if model.is_local and key_type != "local":
  118. raise OpenAICompatError(
  119. status_code=403,
  120. message="Local models can only be accessed with local API keys",
  121. error_type="permission_error",
  122. )
  123. if not model.is_local and key_type != "public":
  124. raise OpenAICompatError(
  125. status_code=403,
  126. message="Cloud models can only be accessed with public API keys",
  127. error_type="permission_error",
  128. )
  129. # 流式:使用独立 session 避免请求结束时被关闭,并避免重复调用上游
  130. if request.stream:
  131. stream_db = SessionLocal()
  132. async def stream_and_close():
  133. try:
  134. stream_service = OpenAICompatService(stream_db)
  135. raw = await stream_service.chat_completions(
  136. request, user_id, api_key_id, request_ip
  137. )
  138. async for chunk in raw:
  139. yield chunk
  140. finally:
  141. stream_db.close()
  142. return StreamingResponse(
  143. stream_and_close(),
  144. media_type="text/event-stream",
  145. headers={
  146. "Cache-Control": "no-cache",
  147. "Connection": "keep-alive",
  148. "X-Accel-Buffering": "no",
  149. },
  150. )
  151. # 非流式:使用依赖注入的 db
  152. result = await service.chat_completions(request, user_id, api_key_id, request_ip)
  153. return result
  154. except OpenAICompatError as e:
  155. raise HTTPException(
  156. status_code=e.status_code,
  157. detail={"error": {"message": e.message, "type": e.error_type}},
  158. )
  159. # ─────────────────────────────────────────────
  160. # GET /api/v1/models
  161. # ─────────────────────────────────────────────
  162. @router.get("/models", response_model=ModelsListResponse, summary="获取模型列表", description="获取当前用户可用的模型列表。根据API Key类型返回相应的模型(public key返回云端模型,local key返回本地模型)。")
  163. def list_models(
  164. auth: tuple = Depends(get_api_key_auth),
  165. db: Session = Depends(get_db),
  166. ):
  167. """
  168. 获取可用模型列表
  169. **认证方式**:
  170. - 在请求头中添加:`Authorization: Bearer YOUR_API_KEY`
  171. **返回内容**:
  172. - 根据API Key类型返回相应的模型列表
  173. - public类型的key返回云端模型
  174. - local类型的key返回本地模型
  175. """
  176. user_id, _, key_type = auth
  177. service = OpenAICompatService(db)
  178. # 根据API密钥类型返回相应的模型列表
  179. return service.get_available_models(user_id, key_type)
  180. @router.post("/embeddings", response_model=EmbeddingsResponse)
  181. async def embeddings(
  182. request: EmbeddingsRequest,
  183. req: Request,
  184. auth: tuple = Depends(get_api_key_auth),
  185. db: Session = Depends(get_db)
  186. ):
  187. """
  188. 文本嵌入
  189. 将文本转换为向量表示。
  190. 参数说明:
  191. - input: 要嵌入的文本(字符串或字符串数组)
  192. - model: 使用的嵌入模型
  193. - encoding_format: 返回格式(float或base64)
  194. - dimensions: 向量维度(可选)
  195. """
  196. user_id, api_key_id, key_type = auth
  197. service = OpenAICompatService(db)
  198. try:
  199. return await service.embeddings(request, user_id, api_key_id, req.client.host)
  200. except OpenAICompatError as e:
  201. raise HTTPException(
  202. status_code=e.status_code,
  203. detail={"error": {"message": e.message, "type": e.error_type}}
  204. )
  205. @router.post("/images/generations", response_model=ImageGenerationResponse)
  206. async def image_generations(
  207. request: ImageGenerationRequest,
  208. req: Request,
  209. auth: tuple = Depends(get_api_key_auth),
  210. db: Session = Depends(get_db)
  211. ):
  212. """
  213. 文生图
  214. 根据文本描述生成图像。
  215. 参数说明:
  216. - prompt: 图像描述文本
  217. - model: 使用的图像生成模型
  218. - n: 生成图像数量
  219. - quality: 图像质量(standard或hd)
  220. - size: 图像尺寸
  221. - style: 图像风格(vivid或natural)
  222. - response_format: 返回格式(url或b64_json)
  223. """
  224. user_id, api_key_id, key_type = auth
  225. service = OpenAICompatService(db)
  226. try:
  227. # 调用 Service 层的图像生成逻辑
  228. return await service.image_generations(request, user_id, api_key_id, req.client.host)
  229. except OpenAICompatError as e:
  230. raise HTTPException(
  231. status_code=e.status_code,
  232. detail={"error": {"message": e.message, "type": e.error_type}}
  233. )
  234. @router.post("/images/edits", response_model=ImageGenerationResponse)
  235. async def image_edits(
  236. req: Request,
  237. image: Optional[UploadFile] = File(None, description="要编辑的原始图像(推荐PNG/JPG格式)"),
  238. prompt: Optional[str] = Form(None, description="对新图像的文本描述"),
  239. mask: Optional[UploadFile] = File(None, description="可选的遮罩层图像"),
  240. model: Optional[str] = Form("wan2.6-image", description="模型ID"),
  241. n: Optional[int] = Form(1, description="生成数量"),
  242. size: Optional[str] = Form("1024x1024", description="图像尺寸"),
  243. response_format: Optional[str] = Form("url", description="返回格式"),
  244. user: Optional[str] = Form(None, description="终端用户标识"),
  245. auth: tuple = Depends(get_api_key_auth),
  246. db: Session = Depends(get_db)
  247. ):
  248. """
  249. 图像编辑/图生图
  250. 基于原始图像和文本描述生成新图像。
  251. 支持两种请求方式:
  252. 1. application/json: 使用ImageEditsRequest模型,image和mask为base64编码
  253. 2. multipart/form-data: 使用File和Form参数,直接上传文件
  254. 参数说明:
  255. - image: 要编辑的原始图像
  256. - prompt: 对新图像的文本描述
  257. - mask: 遮罩层图像(可选)
  258. - model: 使用的图像生成模型
  259. - n: 生成图像数量
  260. - size: 图像尺寸
  261. - response_format: 返回格式(url或b64_json)
  262. """
  263. user_id, api_key_id, key_type = auth
  264. service = OpenAICompatService(db)
  265. try:
  266. content_type = req.headers.get("content-type", "")
  267. # 处理JSON请求
  268. if "application/json" in content_type:
  269. body = await req.json()
  270. from app.schemas.openai_compat import ImageEditsRequest
  271. request_obj = ImageEditsRequest(**body)
  272. return await service.image_edits(
  273. image=request_obj.image,
  274. prompt=request_obj.prompt,
  275. mask=request_obj.mask,
  276. model_name=request_obj.model,
  277. n=request_obj.n,
  278. size=request_obj.size,
  279. user_id=user_id,
  280. api_key_id=api_key_id,
  281. request_ip=req.client.host
  282. )
  283. # 处理multipart/form-data请求
  284. elif "multipart/form-data" in content_type and image and prompt:
  285. return await service.image_edits(
  286. image=image,
  287. prompt=prompt,
  288. mask=mask,
  289. model_name=model,
  290. n=n,
  291. size=size,
  292. user_id=user_id,
  293. api_key_id=api_key_id,
  294. request_ip=req.client.host
  295. )
  296. else:
  297. raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
  298. except OpenAICompatError as e:
  299. raise HTTPException(
  300. status_code=e.status_code,
  301. detail={"error": {"message": e.message, "type": e.error_type}}
  302. )
  303. @router.post("/audio/transcriptions", response_model=AudioTranscriptionResponse)
  304. async def audio_transcriptions(
  305. req: Request,
  306. file: Optional[UploadFile] = File(None, description="要识别的音频文件(如 mp3, wav)"),
  307. model: Optional[str] = Form(None, description="模型名称"),
  308. language: Optional[str] = Form(None, description="ISO-639-1 语言代码"),
  309. response_format: Optional[str] = Form("json", description="返回格式"),
  310. auth: tuple = Depends(get_api_key_auth),
  311. db: Session = Depends(get_db)
  312. ):
  313. """
  314. 语音转文字(STT)
  315. 将音频文件转换为文本。
  316. 支持两种请求方式:
  317. 1. application/json: 使用AudioTranscriptionsRequest模型,file为base64编码
  318. 2. multipart/form-data: 使用File和Form参数,直接上传文件
  319. 参数说明:
  320. - file: 要识别的音频文件
  321. - model: 语音识别模型名称
  322. - language: 音频语言代码(可选)
  323. - response_format: 返回格式
  324. """
  325. user_id, api_key_id, key_type = auth
  326. service = OpenAICompatService(db)
  327. try:
  328. content_type = req.headers.get("content-type", "")
  329. # 处理JSON请求
  330. if "application/json" in content_type:
  331. body = await req.json()
  332. from app.schemas.openai_compat import AudioTranscriptionsRequest
  333. request_obj = AudioTranscriptionsRequest(**body)
  334. return await service.audio_transcriptions(
  335. file=request_obj.file,
  336. model_name=request_obj.model,
  337. language=request_obj.language,
  338. user_id=user_id,
  339. api_key_id=api_key_id,
  340. request_ip=req.client.host
  341. )
  342. # 处理multipart/form-data请求
  343. elif "multipart/form-data" in content_type and file and model:
  344. return await service.audio_transcriptions(
  345. file=file,
  346. model_name=model,
  347. language=language,
  348. user_id=user_id,
  349. api_key_id=api_key_id,
  350. request_ip=req.client.host
  351. )
  352. else:
  353. raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
  354. except OpenAICompatError as e:
  355. raise HTTPException(
  356. status_code=e.status_code,
  357. detail={"error": {"message": e.message, "type": e.error_type}}
  358. )
  359. @router.post("/audio/translations", response_model=AudioTranscriptionResponse)
  360. async def audio_translations(
  361. req: Request,
  362. file: Optional[UploadFile] = File(None, description="要翻译的音频文件(如 mp3, wav)"),
  363. model: Optional[str] = Form(None, description="使用的语音识别模型"),
  364. source_language: Optional[str] = Form(None, description="原语音语言代码"),
  365. target_language: Optional[str] = Form("en", description="目标翻译语言代码"),
  366. translation_model: Optional[str] = Form("qwen-max", description="执行翻译的文本大模型"),
  367. prompt: Optional[str] = Form(None, description="可选的翻译提示词"),
  368. auth: tuple = Depends(get_api_key_auth),
  369. db: Session = Depends(get_db)
  370. ):
  371. """
  372. 语音翻译
  373. 将音频文件识别并翻译为目标语言文本。
  374. 支持两种请求方式:
  375. 1. application/json: 使用AudioTranslationsRequest模型,file为base64编码
  376. 2. multipart/form-data: 使用File和Form参数,直接上传文件
  377. 参数说明:
  378. - file: 要翻译的音频文件
  379. - model: 语音识别模型名称
  380. - source_language: 源语言代码(可选)
  381. - target_language: 目标语言代码,默认为英语
  382. - translation_model: 用于翻译的文本模型
  383. - prompt: 翻译提示词(可选)
  384. """
  385. user_id, api_key_id, key_type = auth
  386. service = OpenAICompatService(db)
  387. try:
  388. content_type = req.headers.get("content-type", "")
  389. # 处理JSON请求
  390. if "application/json" in content_type:
  391. body = await req.json()
  392. from app.schemas.openai_compat import AudioTranslationsRequest
  393. request_obj = AudioTranslationsRequest(**body)
  394. asr_result = await service.audio_transcriptions(
  395. file=request_obj.file,
  396. model_name=request_obj.model,
  397. language=request_obj.source_language,
  398. user_id=user_id,
  399. api_key_id=api_key_id,
  400. request_ip=req.client.host
  401. )
  402. original_text = asr_result.text
  403. if not original_text or not original_text.strip():
  404. return AudioTranscriptionResponse(text="")
  405. if request_obj.target_language == "en" and request_obj.source_language == "en":
  406. return AudioTranscriptionResponse(text=original_text)
  407. from app.schemas.openai_compat import ChatCompletionsRequest, Message
  408. import json
  409. from decimal import Decimal
  410. from app.models.model import ModelPriceNew
  411. from app.services.api_call_log_service import ApiCallLogService
  412. if request_obj.prompt:
  413. system_prompt += f"\n参考提示:{request_obj.prompt}"
  414. translation_request = ChatCompletionsRequest(
  415. model=request_obj.translation_model,
  416. messages=[
  417. Message(role="system", content=system_prompt),
  418. Message(role="user", content=original_text)
  419. ],
  420. stream=True
  421. )
  422. chat_result_stream = await service.chat_completions(
  423. request=translation_request,
  424. user_id=user_id,
  425. api_key_id=api_key_id
  426. )
  427. final_text_chunks = []
  428. async for chunk in chat_result_stream:
  429. if isinstance(chunk, str) and chunk.startswith("data: ") and chunk.strip() != "data: [DONE]":
  430. try:
  431. data_dict = json.loads(chunk[6:])
  432. delta_content = data_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
  433. if delta_content:
  434. final_text_chunks.append(delta_content)
  435. except Exception:
  436. continue
  437. final_text = "".join(final_text_chunks)
  438. estimated_input_tokens = int((len(system_prompt) + len(original_text)) * 1.2) or 1
  439. estimated_output_tokens = int(len(final_text) * 1.2) or 1
  440. trans_model_obj = service._find_model(request_obj.translation_model, user_id)
  441. bill = Decimal("0")
  442. if trans_model_obj and not trans_model_obj.is_local:
  443. price_info = db.query(ModelPriceNew).filter(
  444. ModelPriceNew.model_code == trans_model_obj.model_code,
  445. ModelPriceNew.is_active == True,
  446. ).first()
  447. if price_info:
  448. in_cost = price_info.input_price_discounted * Decimal(str(estimated_input_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
  449. out_cost = price_info.output_price_discounted * Decimal(str(estimated_output_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
  450. bill = in_cost + out_cost
  451. log_service = ApiCallLogService(db)
  452. log_service.create_log(
  453. user_id=user_id,
  454. api_key_id=api_key_id,
  455. model_id=trans_model_obj.id if trans_model_obj else None,
  456. model_name=request_obj.translation_model,
  457. is_local=trans_model_obj.is_local if trans_model_obj else False,
  458. input_tokens=estimated_input_tokens,
  459. output_tokens=estimated_output_tokens,
  460. bill=float(bill),
  461. status="success",
  462. request_ip=req.client.host
  463. )
  464. return AudioTranscriptionResponse(text=final_text.strip())
  465. # 处理multipart/form-data请求
  466. elif "multipart/form-data" in content_type and file and model:
  467. asr_result = await service.audio_transcriptions(
  468. file=file,
  469. model_name=model,
  470. language=source_language,
  471. user_id=user_id,
  472. api_key_id=api_key_id,
  473. request_ip=req.client.host
  474. )
  475. original_text = asr_result.text
  476. if not original_text or not original_text.strip():
  477. return AudioTranscriptionResponse(text="")
  478. if target_language == "en" and source_language == "en":
  479. return AudioTranscriptionResponse(text=original_text)
  480. from app.schemas.openai_compat import ChatCompletionsRequest, Message
  481. import json
  482. from decimal import Decimal
  483. from app.models.model import ModelPriceNew
  484. from app.services.api_call_log_service import ApiCallLogService
  485. system_prompt = f"你是一个专业的翻译官。请将用户的文本准确翻译成 {target_language}。只返回翻译后的纯文本结果,不要任何多余解释。"
  486. if prompt:
  487. system_prompt += f"\n参考提示:{prompt}"
  488. translation_request = ChatCompletionsRequest(
  489. model=translation_model,
  490. messages=[
  491. Message(role="system", content=system_prompt),
  492. Message(role="user", content=original_text)
  493. ],
  494. stream=True
  495. )
  496. chat_result_stream = await service.chat_completions(
  497. request=translation_request,
  498. user_id=user_id,
  499. api_key_id=api_key_id
  500. )
  501. final_text_chunks = []
  502. async for chunk in chat_result_stream:
  503. if isinstance(chunk, str) and chunk.startswith("data: ") and chunk.strip() != "data: [DONE]":
  504. try:
  505. data_dict = json.loads(chunk[6:])
  506. delta_content = data_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
  507. if delta_content:
  508. final_text_chunks.append(delta_content)
  509. except Exception:
  510. continue
  511. final_text = "".join(final_text_chunks)
  512. estimated_input_tokens = int((len(system_prompt) + len(original_text)) * 1.2) or 1
  513. estimated_output_tokens = int(len(final_text) * 1.2) or 1
  514. trans_model_obj = service._find_model(translation_model, user_id)
  515. bill = Decimal("0")
  516. if trans_model_obj and not trans_model_obj.is_local:
  517. price_info = db.query(ModelPriceNew).filter(
  518. ModelPriceNew.model_code == trans_model_obj.model_code,
  519. ModelPriceNew.is_active == True,
  520. ).first()
  521. if price_info:
  522. in_cost = price_info.input_price_discounted * Decimal(str(estimated_input_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
  523. out_cost = price_info.output_price_discounted * Decimal(str(estimated_output_tokens)) / Decimal(str(price_info.display_multiplier or 1000000))
  524. bill = in_cost + out_cost
  525. log_service = ApiCallLogService(db)
  526. log_service.create_log(
  527. user_id=user_id,
  528. api_key_id=api_key_id,
  529. model_id=trans_model_obj.id if trans_model_obj else None,
  530. model_name=translation_model,
  531. is_local=trans_model_obj.is_local if trans_model_obj else False,
  532. input_tokens=estimated_input_tokens,
  533. output_tokens=estimated_output_tokens,
  534. bill=float(bill),
  535. status="success",
  536. request_ip=req.client.host
  537. )
  538. return AudioTranscriptionResponse(text=final_text.strip())
  539. else:
  540. raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
  541. except OpenAICompatError as e:
  542. raise HTTPException(
  543. status_code=e.status_code,
  544. detail={"error": {"message": e.message, "type": e.error_type}}
  545. )
  546. except Exception as e:
  547. import httpx
  548. if isinstance(e, httpx.HTTPStatusError):
  549. raise HTTPException(
  550. status_code=400,
  551. detail={"error": {"message": f"Upstream API Error: {e.response.text}", "type": "upstream_api_error"}}
  552. )
  553. raise HTTPException(
  554. status_code=500,
  555. detail={"error": {"message": str(e), "type": "internal_error"}}
  556. )
  557. @router.post("/audio/speech", response_model=AudioSpeechResponse)
  558. async def audio_speech(
  559. request: AudioSpeechRequest,
  560. req: Request,
  561. auth: tuple = Depends(get_api_key_auth),
  562. db: Session = Depends(get_db)
  563. ):
  564. """
  565. 文字转语音(TTS)
  566. 将文本转换为语音音频。
  567. 参数说明:
  568. - model: 使用的TTS模型
  569. - input: 要转换的文本内容
  570. - voice: 发音人声音类型
  571. - response_format: 音频格式(mp3, opus, aac, flac, wav, pcm)
  572. - speed: 语速,范围0.25-4.0
  573. 返回音频二进制流。
  574. """
  575. user_id, api_key_id, key_type = auth
  576. service = OpenAICompatService(db)
  577. try:
  578. audio_stream, media_type = await service.audio_speech(
  579. request=request,
  580. user_id=user_id,
  581. api_key_id=api_key_id,
  582. request_ip=req.client.host
  583. )
  584. return StreamingResponse(
  585. audio_stream,
  586. media_type=media_type,
  587. headers={"Content-Disposition": f"attachment; filename=speech.{request.response_format}"}
  588. )
  589. except OpenAICompatError as e:
  590. raise HTTPException(
  591. status_code=e.status_code,
  592. detail={"error": {"message": e.message, "type": e.error_type}}
  593. )
  594. @router.post("/videos/generations", response_model=VideoGenerationResponse)
  595. async def video_generations(
  596. req: Request,
  597. image: Optional[UploadFile] = File(None, description="图生视频的参考图像"),
  598. prompt: Optional[str] = Form(None, description="生成视频的文本提示词"),
  599. model: Optional[str] = Form("wan2.6-t2v", description="使用的模型ID"),
  600. size: Optional[str] = Form("1080P", description="视频分辨率,如 720P, 1080P"),
  601. duration: Optional[int] = Form(5, description="视频时长(秒)"),
  602. auth: tuple = Depends(get_api_key_auth),
  603. db: Session = Depends(get_db)
  604. ):
  605. """
  606. 视频生成
  607. 根据文本提示词或参考图像生成视频。
  608. 支持两种请求方式:
  609. 1. application/json: 使用VideoGenerationRequest模型,image为base64编码
  610. 2. multipart/form-data: 使用File和Form参数,直接上传文件
  611. 支持两种模式:
  612. - 文生视频 (T2V): 只提供 prompt
  613. - 图生视频 (I2V): 同时提供 prompt 和 image
  614. 参数说明:
  615. - prompt: 生成视频的文本描述
  616. - image: 参考图像(图生视频模式)
  617. - model: 使用的视频生成模型
  618. - size: 视频分辨率
  619. - duration: 视频时长(秒)
  620. """
  621. user_id, api_key_id, key_type = auth
  622. service = OpenAICompatService(db)
  623. try:
  624. content_type = req.headers.get("content-type", "")
  625. # 处理JSON请求
  626. if "application/json" in content_type:
  627. body = await req.json()
  628. from app.schemas.openai_compat import VideoGenerationRequest
  629. request_obj = VideoGenerationRequest(**body)
  630. # 图生视频
  631. if request_obj.image:
  632. return await service.image_to_video_generations(
  633. image=request_obj.image,
  634. prompt=request_obj.prompt,
  635. model_name=request_obj.model,
  636. size=request_obj.size,
  637. user_id=user_id,
  638. api_key_id=api_key_id,
  639. request_ip=req.client.host
  640. )
  641. # 文生视频
  642. else:
  643. return await service.video_generations(
  644. request=request_obj,
  645. user_id=user_id,
  646. api_key_id=api_key_id,
  647. request_ip=req.client.host
  648. )
  649. # 处理multipart/form-data请求
  650. elif "multipart/form-data" in content_type and prompt:
  651. # 图生视频
  652. if image:
  653. return await service.image_to_video_generations(
  654. image=image,
  655. prompt=prompt,
  656. model_name=model,
  657. size=size,
  658. user_id=user_id,
  659. api_key_id=api_key_id,
  660. request_ip=req.client.host
  661. )
  662. # 文生视频
  663. else:
  664. from app.schemas.openai_compat import VideoGenerationRequest
  665. request_obj = VideoGenerationRequest(
  666. prompt=prompt,
  667. model=model,
  668. size=size,
  669. duration=duration
  670. )
  671. return await service.video_generations(
  672. request=request_obj,
  673. user_id=user_id,
  674. api_key_id=api_key_id,
  675. request_ip=req.client.host
  676. )
  677. else:
  678. raise OpenAICompatError(415, "不支持的 Content-Type,请使用 application/json 或 multipart/form-data", "invalid_request_error")
  679. except OpenAICompatError as e:
  680. raise HTTPException(
  681. status_code=e.status_code,
  682. detail={"error": {"message": e.message, "type": e.error_type}}
  683. )
  684. @router.post("/rerank", response_model=RerankResponse)
  685. async def rerank(
  686. request: RerankRequest,
  687. req: Request,
  688. auth: tuple = Depends(get_api_key_auth),
  689. db: Session = Depends(get_db)
  690. ):
  691. """
  692. 文档重排序
  693. 根据查询文本对文档列表进行相关性排序。
  694. 参数说明:
  695. - model: 使用的重排序模型,如 bge-reranker-v2-m3
  696. - query: 查询文本
  697. - documents: 待排序的文档列表
  698. - top_n: 返回前N个结果(可选,默认返回全部)
  699. - return_documents: 是否在结果中返回文档内容(默认true)
  700. 返回结果按相关性分数降序排列。
  701. """
  702. user_id, api_key_id, key_type = auth
  703. service = OpenAICompatService(db)
  704. try:
  705. return await service.rerank(request, user_id, api_key_id, req.client.host)
  706. except OpenAICompatError as e:
  707. raise HTTPException(
  708. status_code=e.status_code,
  709. detail={"error": {"message": e.message, "type": e.error_type}}
  710. )