|
@@ -1,6 +1,6 @@
|
|
|
"""
|
|
"""
|
|
|
OAuth 2.0 认证服务
|
|
OAuth 2.0 认证服务
|
|
|
-处理与 OAuth 认证中心的交互,包括 token 验证和刷新
|
|
|
|
|
|
|
+处理与统一认证平台的交互,包括 code 换 token、用户信息获取和角色映射
|
|
|
"""
|
|
"""
|
|
|
import httpx
|
|
import httpx
|
|
|
import logging
|
|
import logging
|
|
@@ -14,388 +14,244 @@ from database import get_db_connection
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
-# SSO 角色 → 本地角色映射(支持中英文)
|
|
|
|
|
|
|
+# SSO 角色 → 本地角色映射
|
|
|
|
|
+# 仅识别 label_admin(标注管理员)、annotator(标注员)、viewer(查看者)
|
|
|
|
|
+# 其他角色一律不识别,无权限
|
|
|
SSO_ROLE_MAPPING = {
|
|
SSO_ROLE_MAPPING = {
|
|
|
- # 英文角色名
|
|
|
|
|
- "super_admin": "admin",
|
|
|
|
|
|
|
+ # 角色代码
|
|
|
"label_admin": "admin",
|
|
"label_admin": "admin",
|
|
|
- "admin": "admin",
|
|
|
|
|
- "labeler": "annotator",
|
|
|
|
|
- "user_manager": "admin",
|
|
|
|
|
- "app_manager": "admin",
|
|
|
|
|
- # 中文角色名
|
|
|
|
|
- "超级管理员": "admin",
|
|
|
|
|
|
|
+ "annotator": "annotator",
|
|
|
|
|
+ "viewer": "viewer",
|
|
|
|
|
+ # 角色名称(对应中文显示)
|
|
|
"标注管理员": "admin",
|
|
"标注管理员": "admin",
|
|
|
- "管理员": "admin",
|
|
|
|
|
"标注员": "annotator",
|
|
"标注员": "annotator",
|
|
|
- "用户管理员": "admin",
|
|
|
|
|
- "应用管理员": "admin",
|
|
|
|
|
|
|
+ "查看者": "viewer",
|
|
|
}
|
|
}
|
|
|
-DEFAULT_LOCAL_ROLE = "viewer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
-def map_sso_roles_to_local(sso_roles: list, is_superuser: bool = False) -> str:
|
|
|
|
|
|
|
+def map_sso_roles_to_local(sso_roles: list) -> Optional[str]:
|
|
|
"""
|
|
"""
|
|
|
- 将 SSO 角色列表映射为本地单一角色。
|
|
|
|
|
|
|
+ 将 SSO 角色列表映射为本地角色。
|
|
|
|
|
+ 仅识别 label_admin、annotator、viewer,未识别到任何角色则返回 None(无权限)。
|
|
|
优先级: admin > annotator > viewer
|
|
优先级: admin > annotator > viewer
|
|
|
"""
|
|
"""
|
|
|
- if is_superuser:
|
|
|
|
|
- return "admin"
|
|
|
|
|
-
|
|
|
|
|
- local_role = DEFAULT_LOCAL_ROLE
|
|
|
|
|
|
|
+ local_role: Optional[str] = None
|
|
|
for sso_role in sso_roles:
|
|
for sso_role in sso_roles:
|
|
|
mapped = SSO_ROLE_MAPPING.get(sso_role)
|
|
mapped = SSO_ROLE_MAPPING.get(sso_role)
|
|
|
if mapped == "admin":
|
|
if mapped == "admin":
|
|
|
return "admin"
|
|
return "admin"
|
|
|
if mapped == "annotator":
|
|
if mapped == "annotator":
|
|
|
local_role = "annotator"
|
|
local_role = "annotator"
|
|
|
|
|
+ elif mapped == "viewer" and local_role is None:
|
|
|
|
|
+ local_role = "viewer"
|
|
|
|
|
|
|
|
return local_role
|
|
return local_role
|
|
|
|
|
|
|
|
|
|
|
|
|
class OAuthService:
|
|
class OAuthService:
|
|
|
"""OAuth 2.0 认证服务"""
|
|
"""OAuth 2.0 认证服务"""
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def generate_state() -> str:
|
|
def generate_state() -> str:
|
|
|
- """
|
|
|
|
|
- 生成随机 state 参数,用于防止 CSRF 攻击
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 随机字符串
|
|
|
|
|
- """
|
|
|
|
|
|
|
+ """生成随机 state 参数,用于防止 CSRF 攻击"""
|
|
|
return secrets.token_urlsafe(32)
|
|
return secrets.token_urlsafe(32)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def get_authorization_url(state: str) -> str:
|
|
def get_authorization_url(state: str) -> str:
|
|
|
- """
|
|
|
|
|
- 构建 OAuth 授权 URL
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- state: 防CSRF的随机字符串
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 完整的授权URL
|
|
|
|
|
- """
|
|
|
|
|
|
|
+ """构建 OAuth 授权 URL"""
|
|
|
from urllib.parse import urlencode
|
|
from urllib.parse import urlencode
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
params = {
|
|
params = {
|
|
|
"response_type": "code",
|
|
"response_type": "code",
|
|
|
- "client_id": settings.OAUTH_CLIENT_ID,
|
|
|
|
|
- "redirect_uri": settings.OAUTH_REDIRECT_URI,
|
|
|
|
|
- "scope": settings.OAUTH_SCOPE,
|
|
|
|
|
- "state": state
|
|
|
|
|
|
|
+ "client_id": settings.SSO_CLIENT_ID,
|
|
|
|
|
+ "redirect_uri": settings.SSO_REDIRECT_URI,
|
|
|
|
|
+ "scope": settings.SSO_SCOPE,
|
|
|
|
|
+ "state": state,
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}"
|
|
|
|
|
|
|
+
|
|
|
|
|
+ authorize_url = f"{settings.SSO_BASE_URL}{settings.SSO_AUTHORIZE_ENDPOINT}"
|
|
|
return f"{authorize_url}?{urlencode(params)}"
|
|
return f"{authorize_url}?{urlencode(params)}"
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
async def exchange_code_for_token(code: str) -> Dict[str, Any]:
|
|
async def exchange_code_for_token(code: str) -> Dict[str, Any]:
|
|
|
- """
|
|
|
|
|
- 用授权码换取访问令牌
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- code: OAuth 授权码
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 令牌信息字典,包含 access_token, token_type, expires_in 等
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- Exception: 令牌交换失败
|
|
|
|
|
- """
|
|
|
|
|
- token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
|
|
|
|
|
-
|
|
|
|
|
|
|
+ """用授权码换取访问令牌"""
|
|
|
|
|
+ token_url = f"{settings.SSO_BASE_URL}{settings.SSO_TOKEN_ENDPOINT}"
|
|
|
|
|
+
|
|
|
async with httpx.AsyncClient() as client:
|
|
async with httpx.AsyncClient() as client:
|
|
|
response = await client.post(
|
|
response = await client.post(
|
|
|
token_url,
|
|
token_url,
|
|
|
data={
|
|
data={
|
|
|
"grant_type": "authorization_code",
|
|
"grant_type": "authorization_code",
|
|
|
"code": code,
|
|
"code": code,
|
|
|
- "redirect_uri": settings.OAUTH_REDIRECT_URI,
|
|
|
|
|
- "client_id": settings.OAUTH_CLIENT_ID,
|
|
|
|
|
- "client_secret": settings.OAUTH_CLIENT_SECRET
|
|
|
|
|
|
|
+ "redirect_uri": settings.SSO_REDIRECT_URI,
|
|
|
|
|
+ "client_id": settings.SSO_CLIENT_ID,
|
|
|
|
|
+ "client_secret": settings.SSO_CLIENT_SECRET,
|
|
|
},
|
|
},
|
|
|
- headers={"Content-Type": "application/x-www-form-urlencoded"}
|
|
|
|
|
|
|
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if response.status_code != 200:
|
|
if response.status_code != 200:
|
|
|
raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
|
|
raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
data = response.json()
|
|
data = response.json()
|
|
|
-
|
|
|
|
|
- # 处理不同的响应格式
|
|
|
|
|
|
|
+
|
|
|
if "access_token" in data:
|
|
if "access_token" in data:
|
|
|
return data
|
|
return data
|
|
|
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
|
|
|
|
|
- code = data.get("code")
|
|
|
|
|
- if (code == 0 or code == "000000") and "data" in data:
|
|
|
|
|
|
|
+ code_val = data.get("code")
|
|
|
|
|
+ if (code_val == 0 or code_val == "000000") and "data" in data:
|
|
|
return data["data"]
|
|
return data["data"]
|
|
|
else:
|
|
else:
|
|
|
raise Exception(f"无效的令牌响应格式: {data}")
|
|
raise Exception(f"无效的令牌响应格式: {data}")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
async def get_user_info(access_token: str) -> Dict[str, Any]:
|
|
async def get_user_info(access_token: str) -> Dict[str, Any]:
|
|
|
- """
|
|
|
|
|
- 使用访问令牌获取用户信息
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- access_token: OAuth 访问令牌
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 用户信息字典
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- Exception: 获取用户信息失败
|
|
|
|
|
- """
|
|
|
|
|
- userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}"
|
|
|
|
|
-
|
|
|
|
|
|
|
+ """使用访问令牌获取用户信息(不含角色)"""
|
|
|
|
|
+ userinfo_url = f"{settings.SSO_BASE_URL}{settings.SSO_USERINFO_ENDPOINT}"
|
|
|
|
|
+
|
|
|
async with httpx.AsyncClient() as client:
|
|
async with httpx.AsyncClient() as client:
|
|
|
response = await client.get(
|
|
response = await client.get(
|
|
|
userinfo_url,
|
|
userinfo_url,
|
|
|
- headers={"Authorization": f"Bearer {access_token}"}
|
|
|
|
|
|
|
+ headers={"Authorization": f"Bearer {access_token}"},
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if response.status_code != 200:
|
|
if response.status_code != 200:
|
|
|
raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
|
|
raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
data = response.json()
|
|
data = response.json()
|
|
|
-
|
|
|
|
|
- # 处理不同的响应格式
|
|
|
|
|
|
|
+
|
|
|
if "sub" in data or "id" in data:
|
|
if "sub" in data or "id" in data:
|
|
|
return data
|
|
return data
|
|
|
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
|
|
|
|
|
- code = data.get("code")
|
|
|
|
|
- if (code == 0 or code == "000000") and "data" in data:
|
|
|
|
|
|
|
+ code_val = data.get("code")
|
|
|
|
|
+ if (code_val == 0 or code_val == "000000") and "data" in data:
|
|
|
return data["data"]
|
|
return data["data"]
|
|
|
else:
|
|
else:
|
|
|
raise Exception(f"无效的用户信息响应格式: {data}")
|
|
raise Exception(f"无效的用户信息响应格式: {data}")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ async def get_user_profile(access_token: str) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 通过 SSO /oauth/userinfo 获取用户信息和角色。
|
|
|
|
|
+ 返回格式: {sub, username, email, roles: [{name, code}]}
|
|
|
|
|
+ """
|
|
|
|
|
+ userinfo_url = f"{settings.SSO_BASE_URL}{settings.SSO_USERINFO_ENDPOINT}"
|
|
|
|
|
+ profile = {}
|
|
|
|
|
+
|
|
|
|
|
+ async with httpx.AsyncClient(timeout=10.0) as client:
|
|
|
|
|
+ try:
|
|
|
|
|
+ response = await client.get(
|
|
|
|
|
+ userinfo_url,
|
|
|
|
|
+ headers={"Authorization": f"Bearer {access_token}"},
|
|
|
|
|
+ )
|
|
|
|
|
+ except httpx.RequestError:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
|
|
|
+ detail="SSO 认证中心不可用",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if response.status_code == 401:
|
|
|
|
|
+ raise HTTPException(
|
|
|
|
|
+ status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
+ detail="无效的访问令牌",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if response.status_code == 200:
|
|
|
|
|
+ data = response.json()
|
|
|
|
|
+ logger.debug(f"SSO userinfo response: {data}")
|
|
|
|
|
+
|
|
|
|
|
+ code_val = data.get("code")
|
|
|
|
|
+ if (code_val == 0 or code_val == "000000") and "data" in data:
|
|
|
|
|
+ profile = data["data"]
|
|
|
|
|
+ elif "id" in data or "username" in data or "sub" in data:
|
|
|
|
|
+ profile = data
|
|
|
|
|
+
|
|
|
|
|
+ # 解析角色(支持对象和字符串两种格式)
|
|
|
|
|
+ raw_roles = profile.get("roles", [])
|
|
|
|
|
+ sso_roles: list = []
|
|
|
|
|
+ for role_item in raw_roles:
|
|
|
|
|
+ if isinstance(role_item, dict):
|
|
|
|
|
+ sso_roles.append(role_item.get("code", ""))
|
|
|
|
|
+ name = role_item.get("name", "")
|
|
|
|
|
+ if name:
|
|
|
|
|
+ sso_roles.append(name)
|
|
|
|
|
+ elif isinstance(role_item, str):
|
|
|
|
|
+ sso_roles.append(role_item)
|
|
|
|
|
+
|
|
|
|
|
+ local_role = map_sso_roles_to_local(sso_roles)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(
|
|
|
|
|
+ f"SSO 用户 {profile.get('username')}: "
|
|
|
|
|
+ f"roles={raw_roles}, sso_roles={sso_roles} → local_role={local_role}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "id": profile.get("id") or profile.get("sub"),
|
|
|
|
|
+ "username": profile.get("username") or profile.get("preferred_username") or profile.get("name"),
|
|
|
|
|
+ "email": profile.get("email", ""),
|
|
|
|
|
+ "role": local_role,
|
|
|
|
|
+ "sso_roles": sso_roles,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
|
|
def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
|
|
|
"""
|
|
"""
|
|
|
从 OAuth 用户信息同步到本地数据库
|
|
从 OAuth 用户信息同步到本地数据库
|
|
|
如果用户不存在则创建,如果存在则更新(包括角色)
|
|
如果用户不存在则创建,如果存在则更新(包括角色)
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- oauth_user_info: OAuth 返回的用户信息
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 本地用户对象
|
|
|
|
|
"""
|
|
"""
|
|
|
with get_db_connection() as conn:
|
|
with get_db_connection() as conn:
|
|
|
cursor = conn.cursor()
|
|
cursor = conn.cursor()
|
|
|
-
|
|
|
|
|
- # 提取用户信息(兼容不同的字段名)
|
|
|
|
|
|
|
+
|
|
|
oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
|
|
oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
|
|
|
username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
|
|
username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
|
|
|
email = oauth_user_info.get("email", "")
|
|
email = oauth_user_info.get("email", "")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if not oauth_id:
|
|
if not oauth_id:
|
|
|
raise ValueError("OAuth 用户信息缺少 ID 字段")
|
|
raise ValueError("OAuth 用户信息缺少 ID 字段")
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if not username:
|
|
if not username:
|
|
|
raise ValueError("OAuth 用户信息缺少用户名字段")
|
|
raise ValueError("OAuth 用户信息缺少用户名字段")
|
|
|
-
|
|
|
|
|
- # 计算本地角色
|
|
|
|
|
|
|
+
|
|
|
sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", [])
|
|
sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", [])
|
|
|
- is_superuser = bool(oauth_user_info.get("is_superuser", False))
|
|
|
|
|
- role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles, is_superuser)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles)
|
|
|
|
|
+
|
|
|
|
|
+ if role is None:
|
|
|
|
|
+ raise ValueError(f"用户 {username} 没有被识别的 SSO 角色(sso_roles={sso_roles}),无权限访问")
|
|
|
|
|
+
|
|
|
logger.debug(f"sync_user_from_oauth: oauth_id={oauth_id}, username={username}, sso_roles={sso_roles}, computed_role={role}")
|
|
logger.debug(f"sync_user_from_oauth: oauth_id={oauth_id}, username={username}, sso_roles={sso_roles}, computed_role={role}")
|
|
|
-
|
|
|
|
|
- # 查找是否已存在该 OAuth 用户
|
|
|
|
|
|
|
+
|
|
|
cursor.execute(
|
|
cursor.execute(
|
|
|
"SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s",
|
|
"SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s",
|
|
|
- ("sso", oauth_id)
|
|
|
|
|
|
|
+ ("sso", oauth_id),
|
|
|
)
|
|
)
|
|
|
row = cursor.fetchone()
|
|
row = cursor.fetchone()
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if row:
|
|
if row:
|
|
|
- # 用户已存在,更新信息(包括角色)
|
|
|
|
|
user = User.from_row(row)
|
|
user = User.from_row(row)
|
|
|
logger.debug(f"User exists: id={user.id}, old_role={user.role}, new_role={role}")
|
|
logger.debug(f"User exists: id={user.id}, old_role={user.role}, new_role={role}")
|
|
|
-
|
|
|
|
|
- cursor.execute("""
|
|
|
|
|
- UPDATE users
|
|
|
|
|
- SET username = %s, email = %s, role = %s
|
|
|
|
|
- WHERE id = %s
|
|
|
|
|
- """, (username, email, role, user.id))
|
|
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ "UPDATE users SET username = %s, email = %s, role = %s WHERE id = %s",
|
|
|
|
|
+ (username, email, role, user.id),
|
|
|
|
|
+ )
|
|
|
conn.commit()
|
|
conn.commit()
|
|
|
logger.debug(f"User updated in database")
|
|
logger.debug(f"User updated in database")
|
|
|
-
|
|
|
|
|
- # 重新查询更新后的用户
|
|
|
|
|
|
|
+
|
|
|
cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,))
|
|
cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,))
|
|
|
row = cursor.fetchone()
|
|
row = cursor.fetchone()
|
|
|
- updated_user = User.from_row(row)
|
|
|
|
|
- logger.debug(f"User after update: role={updated_user.role}")
|
|
|
|
|
- return updated_user
|
|
|
|
|
|
|
+ return User.from_row(row)
|
|
|
else:
|
|
else:
|
|
|
- # 新用户,创建记录
|
|
|
|
|
user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
|
|
user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
|
|
|
-
|
|
|
|
|
- cursor.execute("""
|
|
|
|
|
|
|
+
|
|
|
|
|
+ cursor.execute(
|
|
|
|
|
+ """
|
|
|
INSERT INTO users (
|
|
INSERT INTO users (
|
|
|
id, username, email, password_hash, role,
|
|
id, username, email, password_hash, role,
|
|
|
oauth_provider, oauth_id, created_at
|
|
oauth_provider, oauth_id, created_at
|
|
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
|
|
- """, (
|
|
|
|
|
- user_id,
|
|
|
|
|
- username,
|
|
|
|
|
- email,
|
|
|
|
|
- "", # OAuth 用户不需要密码
|
|
|
|
|
- role,
|
|
|
|
|
- "sso",
|
|
|
|
|
- oauth_id,
|
|
|
|
|
- datetime.now()
|
|
|
|
|
- ))
|
|
|
|
|
-
|
|
|
|
|
|
|
+ """,
|
|
|
|
|
+ (user_id, username, email, "", role, "sso", oauth_id, datetime.now()),
|
|
|
|
|
+ )
|
|
|
conn.commit()
|
|
conn.commit()
|
|
|
-
|
|
|
|
|
- # 查询新创建的用户
|
|
|
|
|
|
|
+
|
|
|
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
|
|
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
|
|
|
row = cursor.fetchone()
|
|
row = cursor.fetchone()
|
|
|
return User.from_row(row)
|
|
return User.from_row(row)
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- async def verify_sso_token(access_token: str) -> Dict[str, Any]:
|
|
|
|
|
- """
|
|
|
|
|
- 通过 SSO 验证 token 并获取用户信息(含角色)。
|
|
|
|
|
-
|
|
|
|
|
- 使用 /api/v1/system/users/profile 端点获取完整用户信息,
|
|
|
|
|
- 包括 roles 列表和 is_superuser 标记,然后映射为本地角色。
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- access_token: SSO 访问令牌
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 用户信息字典 {id, username, email, role, ...}
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- HTTPException(401): token 无效
|
|
|
|
|
- HTTPException(503): SSO 中心不可用
|
|
|
|
|
- """
|
|
|
|
|
- profile_url = f"{settings.OAUTH_BASE_URL}/api/v1/system/users/profile"
|
|
|
|
|
-
|
|
|
|
|
- async with httpx.AsyncClient(timeout=10.0) as client:
|
|
|
|
|
- try:
|
|
|
|
|
- response = await client.get(
|
|
|
|
|
- profile_url,
|
|
|
|
|
- headers={"Authorization": f"Bearer {access_token}"}
|
|
|
|
|
- )
|
|
|
|
|
- except httpx.RequestError:
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
|
|
|
- detail="SSO 认证中心不可用"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if response.status_code == 401:
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
- detail="无效的访问令牌"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- if response.status_code != 200:
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
- detail=f"SSO 验证失败 ({response.status_code})"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- data = response.json()
|
|
|
|
|
- logger.debug(f"SSO profile response: {data}")
|
|
|
|
|
-
|
|
|
|
|
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
|
|
|
|
|
- code = data.get("code")
|
|
|
|
|
- if (code == 0 or code == "000000") and "data" in data:
|
|
|
|
|
- profile = data["data"]
|
|
|
|
|
- elif "id" in data or "username" in data:
|
|
|
|
|
- profile = data
|
|
|
|
|
- else:
|
|
|
|
|
- logger.error(f"Invalid profile response format: {data}")
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
- detail="无效的访问令牌"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # 提取角色信息并映射
|
|
|
|
|
- sso_roles = profile.get("roles", [])
|
|
|
|
|
- is_superuser = bool(profile.get("is_superuser", False))
|
|
|
|
|
- local_role = map_sso_roles_to_local(sso_roles, is_superuser)
|
|
|
|
|
-
|
|
|
|
|
- logger.info(
|
|
|
|
|
- f"SSO 用户 {profile.get('username')}: "
|
|
|
|
|
- f"sso_roles={sso_roles}, is_superuser={is_superuser} → local_role={local_role}"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- # 返回统一格式的用户信息
|
|
|
|
|
- return {
|
|
|
|
|
- "id": profile.get("id"),
|
|
|
|
|
- "username": profile.get("username"),
|
|
|
|
|
- "email": profile.get("email", ""),
|
|
|
|
|
- "role": local_role,
|
|
|
|
|
- "sso_roles": sso_roles,
|
|
|
|
|
- "is_superuser": is_superuser,
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- @staticmethod
|
|
|
|
|
- async def refresh_sso_token(refresh_token: str) -> Dict[str, Any]:
|
|
|
|
|
- """
|
|
|
|
|
- 向 SSO 中心刷新 token。
|
|
|
|
|
-
|
|
|
|
|
- Args:
|
|
|
|
|
- refresh_token: SSO 刷新令牌
|
|
|
|
|
-
|
|
|
|
|
- Returns:
|
|
|
|
|
- 新的 token 信息 {access_token, refresh_token, ...}
|
|
|
|
|
-
|
|
|
|
|
- Raises:
|
|
|
|
|
- HTTPException(401): refresh_token 无效
|
|
|
|
|
- HTTPException(503): SSO 中心不可用
|
|
|
|
|
- """
|
|
|
|
|
- token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
|
|
|
|
|
- logger.debug(f"Refreshing token at: {token_url}")
|
|
|
|
|
-
|
|
|
|
|
- async with httpx.AsyncClient(timeout=10.0) as client:
|
|
|
|
|
- try:
|
|
|
|
|
- response = await client.post(
|
|
|
|
|
- token_url,
|
|
|
|
|
- data={
|
|
|
|
|
- "grant_type": "refresh_token",
|
|
|
|
|
- "refresh_token": refresh_token,
|
|
|
|
|
- "client_id": settings.OAUTH_CLIENT_ID,
|
|
|
|
|
- "client_secret": settings.OAUTH_CLIENT_SECRET
|
|
|
|
|
- },
|
|
|
|
|
- headers={"Content-Type": "application/x-www-form-urlencoded"}
|
|
|
|
|
- )
|
|
|
|
|
- except httpx.RequestError as e:
|
|
|
|
|
- logger.error(f"SSO refresh request failed: {e}")
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
|
|
|
- detail="SSO 认证中心不可用"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- logger.debug(f"SSO refresh response: status={response.status_code}")
|
|
|
|
|
-
|
|
|
|
|
- if response.status_code != 200:
|
|
|
|
|
- logger.error(f"SSO refresh failed: {response.status_code}, body={response.text}")
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
- detail="刷新令牌无效或已过期,请重新登录"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
- data = response.json()
|
|
|
|
|
- logger.debug(f"SSO refresh response data: {data}")
|
|
|
|
|
-
|
|
|
|
|
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
|
|
|
|
|
- code = data.get("code")
|
|
|
|
|
- if (code == 0 or code == "000000") and "data" in data:
|
|
|
|
|
- return data["data"]
|
|
|
|
|
- elif "access_token" in data:
|
|
|
|
|
- return data
|
|
|
|
|
- else:
|
|
|
|
|
- logger.error(f"Invalid refresh response format: {data}")
|
|
|
|
|
- raise HTTPException(
|
|
|
|
|
- status_code=status.HTTP_401_UNAUTHORIZED,
|
|
|
|
|
- detail="刷新令牌无效或已过期,请重新登录"
|
|
|
|
|
- )
|
|
|