#!/usr/bin/env python3
"""
完整的SSO服务器 - 包含认证API
"""
import sys
import os
import socket
import json
import uuid
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import Optional, Any
import hashlib
import secrets
# 修复JWT导入 - 确保使用正确的JWT库
try:
# 首先尝试使用PyJWT
import jwt as pyjwt
# 测试是否有encode方法
test_token = pyjwt.encode({"test": "data"}, "secret", algorithm="HS256")
jwt = pyjwt
print("✅ 使用PyJWT库")
except (ImportError, AttributeError, TypeError) as e:
print(f"PyJWT导入失败: {e}")
try:
# 尝试使用python-jose
from jose import jwt
print("✅ 使用python-jose库")
except ImportError as e:
print(f"python-jose导入失败: {e}")
# 最后尝试安装PyJWT
print("尝试安装PyJWT...")
import subprocess
import sys
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "PyJWT"])
import jwt
print("✅ PyJWT安装成功")
except Exception as install_error:
print(f"❌ PyJWT安装失败: {install_error}")
raise ImportError("无法导入JWT库,请手动安装: pip install PyJWT")
from datetime import datetime, timedelta, timezone
import pymysql
from urllib.parse import urlparse
# 导入RBAC API - 移除循环导入
# from rbac_api import get_user_menus, get_all_menus, get_all_roles, get_user_permissions
# 数据模型
class LoginRequest(BaseModel):
username: str
password: str
remember_me: bool = False
class TokenResponse(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "Bearer"
expires_in: int
scope: Optional[str] = None
class UserInfo(BaseModel):
id: str
username: str
email: str
phone: Optional[str] = None
avatar_url: Optional[str] = None
is_active: bool
is_superuser: bool = False
roles: list = []
permissions: list = []
class ApiResponse(BaseModel):
code: int
message: str
data: Optional[Any] = None
timestamp: str
# 配置
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-jwt-secret-key-12345")
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
def check_port(port):
"""检查端口是否可用"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('localhost', port))
return True
except OSError:
return False
def find_available_port(start_port=8000, max_port=8010):
"""查找可用端口"""
for port in range(start_port, max_port + 1):
if check_port(port):
return port
return None
def get_db_connection():
"""获取数据库连接"""
try:
database_url = os.getenv('DATABASE_URL', '')
if not database_url:
return None
parsed = urlparse(database_url)
config = {
'host': parsed.hostname or 'localhost',
'port': parsed.port or 3306,
'user': parsed.username or 'root',
'password': parsed.password or '',
'database': parsed.path[1:] if parsed.path else 'sso_db',
'charset': 'utf8mb4'
}
return pymysql.connect(**config)
except Exception as e:
print(f"数据库连接失败: {e}")
return None
def verify_password_simple(password: str, stored_hash: str) -> bool:
"""验证密码(简化版)"""
if stored_hash.startswith("sha256$"):
parts = stored_hash.split("$")
if len(parts) == 3:
salt = parts[1]
expected_hash = parts[2]
actual_hash = hashlib.sha256((password + salt).encode()).hexdigest()
return actual_hash == expected_hash
return False
def create_access_token(data: dict) -> str:
"""创建访问令牌"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc)})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm="HS256")
return encoded_jwt
def verify_token(token: str) -> Optional[dict]:
"""验证令牌"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"])
return payload
except jwt.PyJWTError:
return None
# 创建FastAPI应用
app = FastAPI(
title="SSO认证中心",
version="1.0.0",
description="OAuth2单点登录认证中心",
docs_url="/docs",
redoc_url="/redoc"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
security = HTTPBearer()
@app.get("/")
async def root():
"""根路径"""
return ApiResponse(
code=0,
message="欢迎使用SSO认证中心",
data={
"name": "SSO认证中心",
"version": "1.0.0",
"docs": "/docs"
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/health")
async def health_check():
"""健康检查"""
return ApiResponse(
code=0,
message="服务正常运行",
data={
"status": "healthy",
"version": "1.0.0",
"timestamp": datetime.now(timezone.utc).isoformat()
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.post("/api/v1/auth/login")
async def login(request: Request, login_data: LoginRequest):
"""用户登录"""
print(f"🔐 收到登录请求: username={login_data.username}")
try:
# 获取数据库连接
print("📊 尝试连接数据库...")
conn = get_db_connection()
if not conn:
print("❌ 数据库连接失败")
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
print("✅ 数据库连接成功")
cursor = conn.cursor()
# 查找用户
print(f"🔍 查找用户: {login_data.username}")
cursor.execute(
"SELECT id, username, email, password_hash, is_active, is_superuser FROM users WHERE username = %s OR email = %s",
(login_data.username, login_data.username)
)
user_data = cursor.fetchone()
print(f"👤 用户查询结果: {user_data is not None}")
cursor.close()
conn.close()
if not user_data:
print("❌ 用户不存在")
return ApiResponse(
code=200001,
message="用户名或密码错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id, username, email, password_hash, is_active, is_superuser = user_data
print(f"✅ 找到用户: {username}, 激活状态: {is_active}")
# 检查用户状态
if not is_active:
print("❌ 用户已被禁用")
return ApiResponse(
code=200002,
message="用户已被禁用",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 验证密码
print(f"🔑 验证密码,哈希格式: {password_hash[:20]}...")
password_valid = verify_password_simple(login_data.password, password_hash)
print(f"🔑 密码验证结果: {password_valid}")
if not password_valid:
print("❌ 密码验证失败")
return ApiResponse(
code=200001,
message="用户名或密码错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 生成令牌
print("🎫 生成访问令牌...")
token_data = {
"sub": user_id,
"username": username,
"email": email,
"is_superuser": is_superuser
}
access_token = create_access_token(token_data)
print(f"✅ 令牌生成成功: {access_token[:50]}...")
token_response = TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
scope="profile email"
)
print("🎉 登录成功")
return ApiResponse(
code=0,
message="登录成功",
data=token_response.model_dump(),
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"❌ 登录错误详情: {type(e).__name__}: {str(e)}")
import traceback
print(f"❌ 错误堆栈: {traceback.format_exc()}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/users/profile")
async def get_user_profile(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取用户资料"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
if not user_id:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 查找用户详细信息
cursor.execute("""
SELECT u.id, u.username, u.email, u.phone, u.avatar_url, u.is_active, u.is_superuser,
u.last_login_at, u.created_at, u.updated_at,
p.real_name, p.company, p.department, p.position
FROM users u
LEFT JOIN user_profiles p ON u.id = p.user_id
WHERE u.id = %s
""", (user_id,))
user_data = cursor.fetchone()
# 获取用户角色
cursor.execute("""
SELECT r.name
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND ur.is_active = 1
""", (user_id,))
roles = [row[0] for row in cursor.fetchall()]
cursor.close()
conn.close()
if not user_data:
return ApiResponse(
code=200001,
message="用户不存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 构建用户信息
user_info = {
"id": user_data[0],
"username": user_data[1],
"email": user_data[2],
"phone": user_data[3],
"avatar_url": user_data[4],
"is_active": user_data[5],
"is_superuser": user_data[6],
"last_login_at": user_data[7].isoformat() if user_data[7] else None,
"created_at": user_data[8].isoformat() if user_data[8] else None,
"updated_at": user_data[9].isoformat() if user_data[9] else None,
"real_name": user_data[10],
"company": user_data[11],
"department": user_data[12],
"position": user_data[13],
"roles": roles
}
return ApiResponse(
code=0,
message="获取用户资料成功",
data=user_info,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户资料错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/users/profile")
async def update_user_profile(
request: Request,
profile_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新用户资料"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 更新用户基本信息
update_fields = []
update_values = []
if 'email' in profile_data:
update_fields.append('email = %s')
update_values.append(profile_data['email'])
if 'phone' in profile_data:
update_fields.append('phone = %s')
update_values.append(profile_data['phone'])
if update_fields:
update_values.append(user_id)
cursor.execute(f"""
UPDATE users
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
# 更新或插入用户详情
profile_fields = ['real_name', 'company', 'department', 'position']
profile_updates = {k: v for k, v in profile_data.items() if k in profile_fields}
if profile_updates:
# 检查是否已有记录
cursor.execute("SELECT id FROM user_profiles WHERE user_id = %s", (user_id,))
profile_exists = cursor.fetchone()
if profile_exists:
# 更新现有记录
update_fields = []
update_values = []
for field, value in profile_updates.items():
update_fields.append(f'{field} = %s')
update_values.append(value)
update_values.append(user_id)
cursor.execute(f"""
UPDATE user_profiles
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE user_id = %s
""", update_values)
else:
# 插入新记录
fields = ['user_id'] + list(profile_updates.keys())
values = [user_id] + list(profile_updates.values())
placeholders = ', '.join(['%s'] * len(values))
cursor.execute(f"""
INSERT INTO user_profiles ({', '.join(fields)}, created_at, updated_at)
VALUES ({placeholders}, NOW(), NOW())
""", values)
conn.commit()
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="用户资料更新成功",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"更新用户资料错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/users/password")
async def change_user_password(
request: Request,
password_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""修改用户密码"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
old_password = password_data.get('old_password')
new_password = password_data.get('new_password')
if not old_password or not new_password:
return ApiResponse(
code=100001,
message="缺少必要参数",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 验证当前密码
cursor.execute("SELECT password_hash FROM users WHERE id = %s", (user_id,))
result = cursor.fetchone()
if not result or not verify_password_simple(old_password, result[0]):
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="当前密码错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 生成新密码哈希
new_password_hash = hash_password_simple(new_password)
# 更新密码
cursor.execute("""
UPDATE users
SET password_hash = %s, updated_at = NOW()
WHERE id = %s
""", (new_password_hash, user_id))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="密码修改成功",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"修改密码错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def hash_password_simple(password):
"""简单的密码哈希"""
import hashlib
import secrets
# 生成盐值
salt = secrets.token_hex(16)
# 使用SHA256哈希
password_hash = hashlib.sha256((password + salt).encode()).hexdigest()
return f"sha256${salt}${password_hash}"
@app.post("/api/v1/auth/logout")
async def logout():
"""用户登出"""
return ApiResponse(
code=0,
message="登出成功",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# OAuth2 授权端点
@app.get("/oauth/authorize")
async def oauth_authorize(
response_type: str,
client_id: str,
redirect_uri: str,
scope: str = "profile",
state: str = None
):
"""OAuth2授权端点"""
try:
print(f"🔐 OAuth授权请求: client_id={client_id}, redirect_uri={redirect_uri}, scope={scope}")
# 验证必要参数
if not response_type or not client_id or not redirect_uri:
error_url = f"{redirect_uri}?error=invalid_request&error_description=Missing required parameters"
if state:
error_url += f"&state={state}"
return {"error": "invalid_request", "redirect_url": error_url}
# 验证response_type
if response_type != "code":
error_url = f"{redirect_uri}?error=unsupported_response_type&error_description=Only authorization code flow is supported"
if state:
error_url += f"&state={state}"
return {"error": "unsupported_response_type", "redirect_url": error_url}
# 获取数据库连接
conn = get_db_connection()
if not conn:
error_url = f"{redirect_uri}?error=server_error&error_description=Database connection failed"
if state:
error_url += f"&state={state}"
return {"error": "server_error", "redirect_url": error_url}
cursor = conn.cursor()
# 验证client_id和redirect_uri
cursor.execute("""
SELECT id, name, redirect_uris, scope, is_active, is_trusted
FROM apps
WHERE app_key = %s AND is_active = 1
""", (client_id,))
app_data = cursor.fetchone()
cursor.close()
conn.close()
if not app_data:
error_url = f"{redirect_uri}?error=invalid_client&error_description=Invalid client_id"
if state:
error_url += f"&state={state}"
return {"error": "invalid_client", "redirect_url": error_url}
app_id, app_name, redirect_uris_json, app_scope_json, is_active, is_trusted = app_data
# 验证redirect_uri
redirect_uris = json.loads(redirect_uris_json) if redirect_uris_json else []
if redirect_uri not in redirect_uris:
error_url = f"{redirect_uri}?error=invalid_request&error_description=Invalid redirect_uri"
if state:
error_url += f"&state={state}"
return {"error": "invalid_request", "redirect_url": error_url}
# 验证scope
app_scopes = json.loads(app_scope_json) if app_scope_json else []
requested_scopes = scope.split() if scope else []
invalid_scopes = [s for s in requested_scopes if s not in app_scopes]
if invalid_scopes:
error_url = f"{redirect_uri}?error=invalid_scope&error_description=Invalid scope: {' '.join(invalid_scopes)}"
if state:
error_url += f"&state={state}"
return {"error": "invalid_scope", "redirect_url": error_url}
# TODO: 检查用户登录状态
# 这里应该检查用户是否已登录(通过session或cookie)
# 如果未登录,应该重定向到登录页面
# 临时方案:返回登录页面,让用户先登录
# 生产环境应该使用session管理
# 构建登录页面URL,登录后返回授权页面
login_page_url = f"/oauth/login?response_type={response_type}&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}"
if state:
login_page_url += f"&state={state}"
print(f"🔐 需要用户登录,重定向到登录页面: {login_page_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=login_page_url, status_code=302)
# 非受信任应用需要用户授权确认
# 这里返回授权页面HTML
authorization_html = f"""
授权确认 - SSO认证中心
{app_name}
该应用请求以下权限:
"""
# 添加权限列表
scope_descriptions = {
"profile": "访问您的基本信息(用户名、头像等)",
"email": "访问您的邮箱地址",
"phone": "访问您的手机号码",
"roles": "访问您的角色和权限信息"
}
for scope_item in requested_scopes:
description = scope_descriptions.get(scope_item, f"访问 {scope_item} 信息")
authorization_html += f"- {description}
"
authorization_html += f"""
"""
from fastapi.responses import HTMLResponse
return HTMLResponse(content=authorization_html)
except Exception as e:
print(f"❌ OAuth授权错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Internal server error"
if state:
error_url += f"&state={state}"
return {"error": "server_error", "redirect_url": error_url}
@app.get("/oauth/login")
async def oauth_login_page(
response_type: str,
client_id: str,
redirect_uri: str,
scope: str = "profile",
state: str = None
):
"""OAuth2登录页面"""
try:
print(f"🔐 显示OAuth登录页面: client_id={client_id}")
# 获取应用信息
conn = get_db_connection()
if not conn:
return {"error": "server_error", "message": "数据库连接失败"}
cursor = conn.cursor()
cursor.execute("SELECT name FROM apps WHERE app_key = %s", (client_id,))
app_data = cursor.fetchone()
cursor.close()
conn.close()
app_name = app_data[0] if app_data else "未知应用"
# 构建登录页面HTML
login_html = f"""
SSO登录 - {app_name}
{app_name} 请求访问您的账户
测试账号: admin / Admin123456
"""
from fastapi.responses import HTMLResponse
return HTMLResponse(content=login_html)
except Exception as e:
print(f"❌ OAuth登录页面错误: {e}")
return {"error": "server_error", "message": "服务器内部错误"}
@app.get("/oauth/authorize/authenticated")
async def oauth_authorize_authenticated(
response_type: str,
client_id: str,
redirect_uri: str,
access_token: str,
scope: str = "profile",
state: str = None
):
"""用户已登录后的授权处理"""
try:
print(f"🔐 用户已登录,处理授权: client_id={client_id}")
# 验证访问令牌
payload = verify_token(access_token)
if not payload:
error_url = f"{redirect_uri}?error=invalid_token&error_description=Invalid access token"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
user_id = payload.get("sub")
username = payload.get("username", "")
print(f"✅ 用户已验证: {username} ({user_id})")
# 获取应用信息
conn = get_db_connection()
if not conn:
error_url = f"{redirect_uri}?error=server_error&error_description=Database connection failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
cursor = conn.cursor()
cursor.execute("SELECT name, is_trusted FROM apps WHERE app_key = %s", (client_id,))
app_data = cursor.fetchone()
cursor.close()
conn.close()
if not app_data:
error_url = f"{redirect_uri}?error=invalid_client&error_description=Invalid client"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
app_name, is_trusted = app_data
# 如果是受信任应用,直接授权
if is_trusted:
# 生成授权码
auth_code = secrets.token_urlsafe(32)
# TODO: 将授权码存储到数据库,关联用户和应用
# 这里简化处理,实际应该存储到数据库
# 重定向回应用
callback_url = f"{redirect_uri}?code={auth_code}"
if state:
callback_url += f"&state={state}"
print(f"✅ 受信任应用自动授权: {callback_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=callback_url, status_code=302)
# 非受信任应用,显示授权确认页面
# 这里可以返回授权确认页面的HTML
# 为简化,暂时也直接授权
auth_code = secrets.token_urlsafe(32)
callback_url = f"{redirect_uri}?code={auth_code}"
if state:
callback_url += f"&state={state}"
print(f"✅ 用户授权完成: {callback_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=callback_url, status_code=302)
except Exception as e:
print(f"❌ 授权处理错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Authorization failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
async def oauth_approve(
client_id: str,
redirect_uri: str,
scope: str = "profile",
state: str = None
):
"""用户同意授权"""
try:
print(f"✅ 用户同意授权: client_id={client_id}")
# 生成授权码
auth_code = secrets.token_urlsafe(32)
# TODO: 将授权码存储到数据库,关联用户和应用
# 这里简化处理,实际应该:
# 1. 验证用户登录状态
# 2. 将授权码存储到数据库
# 3. 设置过期时间(通常10分钟)
# 构建回调URL
callback_url = f"{redirect_uri}?code={auth_code}"
if state:
callback_url += f"&state={state}"
print(f"🔄 重定向到: {callback_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=callback_url, status_code=302)
except Exception as e:
print(f"❌ 授权确认错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Authorization failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
@app.get("/oauth/authorize/deny")
async def oauth_deny(
client_id: str,
redirect_uri: str,
state: str = None
):
"""用户拒绝授权"""
try:
print(f"❌ 用户拒绝授权: client_id={client_id}")
# 构建错误回调URL
error_url = f"{redirect_uri}?error=access_denied&error_description=User denied authorization"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
except Exception as e:
print(f"❌ 拒绝授权错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Authorization failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
@app.post("/oauth/token")
async def oauth_token(request: Request):
"""OAuth2令牌端点"""
try:
# 获取请求数据
form_data = await request.form()
grant_type = form_data.get("grant_type")
code = form_data.get("code")
redirect_uri = form_data.get("redirect_uri")
client_id = form_data.get("client_id")
client_secret = form_data.get("client_secret")
print(f"🎫 令牌请求: grant_type={grant_type}, client_id={client_id}")
# 验证grant_type
if grant_type != "authorization_code":
return {
"error": "unsupported_grant_type",
"error_description": "Only authorization_code grant type is supported"
}
# 验证必要参数
if not code or not redirect_uri or not client_id:
return {
"error": "invalid_request",
"error_description": "Missing required parameters"
}
# 获取数据库连接
conn = get_db_connection()
if not conn:
return {
"error": "server_error",
"error_description": "Database connection failed"
}
cursor = conn.cursor()
# 验证客户端
cursor.execute("""
SELECT id, name, app_secret, redirect_uris, scope, is_active
FROM apps
WHERE app_key = %s AND is_active = 1
""", (client_id,))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return {
"error": "invalid_client",
"error_description": "Invalid client credentials"
}
app_id, app_name, stored_secret, redirect_uris_json, scope_json, is_active = app_data
# 验证客户端密钥(如果提供了)
if client_secret and client_secret != stored_secret:
cursor.close()
conn.close()
return {
"error": "invalid_client",
"error_description": "Invalid client credentials"
}
# 验证redirect_uri
redirect_uris = json.loads(redirect_uris_json) if redirect_uris_json else []
if redirect_uri not in redirect_uris:
cursor.close()
conn.close()
return {
"error": "invalid_grant",
"error_description": "Invalid redirect_uri"
}
# TODO: 验证授权码
# 这里简化处理,实际应该:
# 1. 从数据库查找授权码
# 2. 验证授权码是否有效且未过期
# 3. 验证授权码是否已被使用
# 4. 获取关联的用户ID
# 模拟用户ID(实际应该从授权码记录中获取)
user_id = "ed6a79d3-0083-4d81-8b48-fc522f686f74" # admin用户ID
# 生成访问令牌
token_data = {
"sub": user_id,
"client_id": client_id,
"scope": "profile email"
}
access_token = create_access_token(token_data)
refresh_token = secrets.token_urlsafe(32)
# TODO: 将令牌存储到数据库
cursor.close()
conn.close()
# 返回令牌响应
token_response = {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
"refresh_token": refresh_token,
"scope": "profile email"
}
print(f"✅ 令牌生成成功: {access_token[:50]}...")
return token_response
except Exception as e:
print(f"❌ 令牌生成错误: {e}")
return {
"error": "server_error",
"error_description": "Internal server error"
}
@app.get("/oauth/userinfo")
async def oauth_userinfo(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""OAuth2用户信息端点"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return {
"error": "invalid_token",
"error_description": "Invalid or expired access token"
}
user_id = payload.get("sub")
client_id = payload.get("client_id")
scope = payload.get("scope", "").split()
print(f"👤 用户信息请求: user_id={user_id}, client_id={client_id}, scope={scope}")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return {
"error": "server_error",
"error_description": "Database connection failed"
}
cursor = conn.cursor()
# 查找用户信息
cursor.execute("""
SELECT u.id, u.username, u.email, u.phone, u.avatar_url, u.is_active,
p.real_name, p.company, p.department, p.position
FROM users u
LEFT JOIN user_profiles p ON u.id = p.user_id
WHERE u.id = %s AND u.is_active = 1
""", (user_id,))
user_data = cursor.fetchone()
cursor.close()
conn.close()
if not user_data:
return {
"error": "invalid_token",
"error_description": "User not found or inactive"
}
# 构建用户信息响应(根据scope过滤)
user_info = {"sub": user_data[0]}
if "profile" in scope:
user_info.update({
"username": user_data[1],
"avatar_url": user_data[4],
"real_name": user_data[6],
"company": user_data[7],
"department": user_data[8],
"position": user_data[9]
})
if "email" in scope:
user_info["email"] = user_data[2]
if "phone" in scope:
user_info["phone"] = user_data[3]
print(f"✅ 返回用户信息: {user_info}")
return user_info
except Exception as e:
print(f"❌ 获取用户信息错误: {e}")
return {
"error": "server_error",
"error_description": "Internal server error"
}
@app.get("/api/v1/apps")
async def get_apps(
page: int = 1,
page_size: int = 20,
keyword: str = "",
status: str = "",
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取应用列表"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查用户角色,决定是否显示所有应用
cursor.execute("""
SELECT COUNT(*) FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND r.name IN ('super_admin', 'admin', 'app_manager') AND ur.is_active = 1
""", (user_id,))
is_app_manager = cursor.fetchone()[0] > 0
# 构建查询条件
where_conditions = []
params = []
# 如果不是应用管理员,只显示自己创建的应用
if not is_app_manager:
where_conditions.append("created_by = %s")
params.append(user_id)
if keyword:
where_conditions.append("(name LIKE %s OR description LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
if status == "active":
where_conditions.append("is_active = 1")
elif status == "inactive":
where_conditions.append("is_active = 0")
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM apps WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询应用列表
offset = (page - 1) * page_size
cursor.execute(f"""
SELECT id, name, app_key, description, icon_url, redirect_uris, scope,
is_active, is_trusted, access_token_expires, refresh_token_expires,
created_at, updated_at
FROM apps
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""", params + [page_size, offset])
apps = []
for row in cursor.fetchall():
app = {
"id": row[0],
"name": row[1],
"app_key": row[2],
"description": row[3],
"icon_url": row[4],
"redirect_uris": json.loads(row[5]) if row[5] else [],
"scope": json.loads(row[6]) if row[6] else [],
"is_active": bool(row[7]),
"is_trusted": bool(row[8]),
"access_token_expires": row[9],
"refresh_token_expires": row[10],
"created_at": row[11].isoformat() if row[11] else None,
"updated_at": row[12].isoformat() if row[12] else None,
# 模拟统计数据
"today_requests": secrets.randbelow(1000),
"active_users": secrets.randbelow(100)
}
apps.append(app)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取应用列表成功",
data={
"items": apps,
"total": total,
"page": page,
"page_size": page_size
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取应用列表错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/apps/{app_id}")
async def get_app_detail(
app_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取应用详情(包含密钥)"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 查询应用详情(包含密钥)
cursor.execute("""
SELECT id, name, app_key, app_secret, description, icon_url,
redirect_uris, scope, is_active, is_trusted,
access_token_expires, refresh_token_expires,
created_at, updated_at
FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
cursor.close()
conn.close()
if not app_data:
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
app_detail = {
"id": app_data[0],
"name": app_data[1],
"app_key": app_data[2],
"app_secret": app_data[3],
"description": app_data[4],
"icon_url": app_data[5],
"redirect_uris": json.loads(app_data[6]) if app_data[6] else [],
"scope": json.loads(app_data[7]) if app_data[7] else [],
"is_active": bool(app_data[8]),
"is_trusted": bool(app_data[9]),
"access_token_expires": app_data[10],
"refresh_token_expires": app_data[11],
"created_at": app_data[12].isoformat() if app_data[12] else None,
"updated_at": app_data[13].isoformat() if app_data[13] else None
}
return ApiResponse(
code=0,
message="获取应用详情成功",
data=app_detail,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取应用详情错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.post("/api/v1/apps")
async def create_app(
request: Request,
app_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建应用"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 验证必要字段
if not app_data.get('name') or not app_data.get('redirect_uris'):
return ApiResponse(
code=100001,
message="缺少必要参数",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 生成应用ID和密钥
app_id = str(uuid.uuid4())
app_key = generate_random_string(32)
app_secret = generate_random_string(64)
# 插入应用记录
cursor.execute("""
INSERT INTO apps (
id, name, app_key, app_secret, description, icon_url,
redirect_uris, scope, is_active, is_trusted,
access_token_expires, refresh_token_expires, created_by,
created_at, updated_at
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW()
)
""", (
app_id,
app_data['name'],
app_key,
app_secret,
app_data.get('description', ''),
app_data.get('icon_url', ''),
json.dumps(app_data['redirect_uris']),
json.dumps(app_data.get('scope', ['profile'])),
True,
app_data.get('is_trusted', False),
app_data.get('access_token_expires', 7200),
app_data.get('refresh_token_expires', 2592000),
user_id
))
conn.commit()
cursor.close()
conn.close()
# 返回创建的应用信息
app_info = {
"id": app_id,
"name": app_data['name'],
"app_key": app_key,
"app_secret": app_secret,
"description": app_data.get('description', ''),
"redirect_uris": app_data['redirect_uris'],
"scope": app_data.get('scope', ['profile']),
"is_active": True,
"is_trusted": app_data.get('is_trusted', False)
}
return ApiResponse(
code=0,
message="应用创建成功",
data=app_info,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"创建应用错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/apps/{app_id}/status")
async def toggle_app_status(
app_id: str,
status_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""切换应用状态"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
is_active = status_data.get('is_active')
if is_active is None:
return ApiResponse(
code=100001,
message="缺少必要参数",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 更新应用状态
cursor.execute("""
UPDATE apps
SET is_active = %s, updated_at = NOW()
WHERE id = %s
""", (is_active, app_id))
conn.commit()
cursor.close()
conn.close()
action = "启用" if is_active else "禁用"
print(f"✅ 应用状态已更新: {app_data[1]} -> {action}")
return ApiResponse(
code=0,
message=f"应用已{action}",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"切换应用状态错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/apps/{app_id}")
async def update_app(
app_id: str,
app_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新应用信息"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 验证必要参数
name = app_data.get('name', '').strip()
if not name:
return ApiResponse(
code=100001,
message="应用名称不能为空",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
existing_app = cursor.fetchone()
if not existing_app:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查应用名称是否已被其他应用使用
cursor.execute("""
SELECT id FROM apps
WHERE name = %s AND created_by = %s AND id != %s
""", (name, user_id, app_id))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用名称已存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 准备更新数据
description = (app_data.get('description') or '').strip()
icon_url = (app_data.get('icon_url') or '').strip()
redirect_uris = app_data.get('redirect_uris', [])
scope = app_data.get('scope', ['profile', 'email'])
is_trusted = app_data.get('is_trusted', False)
access_token_expires = app_data.get('access_token_expires', 7200)
refresh_token_expires = app_data.get('refresh_token_expires', 2592000)
# 验证回调URL
if not redirect_uris or not isinstance(redirect_uris, list):
cursor.close()
conn.close()
return ApiResponse(
code=100001,
message="至少需要一个回调URL",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 验证权限范围
if not scope or not isinstance(scope, list):
scope = ['profile', 'email']
# 更新应用信息
cursor.execute("""
UPDATE apps
SET name = %s, description = %s, icon_url = %s,
redirect_uris = %s, scope = %s, is_trusted = %s,
access_token_expires = %s, refresh_token_expires = %s,
updated_at = NOW()
WHERE id = %s
""", (
name, description, icon_url,
json.dumps(redirect_uris), json.dumps(scope), is_trusted,
access_token_expires, refresh_token_expires, app_id
))
conn.commit()
# 获取更新后的应用信息
cursor.execute("""
SELECT id, name, app_key, description, icon_url,
redirect_uris, scope, is_active, is_trusted,
access_token_expires, refresh_token_expires,
created_at, updated_at
FROM apps
WHERE id = %s
""", (app_id,))
app_info = cursor.fetchone()
cursor.close()
conn.close()
if app_info:
app_result = {
"id": app_info[0],
"name": app_info[1],
"app_key": app_info[2],
"description": app_info[3],
"icon_url": app_info[4],
"redirect_uris": json.loads(app_info[5]) if app_info[5] else [],
"scope": json.loads(app_info[6]) if app_info[6] else [],
"is_active": bool(app_info[7]),
"is_trusted": bool(app_info[8]),
"access_token_expires": app_info[9],
"refresh_token_expires": app_info[10],
"created_at": app_info[11].isoformat() if app_info[11] else None,
"updated_at": app_info[12].isoformat() if app_info[12] else None
}
print(f"✅ 应用已更新: {name}")
return ApiResponse(
code=0,
message="应用更新成功",
data=app_result,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
else:
return ApiResponse(
code=500001,
message="获取更新后的应用信息失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"更新应用错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.delete("/api/v1/apps/{app_id}")
async def delete_app(
app_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除应用"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 删除应用(级联删除相关数据)
cursor.execute("DELETE FROM apps WHERE id = %s", (app_id,))
conn.commit()
cursor.close()
conn.close()
print(f"✅ 应用已删除: {app_data[1]}")
return ApiResponse(
code=0,
message="应用已删除",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"删除应用错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.post("/api/v1/apps/{app_id}/reset-secret")
async def reset_app_secret(
app_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""重置应用密钥"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 生成新的应用密钥
new_secret = generate_random_string(64)
# 更新应用密钥
cursor.execute("""
UPDATE apps
SET app_secret = %s, updated_at = NOW()
WHERE id = %s
""", (new_secret, app_id))
conn.commit()
cursor.close()
conn.close()
print(f"✅ 应用密钥已重置: {app_data[1]}")
return ApiResponse(
code=0,
message="应用密钥已重置",
data={"app_secret": new_secret},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"重置应用密钥错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def generate_random_string(length=32):
"""生成随机字符串"""
import secrets
import string
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))
"""获取验证码"""
try:
# 生成验证码
captcha_text, captcha_image = generate_captcha()
# 这里应该将验证码文本存储到缓存中(Redis或内存)
# 为了简化,我们暂时返回固定的验证码
captcha_id = secrets.token_hex(16)
return ApiResponse(
code=0,
message="获取验证码成功",
data={
"captcha_id": captcha_id,
"captcha_image": captcha_image,
"captcha_text": captcha_text # 生产环境中不应该返回这个
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"生成验证码错误: {e}")
return ApiResponse(
code=500001,
message="生成验证码失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def generate_captcha():
"""生成验证码"""
try:
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import random
import string
# 生成随机验证码文本
captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
# 创建图片
width, height = 120, 40
image = Image.new('RGB', (width, height), color='white')
draw = ImageDraw.Draw(image)
# 尝试使用系统字体,如果失败则使用默认字体
try:
# Windows系统字体
font = ImageFont.truetype("arial.ttf", 20)
except:
try:
# 备用字体
font = ImageFont.truetype("C:/Windows/Fonts/arial.ttf", 20)
except:
# 使用默认字体
font = ImageFont.load_default()
# 绘制验证码文本
text_width = draw.textlength(captcha_text, font=font)
text_height = 20
x = (width - text_width) // 2
y = (height - text_height) // 2
# 添加一些随机颜色
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD']
text_color = random.choice(colors)
draw.text((x, y), captcha_text, fill=text_color, font=font)
# 添加一些干扰线
for _ in range(3):
x1 = random.randint(0, width)
y1 = random.randint(0, height)
x2 = random.randint(0, width)
y2 = random.randint(0, height)
draw.line([(x1, y1), (x2, y2)], fill=random.choice(colors), width=1)
# 添加一些干扰点
for _ in range(20):
x = random.randint(0, width)
y = random.randint(0, height)
draw.point((x, y), fill=random.choice(colors))
# 转换为base64
buffer = io.BytesIO()
image.save(buffer, format='PNG')
image_data = buffer.getvalue()
image_base64 = base64.b64encode(image_data).decode('utf-8')
return captcha_text, f"data:image/png;base64,{image_base64}"
except ImportError:
# 如果PIL不可用,返回简单的文本验证码
captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
# 创建一个简单的SVG验证码
svg_captcha = f"""
"""
svg_base64 = base64.b64encode(svg_captcha.encode('utf-8')).decode('utf-8')
return captcha_text, f"data:image/svg+xml;base64,{svg_base64}"
except Exception as e:
print(f"生成验证码图片失败: {e}")
# 返回默认验证码
return "1234", "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTIwIiBoZWlnaHQ9IjQwIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPjxyZWN0IHdpZHRoPSIxMjAiIGhlaWdodD0iNDAiIGZpbGw9IiNmMGYwZjAiIHN0cm9rZT0iI2NjYyIvPjx0ZXh0IHg9IjYwIiB5PSIyNSIgZm9udC1mYW1pbHk9IkFyaWFsIiBmb250LXNpemU9IjE4IiB0ZXh0LWFuY2hvcj0ibWlkZGxlIiBmaWxsPSIjMzMzIj4xMjM0PC90ZXh0Pjwvc3ZnPg=="
# RBAC权限管理API
@app.get("/api/v1/user/menus")
async def api_get_user_menus(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取用户菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查用户是否是超级管理员
cursor.execute("""
SELECT COUNT(*) FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND r.name = 'super_admin' AND ur.is_active = 1
""", (user_id,))
is_super_admin = cursor.fetchone()[0] > 0
if is_super_admin:
# 超级管理员返回所有活跃菜单
cursor.execute("""
SELECT m.id, m.parent_id, m.name, m.title, m.path,
m.component, m.icon, m.sort_order, m.menu_type,
m.is_hidden, m.is_active
FROM menus m
WHERE m.is_active = 1
ORDER BY m.sort_order
""")
else:
# 普通用户根据角色权限获取菜单
cursor.execute("""
SELECT m.id, m.parent_id, m.name, m.title, m.path,
m.component, m.icon, m.sort_order, m.menu_type,
m.is_hidden, m.is_active
FROM menus m
JOIN role_menus rm ON m.id = rm.menu_id
JOIN user_roles ur ON rm.role_id = ur.role_id
WHERE ur.user_id = %s
AND ur.is_active = 1
AND m.is_active = 1
GROUP BY m.id, m.parent_id, m.name, m.title, m.path,
m.component, m.icon, m.sort_order, m.menu_type,
m.is_hidden, m.is_active
ORDER BY m.sort_order
""", (user_id,))
menus = []
for row in cursor.fetchall():
menu = {
"id": row[0],
"parent_id": row[1],
"name": row[2],
"title": row[3],
"path": row[4],
"component": row[5],
"icon": row[6],
"sort_order": row[7],
"menu_type": row[8],
"is_hidden": bool(row[9]),
"is_active": bool(row[10]),
"children": []
}
menus.append(menu)
# 构建菜单树
menu_tree = build_menu_tree(menus)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取用户菜单成功",
data=menu_tree,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户菜单错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def build_menu_tree(menus):
"""构建菜单树结构"""
menu_map = {menu["id"]: menu for menu in menus}
tree = []
for menu in menus:
if menu["parent_id"] is None:
tree.append(menu)
else:
parent = menu_map.get(menu["parent_id"])
if parent:
parent["children"].append(menu)
return tree
@app.get("/api/v1/admin/menus")
async def api_get_all_menus(
page: int = 1,
page_size: int = 1000, # 增大默认页面大小,确保返回所有菜单
keyword: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取所有菜单(管理员)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 简化权限检查 - 只检查是否为管理员
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 构建查询条件
where_conditions = []
params = []
if keyword:
where_conditions.append("(m.title LIKE %s OR m.name LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM menus m WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询菜单列表 - 修改排序逻辑以支持树形结构
cursor.execute(f"""
SELECT m.id, m.parent_id, m.name, m.title, m.path, m.component,
m.icon, m.sort_order, m.menu_type, m.is_hidden, m.is_active,
m.description, m.created_at, m.updated_at,
pm.title as parent_title
FROM menus m
LEFT JOIN menus pm ON m.parent_id = pm.id
WHERE {where_clause}
ORDER BY
CASE WHEN m.parent_id IS NULL THEN 0 ELSE 1 END,
m.sort_order,
CASE WHEN m.menu_type = 'menu' THEN 0 ELSE 1 END,
m.created_at
LIMIT %s OFFSET %s
""", params + [page_size, (page - 1) * page_size])
menus = []
for row in cursor.fetchall():
menu = {
"id": row[0],
"parent_id": row[1],
"name": row[2],
"title": row[3],
"path": row[4],
"component": row[5],
"icon": row[6],
"sort_order": row[7],
"menu_type": row[8],
"is_hidden": bool(row[9]),
"is_active": bool(row[10]),
"description": row[11],
"created_at": row[12].isoformat() if row[12] else None,
"updated_at": row[13].isoformat() if row[13] else None,
"parent_title": row[14]
}
menus.append(menu)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取菜单列表成功",
data={
"items": menus,
"total": total,
"page": page,
"page_size": page_size
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取菜单列表错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/admin/roles")
async def api_get_all_roles(
page: int = 1,
page_size: int = 20,
keyword: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取所有角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 简化权限检查 - 只检查是否为管理员
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 构建查询条件
where_conditions = []
params = []
if keyword:
where_conditions.append("(r.display_name LIKE %s OR r.name LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM roles r WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询角色列表
offset = (page - 1) * page_size
cursor.execute(f"""
SELECT r.id, r.name, r.display_name, r.description, r.is_active,
r.is_system, r.created_at, r.updated_at,
COUNT(ur.user_id) as user_count
FROM roles r
LEFT JOIN user_roles ur ON r.id = ur.role_id AND ur.is_active = 1
WHERE {where_clause}
GROUP BY r.id
ORDER BY r.is_system DESC, r.created_at
LIMIT %s OFFSET %s
""", params + [page_size, offset])
roles = []
for row in cursor.fetchall():
role = {
"id": row[0],
"name": row[1],
"display_name": row[2],
"description": row[3],
"is_active": bool(row[4]),
"is_system": bool(row[5]),
"created_at": row[6].isoformat() if row[6] else None,
"updated_at": row[7].isoformat() if row[7] else None,
"user_count": row[8]
}
roles.append(role)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取角色列表成功",
data={
"items": roles,
"total": total,
"page": page,
"page_size": page_size
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取角色列表错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/user/permissions")
async def api_get_user_permissions(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取用户权限"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 获取用户权限
cursor.execute("""
SELECT DISTINCT p.name, p.resource, p.action
FROM permissions p
JOIN role_permissions rp ON p.id = rp.permission_id
JOIN user_roles ur ON rp.role_id = ur.role_id
WHERE ur.user_id = %s
AND ur.is_active = 1
AND p.is_active = 1
""", (user_id,))
permissions = []
for row in cursor.fetchall():
permissions.append({
"name": row[0],
"resource": row[1],
"action": row[2]
})
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取用户权限成功",
data=permissions,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户权限错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 用户管理API
@app.get("/api/v1/admin/users")
async def get_users(
page: int = 1,
page_size: int = 20,
keyword: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取用户列表"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 构建查询条件
where_conditions = []
params = []
if keyword:
where_conditions.append("(u.username LIKE %s OR u.email LIKE %s OR up.real_name LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%", f"%{keyword}%"])
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM users u LEFT JOIN user_profiles up ON u.id = up.user_id WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询用户列表
offset = (page - 1) * page_size
cursor.execute(f"""
SELECT u.id, u.username, u.email, u.phone, u.is_active, u.is_superuser,
u.last_login_at, u.created_at, up.real_name, up.company, up.department,
GROUP_CONCAT(r.display_name) as roles
FROM users u
LEFT JOIN user_profiles up ON u.id = up.user_id
LEFT JOIN user_roles ur ON u.id = ur.user_id AND ur.is_active = 1
LEFT JOIN roles r ON ur.role_id = r.id
WHERE {where_clause}
GROUP BY u.id, u.username, u.email, u.phone, u.is_active, u.is_superuser,
u.last_login_at, u.created_at, up.real_name, up.company, up.department
ORDER BY u.created_at DESC
LIMIT %s OFFSET %s
""", params + [page_size, offset])
users = []
for row in cursor.fetchall():
users.append({
"id": row[0],
"username": row[1],
"email": row[2],
"phone": row[3],
"is_active": bool(row[4]),
"is_superuser": bool(row[5]),
"last_login_at": row[6].isoformat() if row[6] else None,
"created_at": row[7].isoformat() if row[7] else None,
"real_name": row[8],
"company": row[9],
"department": row[10],
"roles": row[11].split(',') if row[11] else []
})
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取用户列表成功",
data={"items": users, "total": total, "page": page, "page_size": page_size},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户列表错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.post("/api/v1/admin/users")
async def create_user(
user_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建用户"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查用户名和邮箱是否已存在
cursor.execute("SELECT id FROM users WHERE username = %s OR email = %s",
(user_data['username'], user_data['email']))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(code=400, message="用户名或邮箱已存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 生成用户ID
user_id = str(uuid.uuid4())
# 创建密码哈希
password_hash = hash_password_simple(user_data['password'])
# 插入用户
cursor.execute("""
INSERT INTO users (id, username, email, phone, password_hash, is_active, is_superuser, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
""", (user_id, user_data['username'], user_data['email'], user_data.get('phone'),
password_hash, user_data.get('is_active', True), user_data.get('is_superuser', False)))
# 插入用户详情
if any(key in user_data for key in ['real_name', 'company', 'department']):
profile_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO user_profiles (id, user_id, real_name, company, department, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
""", (profile_id, user_id, user_data.get('real_name'), user_data.get('company'), user_data.get('department')))
# 分配角色
if 'role_ids' in user_data and user_data['role_ids']:
for role_id in user_data['role_ids']:
role_assignment_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO user_roles (id, user_id, role_id, assigned_by, created_at)
VALUES (%s, %s, %s, %s, NOW())
""", (role_assignment_id, user_id, role_id, payload.get("sub")))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="用户创建成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"创建用户错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.put("/api/v1/admin/users/{user_id}")
async def update_user(
user_id: str,
user_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新用户"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 更新用户基本信息
update_fields = []
update_values = []
for field in ['email', 'phone', 'is_active', 'is_superuser']:
if field in user_data:
update_fields.append(f'{field} = %s')
update_values.append(user_data[field])
if update_fields:
update_values.append(user_id)
cursor.execute(f"""
UPDATE users
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
# 更新用户详情
profile_fields = ['real_name', 'company', 'department']
profile_updates = {k: v for k, v in user_data.items() if k in profile_fields}
if profile_updates:
# 检查是否已有记录
cursor.execute("SELECT id FROM user_profiles WHERE user_id = %s", (user_id,))
profile_exists = cursor.fetchone()
if profile_exists:
update_fields = []
update_values = []
for field, value in profile_updates.items():
update_fields.append(f'{field} = %s')
update_values.append(value)
update_values.append(user_id)
cursor.execute(f"""
UPDATE user_profiles
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE user_id = %s
""", update_values)
else:
profile_id = str(uuid.uuid4())
fields = ['id', 'user_id'] + list(profile_updates.keys())
values = [profile_id, user_id] + list(profile_updates.values())
placeholders = ', '.join(['%s'] * len(values))
cursor.execute(f"""
INSERT INTO user_profiles ({', '.join(fields)}, created_at, updated_at)
VALUES ({placeholders}, NOW(), NOW())
""", values)
# 更新用户角色
if 'role_ids' in user_data:
# 删除现有角色
cursor.execute("DELETE FROM user_roles WHERE user_id = %s", (user_id,))
# 添加新角色
for role_id in user_data['role_ids']:
assignment_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO user_roles (id, user_id, role_id, assigned_by, created_at)
VALUES (%s, %s, %s, %s, NOW())
""", (assignment_id, user_id, role_id, payload.get("sub")))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="用户更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"更新用户错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.delete("/api/v1/admin/users/{user_id}")
async def delete_user(
user_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除用户"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 不能删除自己
if user_id == payload.get("sub"):
return ApiResponse(code=400, message="不能删除自己", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否为超级管理员
cursor.execute("""
SELECT COUNT(*) FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND r.name = 'super_admin' AND ur.is_active = 1
""", (user_id,))
if cursor.fetchone()[0] > 0:
cursor.close()
conn.close()
return ApiResponse(code=400, message="不能删除超级管理员", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 删除相关数据
cursor.execute("DELETE FROM user_roles WHERE user_id = %s", (user_id,))
cursor.execute("DELETE FROM user_profiles WHERE user_id = %s", (user_id,))
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="用户删除成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"删除用户错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 角色管理API
@app.post("/api/v1/admin/roles")
async def create_role(
role_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查角色名是否已存在
cursor.execute("SELECT id FROM roles WHERE name = %s", (role_data['name'],))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(code=400, message="角色名已存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 创建角色
role_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO roles (id, name, display_name, description, is_active, is_system, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, NOW(), NOW())
""", (role_id, role_data['name'], role_data['display_name'], role_data.get('description'),
role_data.get('is_active', True), False))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="角色创建成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"创建角色错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.put("/api/v1/admin/roles/{role_id}")
async def update_role(
role_id: str,
role_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否为系统角色
cursor.execute("SELECT is_system FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(code=404, message="角色不存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
if role[0]: # is_system
cursor.close()
conn.close()
return ApiResponse(code=400, message="不能修改系统角色", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 更新角色
update_fields = []
update_values = []
for field in ['display_name', 'description', 'is_active']:
if field in role_data:
update_fields.append(f'{field} = %s')
update_values.append(role_data[field])
if update_fields:
update_values.append(role_id)
cursor.execute(f"""
UPDATE roles
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="角色更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"更新角色错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.delete("/api/v1/admin/roles/{role_id}")
async def delete_role(
role_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否为系统角色
cursor.execute("SELECT is_system FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(code=404, message="角色不存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
if role[0]: # is_system
cursor.close()
conn.close()
return ApiResponse(code=400, message="不能删除系统角色", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 检查是否有用户使用此角色
cursor.execute("SELECT COUNT(*) FROM user_roles WHERE role_id = %s", (role_id,))
if cursor.fetchone()[0] > 0:
cursor.close()
conn.close()
return ApiResponse(code=400, message="该角色正在被使用,无法删除", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 删除角色相关数据
cursor.execute("DELETE FROM role_permissions WHERE role_id = %s", (role_id,))
cursor.execute("DELETE FROM role_menus WHERE role_id = %s", (role_id,))
cursor.execute("DELETE FROM roles WHERE id = %s", (role_id,))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="角色删除成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"删除角色错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 角色菜单权限管理API
@app.get("/api/v1/admin/roles/{role_id}/menus")
async def get_role_menus(
role_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取角色的菜单权限"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查管理员权限
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查角色是否存在
cursor.execute("SELECT id, name FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(
code=404,
message="角色不存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查是否为超级管理员角色
role_name = role[1]
is_super_admin_role = role_name == "super_admin"
if is_super_admin_role:
# 超级管理员默认拥有所有菜单权限
cursor.execute("""
SELECT id, name, title, parent_id, menu_type
FROM menus
WHERE is_active = 1
ORDER BY sort_order
""")
menu_permissions = cursor.fetchall()
else:
# 普通角色查询已分配的菜单权限
cursor.execute("""
SELECT m.id, m.name, m.title, m.parent_id, m.menu_type
FROM role_menus rm
JOIN menus m ON rm.menu_id = m.id
WHERE rm.role_id = %s AND m.is_active = 1
ORDER BY m.sort_order
""", (role_id,))
menu_permissions = cursor.fetchall()
cursor.close()
conn.close()
# 构建返回数据
menu_ids = [menu[0] for menu in menu_permissions]
menu_details = []
for menu in menu_permissions:
menu_details.append({
"id": menu[0],
"name": menu[1],
"title": menu[2],
"parent_id": menu[3],
"menu_type": menu[4]
})
return ApiResponse(
code=0,
message="获取角色菜单权限成功",
data={
"role_id": role_id,
"role_name": role[1],
"menu_ids": menu_ids,
"menu_details": menu_details,
"total": len(menu_ids)
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取角色菜单权限错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/admin/roles/{role_id}/menus")
async def update_role_menus(
role_id: str,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新角色的菜单权限"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查管理员权限
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取请求数据
body = await request.json()
menu_ids = body.get("menu_ids", [])
if not isinstance(menu_ids, list):
return ApiResponse(
code=400,
message="菜单ID列表格式错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查角色是否存在
cursor.execute("SELECT id, name FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(
code=404,
message="角色不存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查是否为超级管理员角色
role_name = role[1]
is_super_admin_role = role_name == "super_admin"
if is_super_admin_role:
# 超级管理员角色不允许修改权限,始终拥有全部权限
cursor.close()
conn.close()
return ApiResponse(
code=400,
message="超级管理员角色拥有全部权限,无需修改",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 验证菜单ID是否存在
if menu_ids:
placeholders = ','.join(['%s'] * len(menu_ids))
cursor.execute(f"""
SELECT id FROM menus
WHERE id IN ({placeholders}) AND is_active = 1
""", menu_ids)
valid_menu_ids = [row[0] for row in cursor.fetchall()]
invalid_menu_ids = set(menu_ids) - set(valid_menu_ids)
if invalid_menu_ids:
cursor.close()
conn.close()
return ApiResponse(
code=400,
message=f"无效的菜单ID: {', '.join(invalid_menu_ids)}",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 开始事务
cursor.execute("START TRANSACTION")
try:
# 删除角色现有的菜单权限
cursor.execute("DELETE FROM role_menus WHERE role_id = %s", (role_id,))
# 添加新的菜单权限
if menu_ids:
values = [(role_id, menu_id) for menu_id in menu_ids]
cursor.executemany("""
INSERT INTO role_menus (role_id, menu_id, created_at)
VALUES (%s, %s, NOW())
""", values)
# 提交事务
conn.commit()
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="角色菜单权限更新成功",
data={
"role_id": role_id,
"role_name": role[1],
"menu_ids": menu_ids,
"updated_count": len(menu_ids)
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
# 回滚事务
conn.rollback()
cursor.close()
conn.close()
raise e
except Exception as e:
print(f"更新角色菜单权限错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 菜单管理API
@app.post("/api/v1/admin/menus")
async def create_menu(
menu_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查菜单名是否已存在
cursor.execute("SELECT id FROM menus WHERE name = %s", (menu_data['name'],))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(code=400, message="菜单标识已存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 创建菜单
menu_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO menus (id, parent_id, name, title, path, component, icon,
sort_order, menu_type, is_hidden, is_active, description, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
""", (
menu_id, menu_data.get('parent_id'), menu_data['name'], menu_data['title'],
menu_data.get('path'), menu_data.get('component'), menu_data.get('icon'),
menu_data.get('sort_order', 0), menu_data.get('menu_type', 'menu'),
menu_data.get('is_hidden', False), menu_data.get('is_active', True),
menu_data.get('description')
))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="菜单创建成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"创建菜单错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.put("/api/v1/admin/menus/{menu_id}")
async def update_menu(
menu_id: str,
menu_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 更新菜单
update_fields = []
update_values = []
for field in ['parent_id', 'title', 'path', 'component', 'icon', 'sort_order',
'menu_type', 'is_hidden', 'is_active', 'description']:
if field in menu_data:
update_fields.append(f'{field} = %s')
update_values.append(menu_data[field])
if update_fields:
update_values.append(menu_id)
cursor.execute(f"""
UPDATE menus
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="菜单更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"更新菜单错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.delete("/api/v1/admin/menus/{menu_id}")
async def delete_menu(
menu_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否有子菜单
cursor.execute("SELECT COUNT(*) FROM menus WHERE parent_id = %s", (menu_id,))
if cursor.fetchone()[0] > 0:
cursor.close()
conn.close()
return ApiResponse(code=400, message="该菜单下有子菜单,无法删除", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 删除菜单相关数据
cursor.execute("DELETE FROM role_menus WHERE menu_id = %s", (menu_id,))
cursor.execute("DELETE FROM menus WHERE id = %s", (menu_id,))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="菜单删除成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"删除菜单错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 获取所有角色(用于下拉选择)
@app.get("/api/v1/roles/all")
async def get_all_roles_simple(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取所有角色(简化版,用于下拉选择)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
cursor.execute("""
SELECT id, name, display_name, is_system, is_active
FROM roles
WHERE is_active = 1
ORDER BY is_system DESC, display_name
""")
roles = []
for row in cursor.fetchall():
roles.append({
"id": row[0],
"name": row[1],
"display_name": row[2],
"is_system": bool(row[3]),
"is_active": bool(row[4])
})
cursor.close()
conn.close()
return ApiResponse(code=0, message="获取角色列表成功", data=roles, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"获取角色列表错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
if __name__ == "__main__":
import uvicorn
# 查找可用端口
port = find_available_port()
if port is None:
print("❌ 无法找到可用端口 (8000-8010)")
print("请手动停止占用这些端口的进程")
sys.exit(1)
print("=" * 60)
print("🚀 SSO认证中心完整服务器")
print("=" * 60)
print(f"✅ 找到可用端口: {port}")
print(f"🌐 访问地址: http://localhost:{port}")
print(f"📚 API文档: http://localhost:{port}/docs")
print(f"❤️ 健康检查: http://localhost:{port}/health")
print(f"🔐 登录API: http://localhost:{port}/api/v1/auth/login")
print("=" * 60)
print("📝 前端配置:")
print(f" VITE_API_BASE_URL=http://localhost:{port}")
print("=" * 60)
print("👤 测试账号:")
print(" 用户名: admin")
print(" 密码: Admin123456")
print("=" * 60)
print("按 Ctrl+C 停止服务器")
print()
try:
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)
except KeyboardInterrupt:
print("\n👋 服务器已停止")
except Exception as e:
print(f"❌ 启动失败: {e}")
sys.exit(1)
@app.get("/api/v1/auth/captcha")
async def get_captcha():
"""获取验证码"""
try:
# 生成验证码
captcha_text, captcha_image = generate_captcha()
# 这里应该将验证码文本存储到缓存中(Redis或内存)
# 为了简化,我们暂时返回固定的验证码
captcha_id = secrets.token_hex(16)
return ApiResponse(
code=0,
message="获取验证码成功",
data={
"captcha_id": captcha_id,
"captcha_image": captcha_image,
"captcha_text": captcha_text # 生产环境中不应该返回这个
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"生成验证码错误: {e}")
return ApiResponse(
code=500001,
message="生成验证码失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def generate_captcha():
"""生成验证码"""
try:
import random
import string
import base64
# 生成随机验证码文本
captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
# 创建一个简单的SVG验证码
svg_captcha = f"""
"""
svg_base64 = base64.b64encode(svg_captcha.encode('utf-8')).decode('utf-8')
return captcha_text, f"data:image/svg+xml;base64,{svg_base64}"
except Exception as e:
print(f"生成验证码失败: {e}")
# 返回默认验证码
return "1234", "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTIwIiBoZWlnaHQ9IjQwIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPjxyZWN0IHdpZHRoPSIxMjAiIGhlaWdodD0iNDAiIGZpbGw9IiNmMGYwZjAiIHN0cm9rZT0iI2NjYyIvPjx0ZXh0IHg9IjYwIiB5PSIyNSIgZm9udC1mYW1pbHk9IkFyaWFsIiBmb250LXNpemU9IjE4IiB0ZXh0LWFuY2hvcj0ibWlkZGxlIiBmaWxsPSIjMzMzIj4xMjM0PC90ZXh0Pjwvc3ZnPg=="