#!/usr/bin/env python3 """ RBAC权限管理API接口 """ from fastapi import HTTPException, Depends from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import BaseModel from typing import Optional, List, Dict, Any import json from datetime import datetime, timezone # 导入必要的模块 import pymysql from urllib.parse import urlparse import os from dotenv import load_dotenv load_dotenv() # 复制必要的工具函数以避免循环导入 def get_db_connection(): """获取数据库连接""" try: database_url = os.getenv('DATABASE_URL', '') if not database_url: return None parsed = urlparse(database_url) config = { 'host': parsed.hostname or 'localhost', 'port': parsed.port or 3306, 'user': parsed.username or 'root', 'password': parsed.password or '', 'database': parsed.path[1:] if parsed.path else 'sso_db', 'charset': 'utf8mb4' } return pymysql.connect(**config) except Exception as e: print(f"数据库连接失败: {e}") return None def verify_token(token: str) -> Optional[dict]: """验证令牌""" try: # 导入JWT库 try: import jwt as pyjwt test_token = pyjwt.encode({"test": "data"}, "secret", algorithm="HS256") jwt = pyjwt except (ImportError, AttributeError, TypeError): from jose import jwt JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-jwt-secret-key-12345") payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"]) return payload except Exception: return None class ApiResponse(BaseModel): code: int message: str data: Optional[dict] = None timestamp: str security = HTTPBearer() # 数据模型 class MenuCreate(BaseModel): parent_id: Optional[str] = None name: str title: str path: Optional[str] = None component: Optional[str] = None icon: Optional[str] = None sort_order: int = 0 menu_type: str = 'menu' is_hidden: bool = False description: Optional[str] = None class MenuUpdate(BaseModel): parent_id: Optional[str] = None title: str path: Optional[str] = None component: Optional[str] = None icon: Optional[str] = None sort_order: int = 0 menu_type: str = 'menu' is_hidden: bool = False is_active: bool = True description: Optional[str] = None class RoleCreate(BaseModel): name: str display_name: str description: Optional[str] = None class RoleUpdate(BaseModel): display_name: str description: Optional[str] = None is_active: bool = True class PermissionCreate(BaseModel): name: str display_name: str resource: str action: str description: Optional[str] = None class UserRoleAssign(BaseModel): user_id: str role_ids: List[str] class RoleMenuAssign(BaseModel): role_id: str menu_ids: List[str] class RolePermissionAssign(BaseModel): role_id: str permission_ids: List[str] # 权限检查装饰器 def check_permission(resource: str, action: str): """检查用户权限""" def decorator(func): async def wrapper(*args, **kwargs): # 从kwargs中获取credentials credentials = kwargs.get('credentials') if not credentials: return ApiResponse( code=401, message="未提供访问令牌", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() # 验证token payload = verify_token(credentials.credentials) if not payload: return ApiResponse( code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() user_id = payload.get("sub") # 检查用户权限 if not await has_permission(user_id, resource, action): return ApiResponse( code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() # 将user_id添加到kwargs中 kwargs['current_user_id'] = user_id return await func(*args, **kwargs) return wrapper return decorator async def has_permission(user_id: str, resource: str, action: str) -> bool: """检查用户是否有指定权限""" conn = get_db_connection() if not conn: return False cursor = conn.cursor() try: # 查询用户是否有指定权限 cursor.execute(""" SELECT COUNT(*) FROM user_roles ur JOIN role_permissions rp ON ur.role_id = rp.role_id JOIN permissions p ON rp.permission_id = p.id WHERE ur.user_id = %s AND ur.is_active = 1 AND p.resource = %s AND p.action = %s AND p.is_active = 1 """, (user_id, resource, action)) count = cursor.fetchone()[0] return count > 0 except Exception as e: print(f"权限检查错误: {e}") return False finally: cursor.close() conn.close() # 菜单管理API async def get_user_menus(credentials: HTTPAuthorizationCredentials = Depends(security)): """获取用户菜单""" try: payload = verify_token(credentials.credentials) if not payload: return ApiResponse( code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() user_id = payload.get("sub") conn = get_db_connection() if not conn: return ApiResponse( code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() cursor = conn.cursor() # 获取用户可访问的菜单 cursor.execute(""" SELECT DISTINCT m.id, m.parent_id, m.name, m.title, m.path, m.component, m.icon, m.sort_order, m.menu_type, m.is_hidden, m.is_active FROM menus m JOIN role_menus rm ON m.id = rm.menu_id JOIN user_roles ur ON rm.role_id = ur.role_id WHERE ur.user_id = %s AND ur.is_active = 1 AND m.is_active = 1 ORDER BY m.sort_order, m.created_at """, (user_id,)) menus = [] for row in cursor.fetchall(): menu = { "id": row[0], "parent_id": row[1], "name": row[2], "title": row[3], "path": row[4], "component": row[5], "icon": row[6], "sort_order": row[7], "menu_type": row[8], "is_hidden": bool(row[9]), "is_active": bool(row[10]), "children": [] } menus.append(menu) # 构建菜单树 menu_tree = build_menu_tree(menus) cursor.close() conn.close() return ApiResponse( code=0, message="获取用户菜单成功", data=menu_tree, timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() except Exception as e: print(f"获取用户菜单错误: {e}") return ApiResponse( code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() def build_menu_tree(menus: List[Dict]) -> List[Dict]: """构建菜单树结构""" menu_map = {menu["id"]: menu for menu in menus} tree = [] for menu in menus: if menu["parent_id"] is None: tree.append(menu) else: parent = menu_map.get(menu["parent_id"]) if parent: parent["children"].append(menu) return tree async def get_all_menus( page: int = 1, page_size: int = 20, keyword: Optional[str] = None, credentials: HTTPAuthorizationCredentials = Depends(security) ): """获取所有菜单(管理员)""" try: payload = verify_token(credentials.credentials) if not payload: return ApiResponse( code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() user_id = payload.get("sub") # 检查权限 if not await has_permission(user_id, "menu", "view"): return ApiResponse( code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() conn = get_db_connection() if not conn: return ApiResponse( code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() cursor = conn.cursor() # 构建查询条件 where_conditions = [] params = [] if keyword: where_conditions.append("(m.title LIKE %s OR m.name LIKE %s)") params.extend([f"%{keyword}%", f"%{keyword}%"]) where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" # 查询总数 cursor.execute(f"SELECT COUNT(*) FROM menus m WHERE {where_clause}", params) total = cursor.fetchone()[0] # 查询菜单列表 offset = (page - 1) * page_size cursor.execute(f""" SELECT m.id, m.parent_id, m.name, m.title, m.path, m.component, m.icon, m.sort_order, m.menu_type, m.is_hidden, m.is_active, m.description, m.created_at, m.updated_at, pm.title as parent_title FROM menus m LEFT JOIN menus pm ON m.parent_id = pm.id WHERE {where_clause} ORDER BY m.sort_order, m.created_at LIMIT %s OFFSET %s """, params + [page_size, offset]) menus = [] for row in cursor.fetchall(): menu = { "id": row[0], "parent_id": row[1], "name": row[2], "title": row[3], "path": row[4], "component": row[5], "icon": row[6], "sort_order": row[7], "menu_type": row[8], "is_hidden": bool(row[9]), "is_active": bool(row[10]), "description": row[11], "created_at": row[12].isoformat() if row[12] else None, "updated_at": row[13].isoformat() if row[13] else None, "parent_title": row[14] } menus.append(menu) cursor.close() conn.close() return ApiResponse( code=0, message="获取菜单列表成功", data={ "items": menus, "total": total, "page": page, "page_size": page_size }, timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() except Exception as e: print(f"获取菜单列表错误: {e}") return ApiResponse( code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() # 角色管理API async def get_all_roles( page: int = 1, page_size: int = 20, keyword: Optional[str] = None, credentials: HTTPAuthorizationCredentials = Depends(security) ): """获取所有角色""" try: payload = verify_token(credentials.credentials) if not payload: return ApiResponse( code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() user_id = payload.get("sub") # 检查权限 if not await has_permission(user_id, "role", "view"): return ApiResponse( code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() conn = get_db_connection() if not conn: return ApiResponse( code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() cursor = conn.cursor() # 构建查询条件 where_conditions = [] params = [] if keyword: where_conditions.append("(r.display_name LIKE %s OR r.name LIKE %s)") params.extend([f"%{keyword}%", f"%{keyword}%"]) where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" # 查询总数 cursor.execute(f"SELECT COUNT(*) FROM roles r WHERE {where_clause}", params) total = cursor.fetchone()[0] # 查询角色列表 offset = (page - 1) * page_size cursor.execute(f""" SELECT r.id, r.name, r.display_name, r.description, r.is_active, r.is_system, r.created_at, r.updated_at, COUNT(ur.user_id) as user_count FROM roles r LEFT JOIN user_roles ur ON r.id = ur.role_id AND ur.is_active = 1 WHERE {where_clause} GROUP BY r.id ORDER BY r.is_system DESC, r.created_at LIMIT %s OFFSET %s """, params + [page_size, offset]) roles = [] for row in cursor.fetchall(): role = { "id": row[0], "name": row[1], "display_name": row[2], "description": row[3], "is_active": bool(row[4]), "is_system": bool(row[5]), "created_at": row[6].isoformat() if row[6] else None, "updated_at": row[7].isoformat() if row[7] else None, "user_count": row[8] } roles.append(role) cursor.close() conn.close() return ApiResponse( code=0, message="获取角色列表成功", data={ "items": roles, "total": total, "page": page, "page_size": page_size }, timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() except Exception as e: print(f"获取角色列表错误: {e}") return ApiResponse( code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() async def get_user_permissions(credentials: HTTPAuthorizationCredentials = Depends(security)): """获取用户权限""" try: payload = verify_token(credentials.credentials) if not payload: return ApiResponse( code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() user_id = payload.get("sub") conn = get_db_connection() if not conn: return ApiResponse( code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() cursor = conn.cursor() # 获取用户权限 cursor.execute(""" SELECT DISTINCT p.name, p.resource, p.action FROM permissions p JOIN role_permissions rp ON p.id = rp.permission_id JOIN user_roles ur ON rp.role_id = ur.role_id WHERE ur.user_id = %s AND ur.is_active = 1 AND p.is_active = 1 """, (user_id,)) permissions = [] for row in cursor.fetchall(): permissions.append({ "name": row[0], "resource": row[1], "action": row[2] }) cursor.close() conn.close() return ApiResponse( code=0, message="获取用户权限成功", data=permissions, timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() except Exception as e: print(f"获取用户权限错误: {e}") return ApiResponse( code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat() ).model_dump() # 导出API函数供main server使用 __all__ = [ 'get_user_menus', 'get_all_menus', 'get_all_roles', 'get_user_permissions', 'has_permission' ]