""" OAuth 2.0 认证服务 处理与 OAuth 认证中心的交互,包括 token 验证和刷新 """ import httpx import logging import secrets from typing import Dict, Any, Optional from datetime import datetime from fastapi import HTTPException, status from config import settings from models import User from database import get_db_connection logger = logging.getLogger(__name__) # SSO 角色 → 本地角色映射 SSO_ROLE_MAPPING = { "super_admin": "admin", "label_admin": "admin", "admin": "admin", "labeler": "annotator", } DEFAULT_LOCAL_ROLE = "viewer" def map_sso_roles_to_local(sso_roles: list, is_superuser: bool = False) -> str: """ 将 SSO 角色列表映射为本地单一角色。 优先级: admin > annotator > viewer """ if is_superuser: return "admin" local_role = DEFAULT_LOCAL_ROLE for sso_role in sso_roles: mapped = SSO_ROLE_MAPPING.get(sso_role) if mapped == "admin": return "admin" if mapped == "annotator": local_role = "annotator" return local_role class OAuthService: """OAuth 2.0 认证服务""" @staticmethod def generate_state() -> str: """ 生成随机 state 参数,用于防止 CSRF 攻击 Returns: 随机字符串 """ return secrets.token_urlsafe(32) @staticmethod def get_authorization_url(state: str) -> str: """ 构建 OAuth 授权 URL Args: state: 防CSRF的随机字符串 Returns: 完整的授权URL """ from urllib.parse import urlencode params = { "response_type": "code", "client_id": settings.OAUTH_CLIENT_ID, "redirect_uri": settings.OAUTH_REDIRECT_URI, "scope": settings.OAUTH_SCOPE, "state": state } authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}" return f"{authorize_url}?{urlencode(params)}" @staticmethod async def exchange_code_for_token(code: str) -> Dict[str, Any]: """ 用授权码换取访问令牌 Args: code: OAuth 授权码 Returns: 令牌信息字典,包含 access_token, token_type, expires_in 等 Raises: Exception: 令牌交换失败 """ token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}" async with httpx.AsyncClient() as client: response = await client.post( token_url, data={ "grant_type": "authorization_code", "code": code, "redirect_uri": settings.OAUTH_REDIRECT_URI, "client_id": settings.OAUTH_CLIENT_ID, "client_secret": settings.OAUTH_CLIENT_SECRET }, headers={"Content-Type": "application/x-www-form-urlencoded"} ) if response.status_code != 200: raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}") data = response.json() # 处理不同的响应格式 if "access_token" in data: return data # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}} code = data.get("code") if (code == 0 or code == "000000") and "data" in data: return data["data"] else: raise Exception(f"无效的令牌响应格式: {data}") @staticmethod async def get_user_info(access_token: str) -> Dict[str, Any]: """ 使用访问令牌获取用户信息 Args: access_token: OAuth 访问令牌 Returns: 用户信息字典 Raises: Exception: 获取用户信息失败 """ userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}" async with httpx.AsyncClient() as client: response = await client.get( userinfo_url, headers={"Authorization": f"Bearer {access_token}"} ) if response.status_code != 200: raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}") data = response.json() # 处理不同的响应格式 if "sub" in data or "id" in data: return data # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}} code = data.get("code") if (code == 0 or code == "000000") and "data" in data: return data["data"] else: raise Exception(f"无效的用户信息响应格式: {data}") @staticmethod def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User: """ 从 OAuth 用户信息同步到本地数据库 如果用户不存在则创建,如果存在则更新(包括角色) Args: oauth_user_info: OAuth 返回的用户信息 Returns: 本地用户对象 """ with get_db_connection() as conn: cursor = conn.cursor() # 提取用户信息(兼容不同的字段名) oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id") username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name") email = oauth_user_info.get("email", "") if not oauth_id: raise ValueError("OAuth 用户信息缺少 ID 字段") if not username: raise ValueError("OAuth 用户信息缺少用户名字段") # 计算本地角色 sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", []) is_superuser = bool(oauth_user_info.get("is_superuser", False)) role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles, is_superuser) # 查找是否已存在该 OAuth 用户 cursor.execute( "SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s", ("sso", oauth_id) ) row = cursor.fetchone() if row: # 用户已存在,更新信息(包括角色) user = User.from_row(row) cursor.execute(""" UPDATE users SET username = %s, email = %s, role = %s WHERE id = %s """, (username, email, role, user.id)) conn.commit() # 重新查询更新后的用户 cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,)) row = cursor.fetchone() return User.from_row(row) else: # 新用户,创建记录 user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}" cursor.execute(""" INSERT INTO users ( id, username, email, password_hash, role, oauth_provider, oauth_id, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, ( user_id, username, email, "", # OAuth 用户不需要密码 role, "sso", oauth_id, datetime.now() )) conn.commit() # 查询新创建的用户 cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,)) row = cursor.fetchone() return User.from_row(row) @staticmethod async def verify_sso_token(access_token: str) -> Dict[str, Any]: """ 通过 SSO 验证 token 并获取用户信息(含角色)。 使用 /api/v1/system/users/profile 端点获取完整用户信息, 包括 roles 列表和 is_superuser 标记,然后映射为本地角色。 Args: access_token: SSO 访问令牌 Returns: 用户信息字典 {id, username, email, role, ...} Raises: HTTPException(401): token 无效 HTTPException(503): SSO 中心不可用 """ profile_url = f"{settings.OAUTH_BASE_URL}/api/v1/system/users/profile" async with httpx.AsyncClient(timeout=10.0) as client: try: response = await client.get( profile_url, headers={"Authorization": f"Bearer {access_token}"} ) except httpx.RequestError: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="SSO 认证中心不可用" ) if response.status_code == 401: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的访问令牌" ) if response.status_code != 200: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=f"SSO 验证失败 ({response.status_code})" ) data = response.json() logger.debug(f"SSO profile response: {data}") # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}} code = data.get("code") if (code == 0 or code == "000000") and "data" in data: profile = data["data"] elif "id" in data or "username" in data: profile = data else: logger.error(f"Invalid profile response format: {data}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的访问令牌" ) # 提取角色信息并映射 sso_roles = profile.get("roles", []) is_superuser = bool(profile.get("is_superuser", False)) local_role = map_sso_roles_to_local(sso_roles, is_superuser) logger.info( f"SSO 用户 {profile.get('username')}: " f"sso_roles={sso_roles}, is_superuser={is_superuser} → local_role={local_role}" ) # 返回统一格式的用户信息 return { "id": profile.get("id"), "username": profile.get("username"), "email": profile.get("email", ""), "role": local_role, "sso_roles": sso_roles, "is_superuser": is_superuser, } @staticmethod async def refresh_sso_token(refresh_token: str) -> Dict[str, Any]: """ 向 SSO 中心刷新 token。 Args: refresh_token: SSO 刷新令牌 Returns: 新的 token 信息 {access_token, refresh_token, ...} Raises: HTTPException(401): refresh_token 无效 HTTPException(503): SSO 中心不可用 """ token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}" logger.debug(f"Refreshing token at: {token_url}") async with httpx.AsyncClient(timeout=10.0) as client: try: response = await client.post( token_url, data={ "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": settings.OAUTH_CLIENT_ID, "client_secret": settings.OAUTH_CLIENT_SECRET }, headers={"Content-Type": "application/x-www-form-urlencoded"} ) except httpx.RequestError as e: logger.error(f"SSO refresh request failed: {e}") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="SSO 认证中心不可用" ) logger.debug(f"SSO refresh response: status={response.status_code}") if response.status_code != 200: logger.error(f"SSO refresh failed: {response.status_code}, body={response.text}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌无效或已过期,请重新登录" ) data = response.json() logger.debug(f"SSO refresh response data: {data}") # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}} code = data.get("code") if (code == 0 or code == "000000") and "data" in data: return data["data"] elif "access_token" in data: return data else: logger.error(f"Invalid refresh response format: {data}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="刷新令牌无效或已过期,请重新登录" )