""" OAuth 2.0 认证服务 处理与 OAuth 认证中心的交互 """ import httpx import secrets from typing import Dict, Any, Optional from datetime import datetime from config import settings from models import User from database import get_db_connection class OAuthService: """OAuth 2.0 认证服务""" @staticmethod def generate_state() -> str: """ 生成随机 state 参数,用于防止 CSRF 攻击 Returns: 随机字符串 """ return secrets.token_urlsafe(32) @staticmethod def get_authorization_url(state: str) -> str: """ 构建 OAuth 授权 URL Args: state: 防CSRF的随机字符串 Returns: 完整的授权URL """ from urllib.parse import urlencode params = { "response_type": "code", "client_id": settings.OAUTH_CLIENT_ID, "redirect_uri": settings.OAUTH_REDIRECT_URI, "scope": settings.OAUTH_SCOPE, "state": state } authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}" return f"{authorize_url}?{urlencode(params)}" @staticmethod async def exchange_code_for_token(code: str) -> Dict[str, Any]: """ 用授权码换取访问令牌 Args: code: OAuth 授权码 Returns: 令牌信息字典,包含 access_token, token_type, expires_in 等 Raises: Exception: 令牌交换失败 """ token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}" async with httpx.AsyncClient() as client: response = await client.post( token_url, data={ "grant_type": "authorization_code", "code": code, "redirect_uri": settings.OAUTH_REDIRECT_URI, "client_id": settings.OAUTH_CLIENT_ID, "client_secret": settings.OAUTH_CLIENT_SECRET }, headers={"Content-Type": "application/x-www-form-urlencoded"} ) if response.status_code != 200: raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}") data = response.json() # 处理不同的响应格式 if "access_token" in data: return data elif data.get("code") == 0 and "data" in data: return data["data"] else: raise Exception(f"无效的令牌响应格式: {data}") @staticmethod async def get_user_info(access_token: str) -> Dict[str, Any]: """ 使用访问令牌获取用户信息 Args: access_token: OAuth 访问令牌 Returns: 用户信息字典 Raises: Exception: 获取用户信息失败 """ userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}" async with httpx.AsyncClient() as client: response = await client.get( userinfo_url, headers={"Authorization": f"Bearer {access_token}"} ) if response.status_code != 200: raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}") data = response.json() # 处理不同的响应格式 if "sub" in data or "id" in data: return data elif data.get("code") == 0 and "data" in data: return data["data"] else: raise Exception(f"无效的用户信息响应格式: {data}") @staticmethod def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User: """ 从 OAuth 用户信息同步到本地数据库 如果用户不存在则创建,如果存在则更新 Args: oauth_user_info: OAuth 返回的用户信息 Returns: 本地用户对象 """ with get_db_connection() as conn: cursor = conn.cursor() # 提取用户信息(兼容不同的字段名) oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id") username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name") email = oauth_user_info.get("email", "") if not oauth_id: raise ValueError("OAuth 用户信息缺少 ID 字段") if not username: raise ValueError("OAuth 用户信息缺少用户名字段") # 查找是否已存在该 OAuth 用户 cursor.execute( "SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?", ("sso", oauth_id) ) row = cursor.fetchone() if row: # 用户已存在,更新信息 user = User.from_row(row) # 更新用户名和邮箱(如果有变化) cursor.execute(""" UPDATE users SET username = ?, email = ? WHERE id = ? """, (username, email, user.id)) conn.commit() # 重新查询更新后的用户 cursor.execute("SELECT * FROM users WHERE id = ?", (user.id,)) row = cursor.fetchone() return User.from_row(row) else: # 新用户,创建记录 user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}" # 暂时所有用户都是 annotator 角色(SSO 未提供角色信息) role = "annotator" cursor.execute(""" INSERT INTO users ( id, username, email, password_hash, role, oauth_provider, oauth_id, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( user_id, username, email, "", # OAuth 用户不需要密码 role, "sso", oauth_id, datetime.now() )) conn.commit() # 查询新创建的用户 cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,)) row = cursor.fetchone() return User.from_row(row)