auth_log_middleware.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """
  2. 全局权限+日志管理中间件
  3. 提供接口鉴权装饰器和请求日志记录
  4. 功能:
  5. - 角色权限验证(admin/user)
  6. - 请求日志记录(路径、用户ID、时间、响应状态、响应结果)
  7. - 本地环境默认管理员权限
  8. """
  9. import os
  10. import time
  11. import logging
  12. import functools
  13. from enum import Enum
  14. from typing import Optional, List, Callable
  15. from datetime import datetime
  16. from fastapi import Request, HTTPException, status, Depends
  17. from fastapi.responses import JSONResponse
  18. from sqlalchemy.orm import Session
  19. from app.database import get_db, SessionLocal
  20. from app.models.user import User
  21. from app.services.auth_service import AuthService
  22. from app.services.user_service import UserService
  23. from app.services.token_revocation_service import token_revocation_service
  24. logger = logging.getLogger(__name__)
  25. class UserRole(str, Enum):
  26. """用户角色枚举"""
  27. ADMIN = "admin"
  28. USER = "user"
  29. def get_debug_mode() -> bool:
  30. """获取调试模式状态"""
  31. return os.getenv("DEBUG", "False").lower() == "true"
  32. def get_default_admin_user(db: Session) -> Optional[User]:
  33. """获取默认管理员用户(本地环境使用)"""
  34. return db.query(User).filter(User.username == "admin").first()
  35. def extract_token_from_request(request: Request) -> Optional[str]:
  36. """从请求中提取JWT令牌"""
  37. auth_header = request.headers.get("Authorization")
  38. if auth_header and auth_header.startswith("Bearer "):
  39. return auth_header[7:]
  40. return None
  41. def get_user_from_token(token: str, db: Session) -> Optional[User]:
  42. """从令牌获取用户(支持 JWT 和平台 API Key)"""
  43. # 优先尝试 platform API Key(sk-aigc- 前缀)
  44. if token.startswith("sk-aigc-"):
  45. try:
  46. from app.services.platform_api_key_service import PlatformApiKeyService
  47. result = PlatformApiKeyService(db).verify_api_key(token)
  48. if result:
  49. user_id, _ = result
  50. user_service = UserService(db)
  51. return user_service.get_user_by_id(user_id)
  52. except Exception:
  53. pass
  54. return None
  55. # 否则按 JWT 处理
  56. try:
  57. payload = AuthService.verify_token(token)
  58. if token_revocation_service.is_payload_revoked(payload):
  59. return None
  60. user_id = payload.get("user_id")
  61. if user_id:
  62. user_service = UserService(db)
  63. return user_service.get_user_by_id(user_id)
  64. except Exception:
  65. pass
  66. return None
  67. def get_user_role(user: User) -> UserRole:
  68. """获取用户角色"""
  69. if user.username == "admin":
  70. return UserRole.ADMIN
  71. return UserRole.USER
  72. class RequestLogger:
  73. """请求日志记录器"""
  74. @staticmethod
  75. def log_request(
  76. request: Request,
  77. user_id: Optional[str],
  78. status_code: int,
  79. response_data: Optional[dict],
  80. duration_ms: float
  81. ):
  82. """记录请求日志"""
  83. log_data = {
  84. "timestamp": datetime.now().isoformat(),
  85. "method": request.method,
  86. "path": str(request.url.path),
  87. "query_params": str(request.query_params),
  88. "user_id": user_id or "anonymous",
  89. "status_code": status_code,
  90. "duration_ms": round(duration_ms, 2),
  91. "response": response_data
  92. }
  93. if status_code >= 400:
  94. logger.warning(f"Request failed: {log_data}")
  95. else:
  96. logger.info(f"Request completed: {log_data}")
  97. def require_auth(
  98. roles: Optional[List[UserRole]] = None,
  99. allow_anonymous: bool = False
  100. ):
  101. """
  102. 权限验证装饰器
  103. Args:
  104. roles: 允许的角色列表,None表示只需登录
  105. allow_anonymous: 是否允许匿名访问(仅在DEBUG模式下生效)
  106. """
  107. def decorator(func: Callable):
  108. @functools.wraps(func)
  109. async def wrapper(*args, **kwargs):
  110. request: Request = kwargs.get('request')
  111. db: Session = kwargs.get('db')
  112. if not request or not db:
  113. for arg in args:
  114. if isinstance(arg, Request):
  115. request = arg
  116. elif isinstance(arg, Session):
  117. db = arg
  118. start_time = time.time()
  119. user = None
  120. user_id = None
  121. try:
  122. is_debug = get_debug_mode()
  123. token = extract_token_from_request(request)
  124. if token:
  125. user = get_user_from_token(token, db)
  126. # 本地环境且无token时,使用默认管理员
  127. if not user and is_debug:
  128. user = get_default_admin_user(db)
  129. if user:
  130. logger.debug(f"DEBUG模式:使用默认管理员用户 {user.username}")
  131. # 非匿名接口必须有用户
  132. if not user and not allow_anonymous:
  133. raise HTTPException(
  134. status_code=status.HTTP_401_UNAUTHORIZED,
  135. detail="Authentication required"
  136. )
  137. if user:
  138. user_id = user.id
  139. # 角色验证
  140. if roles:
  141. user_role = get_user_role(user)
  142. if user_role not in roles:
  143. raise HTTPException(
  144. status_code=status.HTTP_403_FORBIDDEN,
  145. detail=f"Permission denied. Required roles: {[r.value for r in roles]}"
  146. )
  147. # 将用户注入到kwargs
  148. kwargs['current_user'] = user
  149. result = await func(*args, **kwargs) if asyncio.iscoroutinefunction(func) else func(*args, **kwargs)
  150. duration_ms = (time.time() - start_time) * 1000
  151. response_data = None
  152. status_code = 200
  153. if hasattr(result, 'status_code'):
  154. status_code = result.status_code
  155. if hasattr(result, 'body'):
  156. try:
  157. import json
  158. response_data = json.loads(result.body)
  159. except:
  160. response_data = {"type": "non-json"}
  161. RequestLogger.log_request(request, user_id, status_code, response_data, duration_ms)
  162. return result
  163. except HTTPException as e:
  164. duration_ms = (time.time() - start_time) * 1000
  165. RequestLogger.log_request(
  166. request, user_id, e.status_code,
  167. {"error": e.detail}, duration_ms
  168. )
  169. raise
  170. except Exception as e:
  171. duration_ms = (time.time() - start_time) * 1000
  172. RequestLogger.log_request(
  173. request, user_id, 500,
  174. {"error": str(e)}, duration_ms
  175. )
  176. raise
  177. return wrapper
  178. return decorator
  179. import asyncio
  180. def get_current_user_from_request(
  181. request: Request,
  182. db: Session = Depends(get_db)
  183. ) -> User:
  184. """
  185. FastAPI依赖:从请求获取当前用户
  186. 安全:DEBUG 模式下的默认管理员仅在本地回环地址生效,
  187. 防止生产环境误开 DEBUG 导致任意请求获得管理员权限。
  188. """
  189. token = extract_token_from_request(request)
  190. user = None
  191. if token:
  192. user = get_user_from_token(token, db)
  193. if not user and get_debug_mode():
  194. # 仅本地回环地址允许 DEBUG 默认管理员
  195. client_ip = request.client.host if request.client else ""
  196. if client_ip in ("127.0.0.1", "::1", "localhost"):
  197. user = get_default_admin_user(db)
  198. if not user:
  199. raise HTTPException(
  200. status_code=status.HTTP_401_UNAUTHORIZED,
  201. detail="Authentication required",
  202. headers={"WWW-Authenticate": "Bearer"}
  203. )
  204. return user
  205. def require_role(roles: List[UserRole]):
  206. """
  207. 角色验证依赖
  208. 用法: Depends(require_role([UserRole.ADMIN]))
  209. """
  210. def role_checker(
  211. request: Request,
  212. db: Session = Depends(get_db)
  213. ) -> User:
  214. user = get_current_user_from_request(request, db)
  215. user_role = get_user_role(user)
  216. if user_role not in roles:
  217. raise HTTPException(
  218. status_code=status.HTTP_403_FORBIDDEN,
  219. detail=f"Permission denied. Required roles: {[r.value for r in roles]}"
  220. )
  221. return user
  222. return role_checker