""" OAuth 2.0 认证服务 处理与统一认证平台的交互,包括 code 换 token、用户信息获取和角色映射 """ import httpx import logging import secrets from typing import Dict, Any, Optional from datetime import datetime from fastapi import HTTPException, status from config import settings from models import User from database import get_db_connection logger = logging.getLogger(__name__) # SSO 角色 → 本地角色映射 # 仅识别 ann_sys_admin(标注管理员)、ann_operator(标注员)、ann_viewer(查看者) # 其他角色一律不识别,无权限 SSO_ROLE_MAPPING = { # 角色代码 "ann_sys_admin": "admin", "ann_operator": "annotator", "ann_viewer": "viewer", # 角色名称(对应中文显示) "标注管理员": "admin", "标注员": "annotator", "查看者": "viewer", } def mask_token_in_dict(data: dict) -> dict: """日志打印时隐藏 token 字段,避免敏感信息泄露""" masked = data.copy() for key in ("access_token", "refresh_token", "token"): if key in masked and isinstance(masked[key], str): masked[key] = masked[key][:10] + "..." if len(masked[key]) > 10 else "***" return masked def map_sso_roles_to_local(sso_roles: list) -> Optional[str]: """ 将 SSO 角色列表映射为本地角色。 仅识别 ann_sys_admin、ann_operator、ann_viewer,未识别到任何角色则返回 None(无权限)。 优先级: admin > annotator > viewer """ local_role: Optional[str] = None for sso_role in sso_roles: mapped = SSO_ROLE_MAPPING.get(sso_role) if mapped == "admin": return "admin" if mapped == "annotator": local_role = "annotator" elif mapped == "viewer" and local_role is None: local_role = "viewer" return local_role class OAuthService: """OAuth 2.0 认证服务""" @staticmethod def generate_state() -> str: """生成随机 state 参数,用于防止 CSRF 攻击""" return secrets.token_urlsafe(32) @staticmethod def get_authorization_url(state: str) -> str: """构建 OAuth 授权 URL""" from urllib.parse import urlencode params = { "response_type": "code", "client_id": settings.SSO_CLIENT_ID, "redirect_uri": settings.SSO_REDIRECT_URI, "scope": settings.SSO_SCOPE, "state": state, } authorize_url = f"{settings.SSO_BASE_URL}{settings.SSO_AUTHORIZE_ENDPOINT}" full_url = f"{authorize_url}?{urlencode(params)}" logger.info(f"[OAuth.get_authorization_url] state={state}") logger.info(f"[OAuth.get_authorization_url] 完整授权URL: {full_url}") logger.info(f"[OAuth.get_authorization_url] scope={settings.SSO_SCOPE}") return full_url @staticmethod async def exchange_code_for_token(code: str) -> Dict[str, Any]: """用授权码换取访问令牌""" token_url = f"{settings.SSO_BASE_URL}{settings.SSO_TOKEN_ENDPOINT}" post_data = { "grant_type": "authorization_code", "code": code, "redirect_uri": settings.SSO_REDIRECT_URI, "client_id": settings.SSO_CLIENT_ID, "client_secret": settings.SSO_CLIENT_SECRET, } logger.info(f"[OAuth.exchange_code] POST {token_url}") logger.info(f"[OAuth.exchange_code] POST body (除secret): grant_type={post_data['grant_type']}, code={code[:15]}..., redirect_uri={post_data['redirect_uri']}, client_id={post_data['client_id']}") async with httpx.AsyncClient() as client: response = await client.post( token_url, data=post_data, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) logger.info(f"[OAuth.exchange_code] SSO 响应状态码: {response.status_code}") if response.status_code != 200: logger.error(f"[OAuth.exchange_code] 令牌交换失败, 响应内容: {response.text}") raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}") data = response.json() logger.info(f"[OAuth.exchange_code] SSO JSON 响应 (mask token): {mask_token_in_dict(data)}") if "access_token" in data: return data code_val = data.get("code") if (code_val == 0 or code_val == "000000") and "data" in data: return data["data"] else: logger.error(f"[OAuth.exchange_code] 无效的令牌响应格式: {data}") raise Exception(f"无效的令牌响应格式: {data}") @staticmethod async def get_user_info(access_token: str) -> Dict[str, Any]: """使用访问令牌获取用户信息(不含角色)""" userinfo_url = f"{settings.SSO_BASE_URL}{settings.SSO_USERINFO_ENDPOINT}" async with httpx.AsyncClient() as client: response = await client.get( userinfo_url, headers={"Authorization": f"Bearer {access_token}"}, ) if response.status_code != 200: raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}") data = response.json() if "sub" in data or "id" in data: return data code_val = data.get("code") if (code_val == 0 or code_val == "000000") and "data" in data: return data["data"] else: raise Exception(f"无效的用户信息响应格式: {data}") @staticmethod async def get_user_profile(access_token: str) -> Dict[str, Any]: """ 通过 SSO /oauth/userinfo 获取用户信息和角色。 返回格式: {sub, username, email, roles: [{name, code}]} """ userinfo_url = f"{settings.SSO_BASE_URL}{settings.SSO_USERINFO_ENDPOINT}" profile = {} logger.info(f"[OAuth.get_user_profile] 请求 SSO 用户信息, URL: {userinfo_url}") async with httpx.AsyncClient(timeout=10.0) as client: try: response = await client.get( userinfo_url, headers={"Authorization": f"Bearer {access_token}"}, ) except httpx.RequestError: logger.error(f"[OAuth.get_user_profile] SSO 请求失败, URL: {userinfo_url}") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="SSO 认证中心不可用", ) logger.info(f"[OAuth.get_user_profile] SSO 响应状态码: {response.status_code}") if response.status_code == 401: logger.error(f"[OAuth.get_user_profile] SSO 返回 401, token 无效") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的访问令牌", ) if response.status_code == 200: data = response.json() logger.info(f"[OAuth.get_user_profile] SSO userinfo 原始响应: {data}") code_val = data.get("code") if (code_val == 0 or code_val == "000000") and "data" in data: profile = data["data"] elif "id" in data or "username" in data or "sub" in data: profile = data # 解析角色(支持对象和字符串两种格式) raw_roles = profile.get("roles", []) sso_roles: list = [] for role_item in raw_roles: if isinstance(role_item, dict): sso_roles.append(role_item.get("code", "")) name = role_item.get("name", "") if name: sso_roles.append(name) elif isinstance(role_item, str): sso_roles.append(role_item) local_role = map_sso_roles_to_local(sso_roles) logger.info( f"SSO 用户 {profile.get('username')}: " f"roles={raw_roles}, sso_roles={sso_roles} → local_role={local_role}" ) return { "id": profile.get("id") or profile.get("sub"), "username": profile.get("username") or profile.get("preferred_username") or profile.get("name"), "email": profile.get("email", ""), "role": local_role, "sso_roles": sso_roles, } @staticmethod def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User: """ 从 OAuth 用户信息同步到本地数据库 如果用户不存在则创建,如果存在则更新(包括角色) """ with get_db_connection() as conn: cursor = conn.cursor() oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id") username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name") email = oauth_user_info.get("email", "") logger.info(f"[OAuth.sync_user] oauth_id={oauth_id}, username={username}, email={email}") if not oauth_id: logger.error("[OAuth.sync_user] OAuth 用户信息缺少 ID 字段") raise ValueError("OAuth 用户信息缺少 ID 字段") if not username: logger.error("[OAuth.sync_user] OAuth 用户信息缺少用户名字段") raise ValueError("OAuth 用户信息缺少用户名字段") sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", []) role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles) logger.info(f"[OAuth.sync_user] sso_roles={sso_roles}, computed_role={role}") if role is None: logger.error(f"[OAuth.sync_user] 用户 {username} 没有被识别的 SSO 角色(sso_roles={sso_roles})") raise ValueError(f"用户 {username} 没有被识别的 SSO 角色(sso_roles={sso_roles}),无权限访问") cursor.execute( "SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s", ("sso", oauth_id), ) row = cursor.fetchone() if row: user = User.from_row(row) logger.info(f"[OAuth.sync_user] 用户已存在: id={user.id}, old_role={user.role}, new_role={role}") cursor.execute( "UPDATE users SET username = %s, email = %s, role = %s WHERE id = %s", (username, email, role, user.id), ) conn.commit() logger.info(f"[OAuth.sync_user] 用户角色已更新: id={user.id}, role={role}") cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,)) row = cursor.fetchone() return User.from_row(row) else: user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}" logger.info(f"[OAuth.sync_user] 创建新用户: user_id={user_id}, username={username}, role={role}") cursor.execute( """ INSERT INTO users ( id, username, email, password_hash, role, oauth_provider, oauth_id, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, (user_id, username, email, "", role, "sso", oauth_id, datetime.now()), ) conn.commit() cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,)) row = cursor.fetchone() return User.from_row(row)