| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- """
- OAuth 2.0 认证服务
- 处理与 OAuth 认证中心的交互,包括 token 验证和刷新
- """
- import httpx
- import logging
- import secrets
- from typing import Dict, Any, Optional
- from datetime import datetime
- from fastapi import HTTPException, status
- from config import settings
- from models import User
- from database import get_db_connection
- logger = logging.getLogger(__name__)
- # SSO 角色 → 本地角色映射
- SSO_ROLE_MAPPING = {
- "super_admin": "admin",
- "label_admin": "admin",
- "admin": "admin",
- "labeler": "annotator",
- }
- DEFAULT_LOCAL_ROLE = "viewer"
- def map_sso_roles_to_local(sso_roles: list, is_superuser: bool = False) -> str:
- """
- 将 SSO 角色列表映射为本地单一角色。
- 优先级: admin > annotator > viewer
- """
- if is_superuser:
- return "admin"
- local_role = DEFAULT_LOCAL_ROLE
- for sso_role in sso_roles:
- mapped = SSO_ROLE_MAPPING.get(sso_role)
- if mapped == "admin":
- return "admin"
- if mapped == "annotator":
- local_role = "annotator"
- return local_role
- class OAuthService:
- """OAuth 2.0 认证服务"""
-
- @staticmethod
- def generate_state() -> str:
- """
- 生成随机 state 参数,用于防止 CSRF 攻击
-
- Returns:
- 随机字符串
- """
- return secrets.token_urlsafe(32)
-
- @staticmethod
- def get_authorization_url(state: str) -> str:
- """
- 构建 OAuth 授权 URL
-
- Args:
- state: 防CSRF的随机字符串
-
- Returns:
- 完整的授权URL
- """
- from urllib.parse import urlencode
-
- params = {
- "response_type": "code",
- "client_id": settings.OAUTH_CLIENT_ID,
- "redirect_uri": settings.OAUTH_REDIRECT_URI,
- "scope": settings.OAUTH_SCOPE,
- "state": state
- }
-
- authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}"
- return f"{authorize_url}?{urlencode(params)}"
-
- @staticmethod
- async def exchange_code_for_token(code: str) -> Dict[str, Any]:
- """
- 用授权码换取访问令牌
-
- Args:
- code: OAuth 授权码
-
- Returns:
- 令牌信息字典,包含 access_token, token_type, expires_in 等
-
- Raises:
- Exception: 令牌交换失败
- """
- token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
-
- async with httpx.AsyncClient() as client:
- response = await client.post(
- token_url,
- data={
- "grant_type": "authorization_code",
- "code": code,
- "redirect_uri": settings.OAUTH_REDIRECT_URI,
- "client_id": settings.OAUTH_CLIENT_ID,
- "client_secret": settings.OAUTH_CLIENT_SECRET
- },
- headers={"Content-Type": "application/x-www-form-urlencoded"}
- )
-
- if response.status_code != 200:
- raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
-
- data = response.json()
-
- # 处理不同的响应格式
- if "access_token" in data:
- return data
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
- code = data.get("code")
- if (code == 0 or code == "000000") and "data" in data:
- return data["data"]
- else:
- raise Exception(f"无效的令牌响应格式: {data}")
-
- @staticmethod
- async def get_user_info(access_token: str) -> Dict[str, Any]:
- """
- 使用访问令牌获取用户信息
-
- Args:
- access_token: OAuth 访问令牌
-
- Returns:
- 用户信息字典
-
- Raises:
- Exception: 获取用户信息失败
- """
- userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}"
-
- async with httpx.AsyncClient() as client:
- response = await client.get(
- userinfo_url,
- headers={"Authorization": f"Bearer {access_token}"}
- )
-
- if response.status_code != 200:
- raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
-
- data = response.json()
-
- # 处理不同的响应格式
- if "sub" in data or "id" in data:
- return data
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
- code = data.get("code")
- if (code == 0 or code == "000000") and "data" in data:
- return data["data"]
- else:
- raise Exception(f"无效的用户信息响应格式: {data}")
-
- @staticmethod
- def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
- """
- 从 OAuth 用户信息同步到本地数据库
- 如果用户不存在则创建,如果存在则更新(包括角色)
-
- Args:
- oauth_user_info: OAuth 返回的用户信息
-
- Returns:
- 本地用户对象
- """
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 提取用户信息(兼容不同的字段名)
- oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
- username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
- email = oauth_user_info.get("email", "")
-
- if not oauth_id:
- raise ValueError("OAuth 用户信息缺少 ID 字段")
-
- if not username:
- raise ValueError("OAuth 用户信息缺少用户名字段")
-
- # 计算本地角色
- sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", [])
- is_superuser = bool(oauth_user_info.get("is_superuser", False))
- role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles, is_superuser)
-
- # 查找是否已存在该 OAuth 用户
- cursor.execute(
- "SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s",
- ("sso", oauth_id)
- )
- row = cursor.fetchone()
-
- if row:
- # 用户已存在,更新信息(包括角色)
- user = User.from_row(row)
-
- cursor.execute("""
- UPDATE users
- SET username = %s, email = %s, role = %s
- WHERE id = %s
- """, (username, email, role, user.id))
-
- conn.commit()
-
- # 重新查询更新后的用户
- cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,))
- row = cursor.fetchone()
- return User.from_row(row)
- else:
- # 新用户,创建记录
- user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
-
- cursor.execute("""
- INSERT INTO users (
- id, username, email, password_hash, role,
- oauth_provider, oauth_id, created_at
- ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
- """, (
- user_id,
- username,
- email,
- "", # OAuth 用户不需要密码
- role,
- "sso",
- oauth_id,
- datetime.now()
- ))
-
- conn.commit()
-
- # 查询新创建的用户
- cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
- row = cursor.fetchone()
- return User.from_row(row)
- @staticmethod
- async def verify_sso_token(access_token: str) -> Dict[str, Any]:
- """
- 通过 SSO 验证 token 并获取用户信息(含角色)。
-
- 使用 /api/v1/system/users/profile 端点获取完整用户信息,
- 包括 roles 列表和 is_superuser 标记,然后映射为本地角色。
-
- Args:
- access_token: SSO 访问令牌
-
- Returns:
- 用户信息字典 {id, username, email, role, ...}
-
- Raises:
- HTTPException(401): token 无效
- HTTPException(503): SSO 中心不可用
- """
- profile_url = f"{settings.OAUTH_BASE_URL}/api/v1/system/users/profile"
-
- async with httpx.AsyncClient(timeout=10.0) as client:
- try:
- response = await client.get(
- profile_url,
- headers={"Authorization": f"Bearer {access_token}"}
- )
- except httpx.RequestError:
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="SSO 认证中心不可用"
- )
-
- if response.status_code == 401:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的访问令牌"
- )
-
- if response.status_code != 200:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail=f"SSO 验证失败 ({response.status_code})"
- )
-
- data = response.json()
- logger.debug(f"SSO profile response: {data}")
-
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
- code = data.get("code")
- if (code == 0 or code == "000000") and "data" in data:
- profile = data["data"]
- elif "id" in data or "username" in data:
- profile = data
- else:
- logger.error(f"Invalid profile response format: {data}")
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="无效的访问令牌"
- )
-
- # 提取角色信息并映射
- sso_roles = profile.get("roles", [])
- is_superuser = bool(profile.get("is_superuser", False))
- local_role = map_sso_roles_to_local(sso_roles, is_superuser)
-
- logger.info(
- f"SSO 用户 {profile.get('username')}: "
- f"sso_roles={sso_roles}, is_superuser={is_superuser} → local_role={local_role}"
- )
-
- # 返回统一格式的用户信息
- return {
- "id": profile.get("id"),
- "username": profile.get("username"),
- "email": profile.get("email", ""),
- "role": local_role,
- "sso_roles": sso_roles,
- "is_superuser": is_superuser,
- }
- @staticmethod
- async def refresh_sso_token(refresh_token: str) -> Dict[str, Any]:
- """
- 向 SSO 中心刷新 token。
-
- Args:
- refresh_token: SSO 刷新令牌
-
- Returns:
- 新的 token 信息 {access_token, refresh_token, ...}
-
- Raises:
- HTTPException(401): refresh_token 无效
- HTTPException(503): SSO 中心不可用
- """
- token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
- logger.debug(f"Refreshing token at: {token_url}")
-
- async with httpx.AsyncClient(timeout=10.0) as client:
- try:
- response = await client.post(
- token_url,
- data={
- "grant_type": "refresh_token",
- "refresh_token": refresh_token,
- "client_id": settings.OAUTH_CLIENT_ID,
- "client_secret": settings.OAUTH_CLIENT_SECRET
- },
- headers={"Content-Type": "application/x-www-form-urlencoded"}
- )
- except httpx.RequestError as e:
- logger.error(f"SSO refresh request failed: {e}")
- raise HTTPException(
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
- detail="SSO 认证中心不可用"
- )
-
- logger.debug(f"SSO refresh response: status={response.status_code}")
-
- if response.status_code != 200:
- logger.error(f"SSO refresh failed: {response.status_code}, body={response.text}")
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="刷新令牌无效或已过期,请重新登录"
- )
-
- data = response.json()
- logger.debug(f"SSO refresh response data: {data}")
-
- # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
- code = data.get("code")
- if (code == 0 or code == "000000") and "data" in data:
- return data["data"]
- elif "access_token" in data:
- return data
- else:
- logger.error(f"Invalid refresh response format: {data}")
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="刷新令牌无效或已过期,请重新登录"
- )
|