| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- """
- 全局权限+日志管理中间件
- 提供接口鉴权装饰器和请求日志记录
- 功能:
- - 角色权限验证(admin/user)
- - 请求日志记录(路径、用户ID、时间、响应状态、响应结果)
- - 本地环境默认管理员权限
- """
- import os
- import time
- import logging
- import functools
- from enum import Enum
- from typing import Optional, List, Callable
- from datetime import datetime
- from fastapi import Request, HTTPException, status, Depends
- from fastapi.responses import JSONResponse
- from sqlalchemy.orm import Session
- from app.database import get_db, SessionLocal
- from app.models.user import User
- from app.services.auth_service import AuthService
- from app.services.user_service import UserService
- from app.services.token_revocation_service import token_revocation_service
- logger = logging.getLogger(__name__)
- class UserRole(str, Enum):
- """用户角色枚举"""
- ADMIN = "admin"
- USER = "user"
- def get_debug_mode() -> bool:
- """获取调试模式状态"""
- return os.getenv("DEBUG", "False").lower() == "true"
- def get_default_admin_user(db: Session) -> Optional[User]:
- """获取默认管理员用户(本地环境使用)"""
- return db.query(User).filter(User.username == "admin").first()
- def extract_token_from_request(request: Request) -> Optional[str]:
- """从请求中提取JWT令牌"""
- auth_header = request.headers.get("Authorization")
- if auth_header and auth_header.startswith("Bearer "):
- return auth_header[7:]
- return None
- def get_user_from_token(token: str, db: Session) -> Optional[User]:
- """从令牌获取用户(支持 JWT 和平台 API Key)"""
- # 优先尝试 platform API Key(sk-aigc- 前缀)
- if token.startswith("sk-aigc-"):
- try:
- from app.services.platform_api_key_service import PlatformApiKeyService
- result = PlatformApiKeyService(db).verify_api_key(token)
- if result:
- user_id, _ = result
- user_service = UserService(db)
- return user_service.get_user_by_id(user_id)
- except Exception:
- pass
- return None
- # 否则按 JWT 处理
- try:
- payload = AuthService.verify_token(token)
- if token_revocation_service.is_payload_revoked(payload):
- return None
- user_id = payload.get("user_id")
- if user_id:
- user_service = UserService(db)
- return user_service.get_user_by_id(user_id)
- except Exception:
- pass
- return None
- def get_user_role(user: User) -> UserRole:
- """获取用户角色"""
- if user.username == "admin":
- return UserRole.ADMIN
- return UserRole.USER
- class RequestLogger:
- """请求日志记录器"""
-
- @staticmethod
- def log_request(
- request: Request,
- user_id: Optional[str],
- status_code: int,
- response_data: Optional[dict],
- duration_ms: float
- ):
- """记录请求日志"""
- log_data = {
- "timestamp": datetime.now().isoformat(),
- "method": request.method,
- "path": str(request.url.path),
- "query_params": str(request.query_params),
- "user_id": user_id or "anonymous",
- "status_code": status_code,
- "duration_ms": round(duration_ms, 2),
- "response": response_data
- }
-
- if status_code >= 400:
- logger.warning(f"Request failed: {log_data}")
- else:
- logger.info(f"Request completed: {log_data}")
- def require_auth(
- roles: Optional[List[UserRole]] = None,
- allow_anonymous: bool = False
- ):
- """
- 权限验证装饰器
-
- Args:
- roles: 允许的角色列表,None表示只需登录
- allow_anonymous: 是否允许匿名访问(仅在DEBUG模式下生效)
- """
- def decorator(func: Callable):
- @functools.wraps(func)
- async def wrapper(*args, **kwargs):
- request: Request = kwargs.get('request')
- db: Session = kwargs.get('db')
-
- if not request or not db:
- for arg in args:
- if isinstance(arg, Request):
- request = arg
- elif isinstance(arg, Session):
- db = arg
-
- start_time = time.time()
- user = None
- user_id = None
-
- try:
- is_debug = get_debug_mode()
- token = extract_token_from_request(request)
-
- if token:
- user = get_user_from_token(token, db)
-
- # 本地环境且无token时,使用默认管理员
- if not user and is_debug:
- user = get_default_admin_user(db)
- if user:
- logger.debug(f"DEBUG模式:使用默认管理员用户 {user.username}")
-
- # 非匿名接口必须有用户
- if not user and not allow_anonymous:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Authentication required"
- )
-
- if user:
- user_id = user.id
- # 角色验证
- if roles:
- user_role = get_user_role(user)
- if user_role not in roles:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=f"Permission denied. Required roles: {[r.value for r in roles]}"
- )
-
- # 将用户注入到kwargs
- kwargs['current_user'] = user
-
- result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
-
- duration_ms = (time.time() - start_time) * 1000
- response_data = None
- status_code = 200
-
- if hasattr(result, 'status_code'):
- status_code = result.status_code
- if hasattr(result, 'body'):
- try:
- import json
- response_data = json.loads(result.body)
- except:
- response_data = {"type": "non-json"}
-
- RequestLogger.log_request(request, user_id, status_code, response_data, duration_ms)
- return result
-
- except HTTPException as e:
- duration_ms = (time.time() - start_time) * 1000
- RequestLogger.log_request(
- request, user_id, e.status_code,
- {"error": e.detail}, duration_ms
- )
- raise
- except Exception as e:
- duration_ms = (time.time() - start_time) * 1000
- RequestLogger.log_request(
- request, user_id, 500,
- {"error": str(e)}, duration_ms
- )
- raise
-
- return wrapper
- return decorator
- import asyncio
- def get_current_user_from_request(
- request: Request,
- db: Session = Depends(get_db)
- ) -> User:
- """
- FastAPI依赖:从请求获取当前用户
- 安全:DEBUG 模式下的默认管理员仅在本地回环地址生效,
- 防止生产环境误开 DEBUG 导致任意请求获得管理员权限。
- """
- token = extract_token_from_request(request)
- user = None
-
- if token:
- user = get_user_from_token(token, db)
-
- if not user and get_debug_mode():
- # 仅本地回环地址允许 DEBUG 默认管理员
- client_ip = request.client.host if request.client else ""
- if client_ip in ("127.0.0.1", "::1", "localhost"):
- user = get_default_admin_user(db)
-
- if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Authentication required",
- headers={"WWW-Authenticate": "Bearer"}
- )
-
- return user
- def require_role(roles: List[UserRole]):
- """
- 角色验证依赖
-
- 用法: Depends(require_role([UserRole.ADMIN]))
- """
- def role_checker(
- request: Request,
- db: Session = Depends(get_db)
- ) -> User:
- user = get_current_user_from_request(request, db)
- user_role = get_user_role(user)
-
- if user_role not in roles:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN,
- detail=f"Permission denied. Required roles: {[r.value for r in roles]}"
- )
- return user
-
- return role_checker
|