| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- """部署 API —— 导出、在线服务管理、OpenAI 兼容代理。
- 路由前缀:
- /api/v1/deployment - 管理端点(JWT 认证)
- /api/v1/deployment/proxy - 代理端点(API Key 认证)
- 代理路由:
- POST /proxy/{task_id}/v1/chat/completions - OpenAI 兼容聊天补全
- POST /proxy/{task_id}/v1/completions - OpenAI 兼容文本补全
- GET /proxy/{task_id}/v1/models - 模型列表
- GET /proxy/{task_id}/health - 健康检查
- """
- import time
- import uuid
- from fastapi import APIRouter, Depends, HTTPException, Request
- from fastapi.responses import JSONResponse
- from app.core.auth import get_current_user
- from app.schemas.deployment import (
- DeployConfig, DeployResponse, DeployServeConfig, DeployedServiceInfo,
- )
- from app.services import api_key_service, deploy_service
- # 管理端点(需要 JWT 登录)
- router = APIRouter()
- # 代理端点(需要 API Key,不需要 JWT)
- proxy_router = APIRouter()
- # ---------------------------------------------------------------------------
- # API Key 验证(代理端点专用)
- # ---------------------------------------------------------------------------
- async def _extract_api_key(request: Request) -> str | None:
- """从 Authorization: Bearer sk-xxx 提取 API Key。"""
- auth_header = request.headers.get("Authorization", "")
- if auth_header.startswith("Bearer "):
- token = auth_header[7:].strip()
- if token.startswith("sk-"):
- return token
- return None
- async def _validate_proxy_auth(task_id: str, request: Request) -> None:
- """验证代理请求的 API Key,并检查用户是否拥有该部署任务。"""
- api_key = await _extract_api_key(request)
- if not api_key:
- raise HTTPException(
- status_code=401,
- detail={"error": {"message": "Missing API key. Use Authorization: Bearer sk-xxx", "type": "auth_error"}},
- )
- key_info = await api_key_service.validate_api_key(api_key)
- if not key_info:
- raise HTTPException(
- status_code=401,
- detail={"error": {"message": "Invalid or revoked API key", "type": "auth_error"}},
- )
- # 检查用户是否拥有该部署任务
- if not await api_key_service.check_deploy_ownership(task_id, key_info["user_id"]):
- raise HTTPException(
- status_code=403,
- detail={"error": {"message": "Access denied: you do not own this deployment", "type": "permission_error"}},
- )
- # ---------------------------------------------------------------------------
- # 管理端点(JWT 认证)
- # ---------------------------------------------------------------------------
- @router.post("/export", response_model=DeployResponse)
- async def export_adapter(
- config: DeployConfig,
- current_user: dict = Depends(get_current_user),
- ):
- """启动导出后台任务(导出模型文件),立即返回 task_id。"""
- user_id = current_user.get("sub")
- result = await deploy_service.export_adapter(
- config.job_id,
- {"merge_with_base": config.merge_with_base, "export_format": config.export_format},
- user_id=user_id,
- )
- return DeployResponse(**result)
- @router.post("/serve", response_model=DeployResponse)
- async def serve_model(
- config: DeployServeConfig,
- current_user: dict = Depends(get_current_user),
- ):
- """部署为在线推理服务(OpenAI 兼容 API)。
- 151 提供代理 API,253 运行纯推理 worker。
- 启动后通过 base_url 调用 /v1/chat/completions 等接口。
- """
- user_id = current_user.get("sub")
- try:
- result = await deploy_service.start_serving(
- config.job_id,
- {"merge_with_base": config.merge_with_base, "port": config.port, "host": config.host},
- user_id=user_id,
- )
- return DeployResponse(**result)
- except RuntimeError as e:
- raise HTTPException(status_code=400, detail=str(e))
- @router.get("/services", response_model=list[DeployedServiceInfo])
- async def list_services(current_user: dict = Depends(get_current_user)):
- """列出当前用户已部署的在线服务。"""
- user_id = current_user.get("sub")
- services = await deploy_service.list_deployed_services(user_id)
- return [DeployedServiceInfo(**s) for s in services]
- @router.post("/{task_id}/stop")
- async def stop_serving(
- task_id: str,
- current_user: dict = Depends(get_current_user),
- ):
- """停止已部署的在线服务。"""
- user_id = current_user.get("sub")
- result = await deploy_service.stop_serving(task_id, user_id)
- if "error" in result:
- raise HTTPException(status_code=400, detail=result["error"])
- return result
- @router.get("/{deploy_id}/status", response_model=DeployResponse)
- async def get_deployment_status(deploy_id: str):
- """获取导出/部署任务状态。"""
- result = await deploy_service.get_deploy_status(deploy_id)
- return DeployResponse(**result)
- # ---------------------------------------------------------------------------
- # OpenAI 兼容代理端点(API Key 认证)
- # ---------------------------------------------------------------------------
- @proxy_router.post("/proxy/{task_id}/v1/chat/completions")
- async def proxy_chat_completions(task_id: str, request: Request):
- """OpenAI 兼容的聊天补全代理。"""
- await _validate_proxy_auth(task_id, request)
- try:
- body = await request.json()
- except Exception:
- raise HTTPException(status_code=400, detail="Invalid JSON")
- messages = body.get("messages", [])
- if not messages:
- raise HTTPException(status_code=400, detail="messages is required")
- worker_req = {
- "messages": messages,
- "max_tokens": body.get("max_tokens", 512),
- "temperature": body.get("temperature", 0.7),
- "top_p": body.get("top_p", 0.9),
- "do_sample": body.get("temperature", 0.7) > 0,
- "repetition_penalty": body.get("repetition_penalty", 1.0),
- }
- worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req)
- if "error" in worker_resp:
- return JSONResponse(
- status_code=502,
- content={"error": {"message": worker_resp["error"], "type": "upstream_error"}},
- )
- model = body.get("model", "local-model")
- return {
- "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
- "object": "chat.completion",
- "created": int(time.time()),
- "model": model,
- "choices": [{
- "index": 0,
- "message": {
- "role": "assistant",
- "content": worker_resp.get("generated_text", ""),
- },
- "finish_reason": "stop",
- }],
- "usage": {
- "prompt_tokens": worker_resp.get("prompt_tokens", 0),
- "completion_tokens": worker_resp.get("completion_tokens", 0),
- "total_tokens": worker_resp.get("total_tokens", 0),
- },
- }
- @proxy_router.post("/proxy/{task_id}/v1/completions")
- async def proxy_completions(task_id: str, request: Request):
- """OpenAI 兼容的文本补全代理。"""
- await _validate_proxy_auth(task_id, request)
- try:
- body = await request.json()
- except Exception:
- raise HTTPException(status_code=400, detail="Invalid JSON")
- prompt = body.get("prompt", "")
- if not prompt:
- raise HTTPException(status_code=400, detail="prompt is required")
- worker_req = {
- "prompt": prompt,
- "max_tokens": body.get("max_tokens", 512),
- "temperature": body.get("temperature", 0.7),
- "top_p": body.get("top_p", 0.9),
- "do_sample": body.get("temperature", 0.7) > 0,
- "repetition_penalty": body.get("repetition_penalty", 1.0),
- }
- worker_resp = await deploy_service.proxy_to_worker(task_id, worker_req)
- if "error" in worker_resp:
- return JSONResponse(
- status_code=502,
- content={"error": {"message": worker_resp["error"], "type": "upstream_error"}},
- )
- model = body.get("model", "local-model")
- return {
- "id": f"cmpl-{uuid.uuid4().hex[:12]}",
- "object": "text_completion",
- "created": int(time.time()),
- "model": model,
- "choices": [{
- "index": 0,
- "text": worker_resp.get("generated_text", ""),
- "finish_reason": "stop",
- }],
- "usage": {
- "prompt_tokens": worker_resp.get("prompt_tokens", 0),
- "completion_tokens": worker_resp.get("completion_tokens", 0),
- "total_tokens": worker_resp.get("total_tokens", 0),
- },
- }
- @proxy_router.get("/proxy/{task_id}/v1/models")
- async def proxy_models(task_id: str, request: Request):
- """返回模型列表(代理)。"""
- await _validate_proxy_auth(task_id, request)
- result = await deploy_service.get_deploy_status(task_id)
- model_name = f"finetuned-{result.get('job_id', '')[:8]}" if result.get("job_id") else "local-model"
- return {
- "object": "list",
- "data": [{
- "id": model_name,
- "object": "model",
- "created": int(time.time()),
- "owned_by": "local",
- }],
- }
- @proxy_router.get("/proxy/{task_id}/health")
- async def proxy_health(task_id: str, request: Request):
- """健康检查。"""
- await _validate_proxy_auth(task_id, request)
- result = await deploy_service.get_deploy_status(task_id)
- if result.get("status") != "running":
- return {"status": "error", "message": f"服务状态: {result.get('status', 'unknown')}"}
- return {"status": "ok", "task_id": task_id}
|