| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- """
- OAuth 2.0 认证服务
- 处理与统一认证平台的交互,包括 code 换 token、用户信息获取和角色映射
- """
- import httpx
- import logging
- import secrets
- from typing import Dict, Any, Optional
- from datetime import datetime
- from fastapi import HTTPException, status
- from config import settings
- from models import User
- from database import get_db_connection
- logger = logging.getLogger(__name__)
- # SSO 角色 → 本地角色映射
- # 仅识别 ann_sys_admin(标注管理员)、ann_operator(标注员)、ann_viewer(查看者)
- # 其他角色一律不识别,无权限
- SSO_ROLE_MAPPING = {
- # 角色代码
- "ann_sys_admin": "admin",
- "ann_operator": "annotator",
- "ann_viewer": "viewer",
- # 角色名称(对应中文显示)
- "标注管理员": "admin",
- "标注员": "annotator",
- "查看者": "viewer",
- }
- def mask_token_in_dict(data: dict) -> dict:
- """日志打印时隐藏 token 字段,避免敏感信息泄露"""
- masked = data.copy()
- for key in ("access_token", "refresh_token", "token"):
- if key in masked and isinstance(masked[key], str):
- masked[key] = masked[key][:10] + "..." if len(masked[key]) > 10 else "***"
- return masked
- def map_sso_roles_to_local(sso_roles: list) -> Optional[str]:
- """
- 将 SSO 角色列表映射为本地角色。
- 仅识别 ann_sys_admin、ann_operator、ann_viewer,未识别到任何角色则返回 None(无权限)。
- 优先级: admin > annotator > viewer
- """
- local_role: Optional[str] = None
- for sso_role in sso_roles:
- mapped = SSO_ROLE_MAPPING.get(sso_role)
- if mapped == "admin":
- return "admin"
- if mapped == "annotator":
- local_role = "annotator"
- elif mapped == "viewer" and local_role is None:
- local_role = "viewer"
- return local_role
- class OAuthService:
- """OAuth 2.0 认证服务"""
- @staticmethod
- def generate_state() -> str:
- """生成随机 state 参数,用于防止 CSRF 攻击"""
- return secrets.token_urlsafe(32)
- @staticmethod
- def get_authorization_url(state: str) -> str:
- """构建 OAuth 授权 URL"""
- from urllib.parse import urlencode
- params = {
- "response_type": "code",
- "client_id": settings.SSO_CLIENT_ID,
- "redirect_uri": settings.SSO_REDIRECT_URI,
- "scope": settings.SSO_SCOPE,
- "state": state,
- }
- authorize_url = f"{settings.SSO_BASE_URL}{settings.SSO_AUTHORIZE_ENDPOINT}"
- full_url = f"{authorize_url}?{urlencode(params)}"
- logger.info(f"[OAuth.get_authorization_url] state={state}")
- logger.info(f"[OAuth.get_authorization_url] 完整授权URL: {full_url}")
- logger.info(f"[OAuth.get_authorization_url] scope={settings.SSO_SCOPE}")
- return full_url
- @staticmethod
- async def exchange_code_for_token(code: str) -> Dict[str, Any]:
- """用授权码换取访问令牌"""
- token_url = f"{settings.SSO_BASE_URL}{settings.SSO_TOKEN_ENDPOINT}"
- post_data = {
- "grant_type": "authorization_code",
- "code": code,
- "redirect_uri": settings.SSO_REDIRECT_URI,
- "client_id": settings.SSO_CLIENT_ID,
- "client_secret": settings.SSO_CLIENT_SECRET,
- }
- logger.info(f"[OAuth.exchange_code] POST {token_url}")
- logger.info(f"[OAuth.exchange_code] POST body (除secret): grant_type={post_data['grant_type']}, code={code[:15]}..., redirect_uri={post_data['redirect_uri']}, client_id={post_data['client_id']}")
- async with httpx.AsyncClient() as client:
- response = await client.post(
- token_url,
- data=post_data,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
- )
- logger.info(f"[OAuth.exchange_code] SSO 响应状态码: {response.status_code}")
- if response.status_code != 200:
- logger.error(f"[OAuth.exchange_code] 令牌交换失败, 响应内容: {response.text}")
- raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
- data = response.json()
- logger.info(f"[OAuth.exchange_code] SSO JSON 响应 (mask token): {mask_token_in_dict(data)}")
- if "access_token" in data:
- return data
- code_val = data.get("code")
- if (code_val == 0 or code_val == "000000") and "data" in data:
- return data["data"]
- else:
- logger.error(f"[OAuth.exchange_code] 无效的令牌响应格式: {data}")
- raise Exception(f"无效的令牌响应格式: {data}")
- @staticmethod
- async def get_user_info(access_token: str) -> Dict[str, Any]:
- """使用访问令牌获取用户信息(不含角色)"""
- userinfo_url = f"{settings.SSO_BASE_URL}{settings.SSO_USERINFO_ENDPOINT}"
- async with httpx.AsyncClient() as client:
- response = await client.get(
- userinfo_url,
- headers={"Authorization": f"Bearer {access_token}"},
- )
- if response.status_code != 200:
- raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
- data = response.json()
- if "sub" in data or "id" in data:
- return data
- code_val = data.get("code")
- if (code_val == 0 or code_val == "000000") and "data" in data:
- return data["data"]
- else:
- raise Exception(f"无效的用户信息响应格式: {data}")
- @staticmethod
- async def get_user_profile(access_token: str) -> Dict[str, Any]:
- """
- 通过 SSO /oauth/userinfo 获取用户信息和角色。
- 返回格式: {sub, username, email, roles: [{name, code}]}
- """
- userinfo_url = f"{settings.SSO_BASE_URL}{settings.SSO_USERINFO_ENDPOINT}"
- profile = {}
- logger.info(f"[OAuth.get_user_profile] 请求 SSO 用户信息, URL: {userinfo_url}")
- async with httpx.AsyncClient(timeout=10.0) as client:
- try:
- response = await client.get(
- userinfo_url,
- headers={"Authorization": f"Bearer {access_token}"},
- )
- except httpx.RequestError:
- logger.error(f"[OAuth.get_user_profile] SSO 请求失败, URL: {userinfo_url}")
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="SSO 认证中心不可用",
- )
- logger.info(f"[OAuth.get_user_profile] SSO 响应状态码: {response.status_code}")
- if response.status_code == 401:
- logger.error(f"[OAuth.get_user_profile] SSO 返回 401, token 无效")
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的访问令牌",
- )
- if response.status_code == 200:
- data = response.json()
- logger.info(f"[OAuth.get_user_profile] SSO userinfo 原始响应: {data}")
- code_val = data.get("code")
- if (code_val == 0 or code_val == "000000") and "data" in data:
- profile = data["data"]
- elif "id" in data or "username" in data or "sub" in data:
- profile = data
- # 解析角色(支持对象和字符串两种格式)
- raw_roles = profile.get("roles", [])
- sso_roles: list = []
- for role_item in raw_roles:
- if isinstance(role_item, dict):
- sso_roles.append(role_item.get("code", ""))
- name = role_item.get("name", "")
- if name:
- sso_roles.append(name)
- elif isinstance(role_item, str):
- sso_roles.append(role_item)
- local_role = map_sso_roles_to_local(sso_roles)
- logger.info(
- f"SSO 用户 {profile.get('username')}: "
- f"roles={raw_roles}, sso_roles={sso_roles} → local_role={local_role}"
- )
- return {
- "id": profile.get("id") or profile.get("sub"),
- "username": profile.get("username") or profile.get("preferred_username") or profile.get("name"),
- "email": profile.get("email", ""),
- "role": local_role,
- "sso_roles": sso_roles,
- }
- @staticmethod
- def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
- """
- 从 OAuth 用户信息同步到本地数据库
- 如果用户不存在则创建,如果存在则更新(包括角色)
- """
- with get_db_connection() as conn:
- cursor = conn.cursor()
- oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
- username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
- email = oauth_user_info.get("email", "")
- logger.info(f"[OAuth.sync_user] oauth_id={oauth_id}, username={username}, email={email}")
- if not oauth_id:
- logger.error("[OAuth.sync_user] OAuth 用户信息缺少 ID 字段")
- raise ValueError("OAuth 用户信息缺少 ID 字段")
- if not username:
- logger.error("[OAuth.sync_user] OAuth 用户信息缺少用户名字段")
- raise ValueError("OAuth 用户信息缺少用户名字段")
- sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", [])
- role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles)
- logger.info(f"[OAuth.sync_user] sso_roles={sso_roles}, computed_role={role}")
- if role is None:
- logger.error(f"[OAuth.sync_user] 用户 {username} 没有被识别的 SSO 角色(sso_roles={sso_roles})")
- raise ValueError(f"用户 {username} 没有被识别的 SSO 角色(sso_roles={sso_roles}),无权限访问")
- cursor.execute(
- "SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s",
- ("sso", oauth_id),
- )
- row = cursor.fetchone()
- if row:
- user = User.from_row(row)
- logger.info(f"[OAuth.sync_user] 用户已存在: id={user.id}, old_role={user.role}, new_role={role}")
- cursor.execute(
- "UPDATE users SET username = %s, email = %s, role = %s WHERE id = %s",
- (username, email, role, user.id),
- )
- conn.commit()
- logger.info(f"[OAuth.sync_user] 用户角色已更新: id={user.id}, role={role}")
- cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,))
- row = cursor.fetchone()
- return User.from_row(row)
- else:
- user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
- logger.info(f"[OAuth.sync_user] 创建新用户: user_id={user_id}, username={username}, role={role}")
- cursor.execute(
- """
- INSERT INTO users (
- id, username, email, password_hash, role,
- oauth_provider, oauth_id, created_at
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- """,
- (user_id, username, email, "", role, "sso", oauth_id, datetime.now()),
- )
- conn.commit()
- cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
- row = cursor.fetchone()
- return User.from_row(row)
|