"""部署 API —— 导出、在线服务管理、OpenAI 兼容代理。 路由前缀: /api/v1/deployment - 管理端点(JWT 认证) /api/v1/deployment/proxy - 代理端点(API Key 认证) 代理路由: POST /proxy/{task_id}/v1/chat/completions - OpenAI 兼容聊天补全 POST /proxy/{task_id}/v1/completions - OpenAI 兼容文本补全 GET /proxy/{task_id}/v1/models - 模型列表 GET /proxy/{task_id}/health - 健康检查 """ import time import uuid from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import JSONResponse from app.core.auth import get_current_user from app.schemas.deployment import ( DeployConfig, DeployResponse, DeployServeConfig, DeployedServiceInfo, ) from app.services import api_key_service, deploy_service # 管理端点(需要 JWT 登录) router = APIRouter() # 代理端点(需要 API Key,不需要 JWT) proxy_router = APIRouter() # --------------------------------------------------------------------------- # API Key 验证(代理端点专用) # --------------------------------------------------------------------------- async def _extract_api_key(request: Request) -> str | None: """从 Authorization: Bearer sk-xxx 提取 API Key。""" auth_header = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:].strip() if token.startswith("sk-"): return token return None async def _validate_proxy_auth(task_id: str, request: Request) -> None: """验证代理请求的 API Key,并检查用户是否拥有该部署任务。""" api_key = await _extract_api_key(request) if not api_key: raise HTTPException( status_code=401, detail={"error": {"message": "Missing API key. Use Authorization: Bearer sk-xxx", "type": "auth_error"}}, ) key_info = await api_key_service.validate_api_key(api_key) if not key_info: raise HTTPException( status_code=401, detail={"error": {"message": "Invalid or revoked API key", "type": "auth_error"}}, ) # 检查用户是否拥有该部署任务 if not await api_key_service.check_deploy_ownership(task_id, key_info["user_id"]): raise HTTPException( status_code=403, detail={"error": {"message": "Access denied: you do not own this deployment", "type": "permission_error"}}, ) # --------------------------------------------------------------------------- # 管理端点(JWT 认证) # --------------------------------------------------------------------------- @router.post("/export", response_model=DeployResponse) async def export_adapter( config: DeployConfig, current_user: dict = Depends(get_current_user), ): """启动导出后台任务(导出模型文件),立即返回 task_id。""" user_id = current_user.get("sub") result = await deploy_service.export_adapter( config.job_id, {"merge_with_base": config.merge_with_base, "export_format": config.export_format}, user_id=user_id, ) return DeployResponse(**result) @router.post("/serve", response_model=DeployResponse) async def serve_model( config: DeployServeConfig, current_user: dict = Depends(get_current_user), ): """部署为在线推理服务(OpenAI 兼容 API)。 151 提供代理 API,253 运行纯推理 worker。 启动后通过 base_url 调用 /v1/chat/completions 等接口。 """ user_id = current_user.get("sub") try: result = await deploy_service.start_serving( config.job_id, {"merge_with_base": config.merge_with_base, "port": config.port, "host": config.host}, user_id=user_id, ) return DeployResponse(**result) except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) @router.get("/services", response_model=list[DeployedServiceInfo]) async def list_services(current_user: dict = Depends(get_current_user)): """列出当前用户已部署的在线服务。""" user_id = current_user.get("sub") services = await deploy_service.list_deployed_services(user_id) return [DeployedServiceInfo(**s) for s in services] @router.post("/{task_id}/stop") async def stop_serving( task_id: str, current_user: dict = Depends(get_current_user), ): """停止已部署的在线服务。""" user_id = current_user.get("sub") result = await deploy_service.stop_serving(task_id, user_id) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) return result @router.post("/{task_id}/restart") async def restart_serving( task_id: str, current_user: dict = Depends(get_current_user), ): """重启已停止的在线服务(不重新导出模型,只启动 worker)。""" user_id = current_user.get("sub") try: result = await deploy_service.restart_serving(task_id, user_id) except RuntimeError as e: raise HTTPException(status_code=400, detail=str(e)) if "error" in result: raise HTTPException(status_code=400, detail=result["error"]) return result @router.get("/{deploy_id}/status", response_model=DeployResponse) async def get_deployment_status(deploy_id: str): """获取导出/部署任务状态。""" result = await deploy_service.get_deploy_status(deploy_id) return DeployResponse(**result) # --------------------------------------------------------------------------- # OpenAI 兼容代理端点(API Key 认证) # --------------------------------------------------------------------------- @proxy_router.post("/proxy/{task_id}/v1/chat/completions") async def proxy_chat_completions(task_id: str, request: Request): """OpenAI 兼容的聊天补全代理。""" await _validate_proxy_auth(task_id, request) try: body = await request.json() except Exception: raise HTTPException(status_code=400, detail="Invalid JSON") messages = body.get("messages", []) if not messages: raise HTTPException(status_code=400, detail="messages is required") worker_req = { "messages": messages, "max_tokens": body.get("max_tokens", 512), "temperature": body.get("temperature", 0.7), "top_p": body.get("top_p", 0.9), "do_sample": body.get("temperature", 0.7) > 0, "repetition_penalty": body.get("repetition_penalty", 1.0), } worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req) if "error" in worker_resp: return JSONResponse( status_code=502, content={"error": {"message": worker_resp["error"], "type": "upstream_error"}}, ) model = body.get("model", "local-model") return { "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": worker_resp.get("generated_text", ""), }, "finish_reason": "stop", }], "usage": { "prompt_tokens": worker_resp.get("prompt_tokens", 0), "completion_tokens": worker_resp.get("completion_tokens", 0), "total_tokens": worker_resp.get("total_tokens", 0), }, } @proxy_router.post("/proxy/{task_id}/v1/completions") async def proxy_completions(task_id: str, request: Request): """OpenAI 兼容的文本补全代理。""" await _validate_proxy_auth(task_id, request) try: body = await request.json() except Exception: raise HTTPException(status_code=400, detail="Invalid JSON") prompt = body.get("prompt", "") if not prompt: raise HTTPException(status_code=400, detail="prompt is required") worker_req = { "prompt": prompt, "max_tokens": body.get("max_tokens", 512), "temperature": body.get("temperature", 0.7), "top_p": body.get("top_p", 0.9), "do_sample": body.get("temperature", 0.7) > 0, "repetition_penalty": body.get("repetition_penalty", 1.0), } worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req) if "error" in worker_resp: return JSONResponse( status_code=502, content={"error": {"message": worker_resp["error"], "type": "upstream_error"}}, ) model = body.get("model", "local-model") return { "id": f"cmpl-{uuid.uuid4().hex[:12]}", "object": "text_completion", "created": int(time.time()), "model": model, "choices": [{ "index": 0, "text": worker_resp.get("generated_text", ""), "finish_reason": "stop", }], "usage": { "prompt_tokens": worker_resp.get("prompt_tokens", 0), "completion_tokens": worker_resp.get("completion_tokens", 0), "total_tokens": worker_resp.get("total_tokens", 0), }, } @proxy_router.get("/proxy/{task_id}/v1/models") async def proxy_models(task_id: str, request: Request): """返回模型列表(代理)。""" await _validate_proxy_auth(task_id, request) result = await deploy_service.get_deploy_status(task_id) model_name = f"finetuned-{result.get('job_id', '')[:8]}" if result.get("job_id") else "local-model" return { "object": "list", "data": [{ "id": model_name, "object": "model", "created": int(time.time()), "owned_by": "local", }], } @proxy_router.get("/proxy/{task_id}/health") async def proxy_health(task_id: str, request: Request): """健康检查。""" await _validate_proxy_auth(task_id, request) result = await deploy_service.get_deploy_status(task_id) if result.get("status") != "running": return {"status": "error", "message": f"服务状态: {result.get('status', 'unknown')}"} return {"status": "ok", "task_id": task_id}