| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562 |
- #!/usr/bin/env python3
- """
- RBAC权限管理API接口
- """
- from fastapi import HTTPException, Depends
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
- from pydantic import BaseModel
- from typing import Optional, List, Dict, Any
- import json
- from datetime import datetime, timezone
- # 导入必要的模块
- import pymysql
- from urllib.parse import urlparse
- import os
- from dotenv import load_dotenv
- load_dotenv()
- # 复制必要的工具函数以避免循环导入
- def get_db_connection():
- """获取数据库连接"""
- try:
- database_url = os.getenv('DATABASE_URL', '')
- if not database_url:
- return None
-
- parsed = urlparse(database_url)
- config = {
- 'host': parsed.hostname or 'localhost',
- 'port': parsed.port or 3306,
- 'user': parsed.username or 'root',
- 'password': parsed.password or '',
- 'database': parsed.path[1:] if parsed.path else 'sso_db',
- 'charset': 'utf8mb4'
- }
-
- return pymysql.connect(**config)
- except Exception as e:
- print(f"数据库连接失败: {e}")
- return None
- def verify_token(token: str) -> Optional[dict]:
- """验证令牌"""
- try:
- # 导入JWT库
- try:
- import jwt as pyjwt
- test_token = pyjwt.encode({"test": "data"}, "secret", algorithm="HS256")
- jwt = pyjwt
- except (ImportError, AttributeError, TypeError):
- from jose import jwt
-
- JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-jwt-secret-key-12345")
- payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"])
- return payload
- except Exception:
- return None
- class ApiResponse(BaseModel):
- code: int
- message: str
- data: Optional[dict] = None
- timestamp: str
- security = HTTPBearer()
- # 数据模型
- class MenuCreate(BaseModel):
- parent_id: Optional[str] = None
- name: str
- title: str
- path: Optional[str] = None
- component: Optional[str] = None
- icon: Optional[str] = None
- sort_order: int = 0
- menu_type: str = 'menu'
- is_hidden: bool = False
- description: Optional[str] = None
- class MenuUpdate(BaseModel):
- parent_id: Optional[str] = None
- title: str
- path: Optional[str] = None
- component: Optional[str] = None
- icon: Optional[str] = None
- sort_order: int = 0
- menu_type: str = 'menu'
- is_hidden: bool = False
- is_active: bool = True
- description: Optional[str] = None
- class RoleCreate(BaseModel):
- name: str
- display_name: str
- description: Optional[str] = None
- class RoleUpdate(BaseModel):
- display_name: str
- description: Optional[str] = None
- is_active: bool = True
- class PermissionCreate(BaseModel):
- name: str
- display_name: str
- resource: str
- action: str
- description: Optional[str] = None
- class UserRoleAssign(BaseModel):
- user_id: str
- role_ids: List[str]
- class RoleMenuAssign(BaseModel):
- role_id: str
- menu_ids: List[str]
- class RolePermissionAssign(BaseModel):
- role_id: str
- permission_ids: List[str]
- # 权限检查装饰器
- def check_permission(resource: str, action: str):
- """检查用户权限"""
- def decorator(func):
- async def wrapper(*args, **kwargs):
- # 从kwargs中获取credentials
- credentials = kwargs.get('credentials')
- if not credentials:
- return ApiResponse(
- code=401,
- message="未提供访问令牌",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- # 验证token
- payload = verify_token(credentials.credentials)
- if not payload:
- return ApiResponse(
- code=401,
- message="无效的访问令牌",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- user_id = payload.get("sub")
-
- # 检查用户权限
- if not await has_permission(user_id, resource, action):
- return ApiResponse(
- code=403,
- message="权限不足",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- # 将user_id添加到kwargs中
- kwargs['current_user_id'] = user_id
- return await func(*args, **kwargs)
- return wrapper
- return decorator
- async def has_permission(user_id: str, resource: str, action: str) -> bool:
- """检查用户是否有指定权限"""
- conn = get_db_connection()
- if not conn:
- return False
-
- cursor = conn.cursor()
-
- try:
- # 查询用户是否有指定权限
- cursor.execute("""
- SELECT COUNT(*) FROM user_roles ur
- JOIN role_permissions rp ON ur.role_id = rp.role_id
- JOIN permissions p ON rp.permission_id = p.id
- WHERE ur.user_id = %s
- AND ur.is_active = 1
- AND p.resource = %s
- AND p.action = %s
- AND p.is_active = 1
- """, (user_id, resource, action))
-
- count = cursor.fetchone()[0]
- return count > 0
-
- except Exception as e:
- print(f"权限检查错误: {e}")
- return False
- finally:
- cursor.close()
- conn.close()
- # 菜单管理API
- async def get_user_menus(credentials: HTTPAuthorizationCredentials = Depends(security)):
- """获取用户菜单"""
- try:
- payload = verify_token(credentials.credentials)
- if not payload:
- return ApiResponse(
- code=401,
- message="无效的访问令牌",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- user_id = payload.get("sub")
-
- conn = get_db_connection()
- if not conn:
- return ApiResponse(
- code=500,
- message="数据库连接失败",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- cursor = conn.cursor()
-
- # 获取用户可访问的菜单
- cursor.execute("""
- SELECT DISTINCT m.id, m.parent_id, m.name, m.title, m.path,
- m.component, m.icon, m.sort_order, m.menu_type,
- m.is_hidden, m.is_active
- FROM menus m
- JOIN role_menus rm ON m.id = rm.menu_id
- JOIN user_roles ur ON rm.role_id = ur.role_id
- WHERE ur.user_id = %s
- AND ur.is_active = 1
- AND m.is_active = 1
- ORDER BY m.sort_order, m.created_at
- """, (user_id,))
-
- menus = []
- for row in cursor.fetchall():
- menu = {
- "id": row[0],
- "parent_id": row[1],
- "name": row[2],
- "title": row[3],
- "path": row[4],
- "component": row[5],
- "icon": row[6],
- "sort_order": row[7],
- "menu_type": row[8],
- "is_hidden": bool(row[9]),
- "is_active": bool(row[10]),
- "children": []
- }
- menus.append(menu)
-
- # 构建菜单树
- menu_tree = build_menu_tree(menus)
-
- cursor.close()
- conn.close()
-
- return ApiResponse(
- code=0,
- message="获取用户菜单成功",
- data=menu_tree,
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- except Exception as e:
- print(f"获取用户菜单错误: {e}")
- return ApiResponse(
- code=500,
- message="服务器内部错误",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
- def build_menu_tree(menus: List[Dict]) -> List[Dict]:
- """构建菜单树结构"""
- menu_map = {menu["id"]: menu for menu in menus}
- tree = []
-
- for menu in menus:
- if menu["parent_id"] is None:
- tree.append(menu)
- else:
- parent = menu_map.get(menu["parent_id"])
- if parent:
- parent["children"].append(menu)
-
- return tree
- async def get_all_menus(
- page: int = 1,
- page_size: int = 20,
- keyword: Optional[str] = None,
- credentials: HTTPAuthorizationCredentials = Depends(security)
- ):
- """获取所有菜单(管理员)"""
- try:
- payload = verify_token(credentials.credentials)
- if not payload:
- return ApiResponse(
- code=401,
- message="无效的访问令牌",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- user_id = payload.get("sub")
-
- # 检查权限
- if not await has_permission(user_id, "menu", "view"):
- return ApiResponse(
- code=403,
- message="权限不足",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- conn = get_db_connection()
- if not conn:
- return ApiResponse(
- code=500,
- message="数据库连接失败",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- cursor = conn.cursor()
-
- # 构建查询条件
- where_conditions = []
- params = []
-
- if keyword:
- where_conditions.append("(m.title LIKE %s OR m.name LIKE %s)")
- params.extend([f"%{keyword}%", f"%{keyword}%"])
-
- where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
-
- # 查询总数
- cursor.execute(f"SELECT COUNT(*) FROM menus m WHERE {where_clause}", params)
- total = cursor.fetchone()[0]
-
- # 查询菜单列表
- offset = (page - 1) * page_size
- cursor.execute(f"""
- SELECT m.id, m.parent_id, m.name, m.title, m.path, m.component,
- m.icon, m.sort_order, m.menu_type, m.is_hidden, m.is_active,
- m.description, m.created_at, m.updated_at,
- pm.title as parent_title
- FROM menus m
- LEFT JOIN menus pm ON m.parent_id = pm.id
- WHERE {where_clause}
- ORDER BY m.sort_order, m.created_at
- LIMIT %s OFFSET %s
- """, params + [page_size, offset])
-
- menus = []
- for row in cursor.fetchall():
- menu = {
- "id": row[0],
- "parent_id": row[1],
- "name": row[2],
- "title": row[3],
- "path": row[4],
- "component": row[5],
- "icon": row[6],
- "sort_order": row[7],
- "menu_type": row[8],
- "is_hidden": bool(row[9]),
- "is_active": bool(row[10]),
- "description": row[11],
- "created_at": row[12].isoformat() if row[12] else None,
- "updated_at": row[13].isoformat() if row[13] else None,
- "parent_title": row[14]
- }
- menus.append(menu)
-
- cursor.close()
- conn.close()
-
- return ApiResponse(
- code=0,
- message="获取菜单列表成功",
- data={
- "items": menus,
- "total": total,
- "page": page,
- "page_size": page_size
- },
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- except Exception as e:
- print(f"获取菜单列表错误: {e}")
- return ApiResponse(
- code=500,
- message="服务器内部错误",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
- # 角色管理API
- async def get_all_roles(
- page: int = 1,
- page_size: int = 20,
- keyword: Optional[str] = None,
- credentials: HTTPAuthorizationCredentials = Depends(security)
- ):
- """获取所有角色"""
- try:
- payload = verify_token(credentials.credentials)
- if not payload:
- return ApiResponse(
- code=401,
- message="无效的访问令牌",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- user_id = payload.get("sub")
-
- # 检查权限
- if not await has_permission(user_id, "role", "view"):
- return ApiResponse(
- code=403,
- message="权限不足",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- conn = get_db_connection()
- if not conn:
- return ApiResponse(
- code=500,
- message="数据库连接失败",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- cursor = conn.cursor()
-
- # 构建查询条件
- where_conditions = []
- params = []
-
- if keyword:
- where_conditions.append("(r.display_name LIKE %s OR r.name LIKE %s)")
- params.extend([f"%{keyword}%", f"%{keyword}%"])
-
- where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
-
- # 查询总数
- cursor.execute(f"SELECT COUNT(*) FROM roles r WHERE {where_clause}", params)
- total = cursor.fetchone()[0]
-
- # 查询角色列表
- offset = (page - 1) * page_size
- cursor.execute(f"""
- SELECT r.id, r.name, r.display_name, r.description, r.is_active,
- r.is_system, r.created_at, r.updated_at,
- COUNT(ur.user_id) as user_count
- FROM roles r
- LEFT JOIN user_roles ur ON r.id = ur.role_id AND ur.is_active = 1
- WHERE {where_clause}
- GROUP BY r.id
- ORDER BY r.is_system DESC, r.created_at
- LIMIT %s OFFSET %s
- """, params + [page_size, offset])
-
- roles = []
- for row in cursor.fetchall():
- role = {
- "id": row[0],
- "name": row[1],
- "display_name": row[2],
- "description": row[3],
- "is_active": bool(row[4]),
- "is_system": bool(row[5]),
- "created_at": row[6].isoformat() if row[6] else None,
- "updated_at": row[7].isoformat() if row[7] else None,
- "user_count": row[8]
- }
- roles.append(role)
-
- cursor.close()
- conn.close()
-
- return ApiResponse(
- code=0,
- message="获取角色列表成功",
- data={
- "items": roles,
- "total": total,
- "page": page,
- "page_size": page_size
- },
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- except Exception as e:
- print(f"获取角色列表错误: {e}")
- return ApiResponse(
- code=500,
- message="服务器内部错误",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
- async def get_user_permissions(credentials: HTTPAuthorizationCredentials = Depends(security)):
- """获取用户权限"""
- try:
- payload = verify_token(credentials.credentials)
- if not payload:
- return ApiResponse(
- code=401,
- message="无效的访问令牌",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- user_id = payload.get("sub")
-
- conn = get_db_connection()
- if not conn:
- return ApiResponse(
- code=500,
- message="数据库连接失败",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- cursor = conn.cursor()
-
- # 获取用户权限
- cursor.execute("""
- SELECT DISTINCT p.name, p.resource, p.action
- FROM permissions p
- JOIN role_permissions rp ON p.id = rp.permission_id
- JOIN user_roles ur ON rp.role_id = ur.role_id
- WHERE ur.user_id = %s
- AND ur.is_active = 1
- AND p.is_active = 1
- """, (user_id,))
-
- permissions = []
- for row in cursor.fetchall():
- permissions.append({
- "name": row[0],
- "resource": row[1],
- "action": row[2]
- })
-
- cursor.close()
- conn.close()
-
- return ApiResponse(
- code=0,
- message="获取用户权限成功",
- data=permissions,
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
-
- except Exception as e:
- print(f"获取用户权限错误: {e}")
- return ApiResponse(
- code=500,
- message="服务器内部错误",
- timestamp=datetime.now(timezone.utc).isoformat()
- ).model_dump()
- # 导出API函数供main server使用
- __all__ = [
- 'get_user_menus',
- 'get_all_menus',
- 'get_all_roles',
- 'get_user_permissions',
- 'has_permission'
- ]
|