rbac_api.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. #!/usr/bin/env python3
  2. """
  3. RBAC权限管理API接口
  4. """
  5. from fastapi import HTTPException, Depends
  6. from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
  7. from pydantic import BaseModel
  8. from typing import Optional, List, Dict, Any
  9. import json
  10. from datetime import datetime, timezone
  11. # 导入必要的模块
  12. import pymysql
  13. from urllib.parse import urlparse
  14. import os
  15. from dotenv import load_dotenv
  16. load_dotenv()
  17. # 复制必要的工具函数以避免循环导入
  18. def get_db_connection():
  19. """获取数据库连接"""
  20. try:
  21. database_url = os.getenv('DATABASE_URL', '')
  22. if not database_url:
  23. return None
  24. parsed = urlparse(database_url)
  25. config = {
  26. 'host': parsed.hostname or 'localhost',
  27. 'port': parsed.port or 3306,
  28. 'user': parsed.username or 'root',
  29. 'password': parsed.password or '',
  30. 'database': parsed.path[1:] if parsed.path else 'sso_db',
  31. 'charset': 'utf8mb4'
  32. }
  33. return pymysql.connect(**config)
  34. except Exception as e:
  35. print(f"数据库连接失败: {e}")
  36. return None
  37. def verify_token(token: str) -> Optional[dict]:
  38. """验证令牌"""
  39. try:
  40. # 导入JWT库
  41. try:
  42. import jwt as pyjwt
  43. test_token = pyjwt.encode({"test": "data"}, "secret", algorithm="HS256")
  44. jwt = pyjwt
  45. except (ImportError, AttributeError, TypeError):
  46. from jose import jwt
  47. JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-jwt-secret-key-12345")
  48. payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"])
  49. return payload
  50. except Exception:
  51. return None
  52. class ApiResponse(BaseModel):
  53. code: int
  54. message: str
  55. data: Optional[dict] = None
  56. timestamp: str
  57. security = HTTPBearer()
  58. # 数据模型
  59. class MenuCreate(BaseModel):
  60. parent_id: Optional[str] = None
  61. name: str
  62. title: str
  63. path: Optional[str] = None
  64. component: Optional[str] = None
  65. icon: Optional[str] = None
  66. sort_order: int = 0
  67. menu_type: str = 'menu'
  68. is_hidden: bool = False
  69. description: Optional[str] = None
  70. class MenuUpdate(BaseModel):
  71. parent_id: Optional[str] = None
  72. title: str
  73. path: Optional[str] = None
  74. component: Optional[str] = None
  75. icon: Optional[str] = None
  76. sort_order: int = 0
  77. menu_type: str = 'menu'
  78. is_hidden: bool = False
  79. is_active: bool = True
  80. description: Optional[str] = None
  81. class RoleCreate(BaseModel):
  82. name: str
  83. display_name: str
  84. description: Optional[str] = None
  85. class RoleUpdate(BaseModel):
  86. display_name: str
  87. description: Optional[str] = None
  88. is_active: bool = True
  89. class PermissionCreate(BaseModel):
  90. name: str
  91. display_name: str
  92. resource: str
  93. action: str
  94. description: Optional[str] = None
  95. class UserRoleAssign(BaseModel):
  96. user_id: str
  97. role_ids: List[str]
  98. class RoleMenuAssign(BaseModel):
  99. role_id: str
  100. menu_ids: List[str]
  101. class RolePermissionAssign(BaseModel):
  102. role_id: str
  103. permission_ids: List[str]
  104. # 权限检查装饰器
  105. def check_permission(resource: str, action: str):
  106. """检查用户权限"""
  107. def decorator(func):
  108. async def wrapper(*args, **kwargs):
  109. # 从kwargs中获取credentials
  110. credentials = kwargs.get('credentials')
  111. if not credentials:
  112. return ApiResponse(
  113. code=401,
  114. message="未提供访问令牌",
  115. timestamp=datetime.now(timezone.utc).isoformat()
  116. ).model_dump()
  117. # 验证token
  118. payload = verify_token(credentials.credentials)
  119. if not payload:
  120. return ApiResponse(
  121. code=401,
  122. message="无效的访问令牌",
  123. timestamp=datetime.now(timezone.utc).isoformat()
  124. ).model_dump()
  125. user_id = payload.get("sub")
  126. # 检查用户权限
  127. if not await has_permission(user_id, resource, action):
  128. return ApiResponse(
  129. code=403,
  130. message="权限不足",
  131. timestamp=datetime.now(timezone.utc).isoformat()
  132. ).model_dump()
  133. # 将user_id添加到kwargs中
  134. kwargs['current_user_id'] = user_id
  135. return await func(*args, **kwargs)
  136. return wrapper
  137. return decorator
  138. async def has_permission(user_id: str, resource: str, action: str) -> bool:
  139. """检查用户是否有指定权限"""
  140. conn = get_db_connection()
  141. if not conn:
  142. return False
  143. cursor = conn.cursor()
  144. try:
  145. # 查询用户是否有指定权限
  146. cursor.execute("""
  147. SELECT COUNT(*) FROM user_roles ur
  148. JOIN role_permissions rp ON ur.role_id = rp.role_id
  149. JOIN permissions p ON rp.permission_id = p.id
  150. WHERE ur.user_id = %s
  151. AND ur.is_active = 1
  152. AND p.resource = %s
  153. AND p.action = %s
  154. AND p.is_active = 1
  155. """, (user_id, resource, action))
  156. count = cursor.fetchone()[0]
  157. return count > 0
  158. except Exception as e:
  159. print(f"权限检查错误: {e}")
  160. return False
  161. finally:
  162. cursor.close()
  163. conn.close()
  164. # 菜单管理API
  165. async def get_user_menus(credentials: HTTPAuthorizationCredentials = Depends(security)):
  166. """获取用户菜单"""
  167. try:
  168. payload = verify_token(credentials.credentials)
  169. if not payload:
  170. return ApiResponse(
  171. code=401,
  172. message="无效的访问令牌",
  173. timestamp=datetime.now(timezone.utc).isoformat()
  174. ).model_dump()
  175. user_id = payload.get("sub")
  176. conn = get_db_connection()
  177. if not conn:
  178. return ApiResponse(
  179. code=500,
  180. message="数据库连接失败",
  181. timestamp=datetime.now(timezone.utc).isoformat()
  182. ).model_dump()
  183. cursor = conn.cursor()
  184. # 获取用户可访问的菜单
  185. cursor.execute("""
  186. SELECT DISTINCT m.id, m.parent_id, m.name, m.title, m.path,
  187. m.component, m.icon, m.sort_order, m.menu_type,
  188. m.is_hidden, m.is_active
  189. FROM menus m
  190. JOIN role_menus rm ON m.id = rm.menu_id
  191. JOIN user_roles ur ON rm.role_id = ur.role_id
  192. WHERE ur.user_id = %s
  193. AND ur.is_active = 1
  194. AND m.is_active = 1
  195. ORDER BY m.sort_order, m.created_at
  196. """, (user_id,))
  197. menus = []
  198. for row in cursor.fetchall():
  199. menu = {
  200. "id": row[0],
  201. "parent_id": row[1],
  202. "name": row[2],
  203. "title": row[3],
  204. "path": row[4],
  205. "component": row[5],
  206. "icon": row[6],
  207. "sort_order": row[7],
  208. "menu_type": row[8],
  209. "is_hidden": bool(row[9]),
  210. "is_active": bool(row[10]),
  211. "children": []
  212. }
  213. menus.append(menu)
  214. # 构建菜单树
  215. menu_tree = build_menu_tree(menus)
  216. cursor.close()
  217. conn.close()
  218. return ApiResponse(
  219. code=0,
  220. message="获取用户菜单成功",
  221. data=menu_tree,
  222. timestamp=datetime.now(timezone.utc).isoformat()
  223. ).model_dump()
  224. except Exception as e:
  225. print(f"获取用户菜单错误: {e}")
  226. return ApiResponse(
  227. code=500,
  228. message="服务器内部错误",
  229. timestamp=datetime.now(timezone.utc).isoformat()
  230. ).model_dump()
  231. def build_menu_tree(menus: List[Dict]) -> List[Dict]:
  232. """构建菜单树结构"""
  233. menu_map = {menu["id"]: menu for menu in menus}
  234. tree = []
  235. for menu in menus:
  236. if menu["parent_id"] is None:
  237. tree.append(menu)
  238. else:
  239. parent = menu_map.get(menu["parent_id"])
  240. if parent:
  241. parent["children"].append(menu)
  242. return tree
  243. async def get_all_menus(
  244. page: int = 1,
  245. page_size: int = 20,
  246. keyword: Optional[str] = None,
  247. credentials: HTTPAuthorizationCredentials = Depends(security)
  248. ):
  249. """获取所有菜单(管理员)"""
  250. try:
  251. payload = verify_token(credentials.credentials)
  252. if not payload:
  253. return ApiResponse(
  254. code=401,
  255. message="无效的访问令牌",
  256. timestamp=datetime.now(timezone.utc).isoformat()
  257. ).model_dump()
  258. user_id = payload.get("sub")
  259. # 检查权限
  260. if not await has_permission(user_id, "menu", "view"):
  261. return ApiResponse(
  262. code=403,
  263. message="权限不足",
  264. timestamp=datetime.now(timezone.utc).isoformat()
  265. ).model_dump()
  266. conn = get_db_connection()
  267. if not conn:
  268. return ApiResponse(
  269. code=500,
  270. message="数据库连接失败",
  271. timestamp=datetime.now(timezone.utc).isoformat()
  272. ).model_dump()
  273. cursor = conn.cursor()
  274. # 构建查询条件
  275. where_conditions = []
  276. params = []
  277. if keyword:
  278. where_conditions.append("(m.title LIKE %s OR m.name LIKE %s)")
  279. params.extend([f"%{keyword}%", f"%{keyword}%"])
  280. where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
  281. # 查询总数
  282. cursor.execute(f"SELECT COUNT(*) FROM menus m WHERE {where_clause}", params)
  283. total = cursor.fetchone()[0]
  284. # 查询菜单列表
  285. offset = (page - 1) * page_size
  286. cursor.execute(f"""
  287. SELECT m.id, m.parent_id, m.name, m.title, m.path, m.component,
  288. m.icon, m.sort_order, m.menu_type, m.is_hidden, m.is_active,
  289. m.description, m.created_at, m.updated_at,
  290. pm.title as parent_title
  291. FROM menus m
  292. LEFT JOIN menus pm ON m.parent_id = pm.id
  293. WHERE {where_clause}
  294. ORDER BY m.sort_order, m.created_at
  295. LIMIT %s OFFSET %s
  296. """, params + [page_size, offset])
  297. menus = []
  298. for row in cursor.fetchall():
  299. menu = {
  300. "id": row[0],
  301. "parent_id": row[1],
  302. "name": row[2],
  303. "title": row[3],
  304. "path": row[4],
  305. "component": row[5],
  306. "icon": row[6],
  307. "sort_order": row[7],
  308. "menu_type": row[8],
  309. "is_hidden": bool(row[9]),
  310. "is_active": bool(row[10]),
  311. "description": row[11],
  312. "created_at": row[12].isoformat() if row[12] else None,
  313. "updated_at": row[13].isoformat() if row[13] else None,
  314. "parent_title": row[14]
  315. }
  316. menus.append(menu)
  317. cursor.close()
  318. conn.close()
  319. return ApiResponse(
  320. code=0,
  321. message="获取菜单列表成功",
  322. data={
  323. "items": menus,
  324. "total": total,
  325. "page": page,
  326. "page_size": page_size
  327. },
  328. timestamp=datetime.now(timezone.utc).isoformat()
  329. ).model_dump()
  330. except Exception as e:
  331. print(f"获取菜单列表错误: {e}")
  332. return ApiResponse(
  333. code=500,
  334. message="服务器内部错误",
  335. timestamp=datetime.now(timezone.utc).isoformat()
  336. ).model_dump()
  337. # 角色管理API
  338. async def get_all_roles(
  339. page: int = 1,
  340. page_size: int = 20,
  341. keyword: Optional[str] = None,
  342. credentials: HTTPAuthorizationCredentials = Depends(security)
  343. ):
  344. """获取所有角色"""
  345. try:
  346. payload = verify_token(credentials.credentials)
  347. if not payload:
  348. return ApiResponse(
  349. code=401,
  350. message="无效的访问令牌",
  351. timestamp=datetime.now(timezone.utc).isoformat()
  352. ).model_dump()
  353. user_id = payload.get("sub")
  354. # 检查权限
  355. if not await has_permission(user_id, "role", "view"):
  356. return ApiResponse(
  357. code=403,
  358. message="权限不足",
  359. timestamp=datetime.now(timezone.utc).isoformat()
  360. ).model_dump()
  361. conn = get_db_connection()
  362. if not conn:
  363. return ApiResponse(
  364. code=500,
  365. message="数据库连接失败",
  366. timestamp=datetime.now(timezone.utc).isoformat()
  367. ).model_dump()
  368. cursor = conn.cursor()
  369. # 构建查询条件
  370. where_conditions = []
  371. params = []
  372. if keyword:
  373. where_conditions.append("(r.display_name LIKE %s OR r.name LIKE %s)")
  374. params.extend([f"%{keyword}%", f"%{keyword}%"])
  375. where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
  376. # 查询总数
  377. cursor.execute(f"SELECT COUNT(*) FROM roles r WHERE {where_clause}", params)
  378. total = cursor.fetchone()[0]
  379. # 查询角色列表
  380. offset = (page - 1) * page_size
  381. cursor.execute(f"""
  382. SELECT r.id, r.name, r.display_name, r.description, r.is_active,
  383. r.is_system, r.created_at, r.updated_at,
  384. COUNT(ur.user_id) as user_count
  385. FROM roles r
  386. LEFT JOIN user_roles ur ON r.id = ur.role_id AND ur.is_active = 1
  387. WHERE {where_clause}
  388. GROUP BY r.id
  389. ORDER BY r.is_system DESC, r.created_at
  390. LIMIT %s OFFSET %s
  391. """, params + [page_size, offset])
  392. roles = []
  393. for row in cursor.fetchall():
  394. role = {
  395. "id": row[0],
  396. "name": row[1],
  397. "display_name": row[2],
  398. "description": row[3],
  399. "is_active": bool(row[4]),
  400. "is_system": bool(row[5]),
  401. "created_at": row[6].isoformat() if row[6] else None,
  402. "updated_at": row[7].isoformat() if row[7] else None,
  403. "user_count": row[8]
  404. }
  405. roles.append(role)
  406. cursor.close()
  407. conn.close()
  408. return ApiResponse(
  409. code=0,
  410. message="获取角色列表成功",
  411. data={
  412. "items": roles,
  413. "total": total,
  414. "page": page,
  415. "page_size": page_size
  416. },
  417. timestamp=datetime.now(timezone.utc).isoformat()
  418. ).model_dump()
  419. except Exception as e:
  420. print(f"获取角色列表错误: {e}")
  421. return ApiResponse(
  422. code=500,
  423. message="服务器内部错误",
  424. timestamp=datetime.now(timezone.utc).isoformat()
  425. ).model_dump()
  426. async def get_user_permissions(credentials: HTTPAuthorizationCredentials = Depends(security)):
  427. """获取用户权限"""
  428. try:
  429. payload = verify_token(credentials.credentials)
  430. if not payload:
  431. return ApiResponse(
  432. code=401,
  433. message="无效的访问令牌",
  434. timestamp=datetime.now(timezone.utc).isoformat()
  435. ).model_dump()
  436. user_id = payload.get("sub")
  437. conn = get_db_connection()
  438. if not conn:
  439. return ApiResponse(
  440. code=500,
  441. message="数据库连接失败",
  442. timestamp=datetime.now(timezone.utc).isoformat()
  443. ).model_dump()
  444. cursor = conn.cursor()
  445. # 获取用户权限
  446. cursor.execute("""
  447. SELECT DISTINCT p.name, p.resource, p.action
  448. FROM permissions p
  449. JOIN role_permissions rp ON p.id = rp.permission_id
  450. JOIN user_roles ur ON rp.role_id = ur.role_id
  451. WHERE ur.user_id = %s
  452. AND ur.is_active = 1
  453. AND p.is_active = 1
  454. """, (user_id,))
  455. permissions = []
  456. for row in cursor.fetchall():
  457. permissions.append({
  458. "name": row[0],
  459. "resource": row[1],
  460. "action": row[2]
  461. })
  462. cursor.close()
  463. conn.close()
  464. return ApiResponse(
  465. code=0,
  466. message="获取用户权限成功",
  467. data=permissions,
  468. timestamp=datetime.now(timezone.utc).isoformat()
  469. ).model_dump()
  470. except Exception as e:
  471. print(f"获取用户权限错误: {e}")
  472. return ApiResponse(
  473. code=500,
  474. message="服务器内部错误",
  475. timestamp=datetime.now(timezone.utc).isoformat()
  476. ).model_dump()
  477. # 导出API函数供main server使用
  478. __all__ = [
  479. 'get_user_menus',
  480. 'get_all_menus',
  481. 'get_all_roles',
  482. 'get_user_permissions',
  483. 'has_permission'
  484. ]