| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- """
- OAuth 2.0 认证服务
- 处理与 OAuth 认证中心的交互
- """
- import httpx
- import secrets
- from typing import Dict, Any, Optional
- from datetime import datetime
- from config import settings
- from models import User
- from database import get_db_connection
- class OAuthService:
- """OAuth 2.0 认证服务"""
-
- @staticmethod
- def generate_state() -> str:
- """
- 生成随机 state 参数,用于防止 CSRF 攻击
-
- Returns:
- 随机字符串
- """
- return secrets.token_urlsafe(32)
-
- @staticmethod
- def get_authorization_url(state: str) -> str:
- """
- 构建 OAuth 授权 URL
-
- Args:
- state: 防CSRF的随机字符串
-
- Returns:
- 完整的授权URL
- """
- from urllib.parse import urlencode
-
- params = {
- "response_type": "code",
- "client_id": settings.OAUTH_CLIENT_ID,
- "redirect_uri": settings.OAUTH_REDIRECT_URI,
- "scope": settings.OAUTH_SCOPE,
- "state": state
- }
-
- authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}"
- return f"{authorize_url}?{urlencode(params)}"
-
- @staticmethod
- async def exchange_code_for_token(code: str) -> Dict[str, Any]:
- """
- 用授权码换取访问令牌
-
- Args:
- code: OAuth 授权码
-
- Returns:
- 令牌信息字典,包含 access_token, token_type, expires_in 等
-
- Raises:
- Exception: 令牌交换失败
- """
- token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
-
- async with httpx.AsyncClient() as client:
- response = await client.post(
- token_url,
- data={
- "grant_type": "authorization_code",
- "code": code,
- "redirect_uri": settings.OAUTH_REDIRECT_URI,
- "client_id": settings.OAUTH_CLIENT_ID,
- "client_secret": settings.OAUTH_CLIENT_SECRET
- },
- headers={"Content-Type": "application/x-www-form-urlencoded"}
- )
-
- if response.status_code != 200:
- raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
-
- data = response.json()
-
- # 处理不同的响应格式
- if "access_token" in data:
- return data
- elif data.get("code") == 0 and "data" in data:
- return data["data"]
- else:
- raise Exception(f"无效的令牌响应格式: {data}")
-
- @staticmethod
- async def get_user_info(access_token: str) -> Dict[str, Any]:
- """
- 使用访问令牌获取用户信息
-
- Args:
- access_token: OAuth 访问令牌
-
- Returns:
- 用户信息字典
-
- Raises:
- Exception: 获取用户信息失败
- """
- userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}"
-
- async with httpx.AsyncClient() as client:
- response = await client.get(
- userinfo_url,
- headers={"Authorization": f"Bearer {access_token}"}
- )
-
- if response.status_code != 200:
- raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
-
- data = response.json()
-
- # 处理不同的响应格式
- if "sub" in data or "id" in data:
- return data
- elif data.get("code") == 0 and "data" in data:
- return data["data"]
- else:
- raise Exception(f"无效的用户信息响应格式: {data}")
-
- @staticmethod
- def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
- """
- 从 OAuth 用户信息同步到本地数据库
- 如果用户不存在则创建,如果存在则更新
-
- Args:
- oauth_user_info: OAuth 返回的用户信息
-
- Returns:
- 本地用户对象
- """
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # 提取用户信息(兼容不同的字段名)
- oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
- username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
- email = oauth_user_info.get("email", "")
-
- if not oauth_id:
- raise ValueError("OAuth 用户信息缺少 ID 字段")
-
- if not username:
- raise ValueError("OAuth 用户信息缺少用户名字段")
-
- # 查找是否已存在该 OAuth 用户
- cursor.execute(
- "SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
- ("sso", oauth_id)
- )
- row = cursor.fetchone()
-
- if row:
- # 用户已存在,更新信息
- user = User.from_row(row)
-
- # 更新用户名和邮箱(如果有变化)
- cursor.execute("""
- UPDATE users
- SET username = ?, email = ?
- WHERE id = ?
- """, (username, email, user.id))
-
- conn.commit()
-
- # 重新查询更新后的用户
- cursor.execute("SELECT * FROM users WHERE id = ?", (user.id,))
- row = cursor.fetchone()
- return User.from_row(row)
- else:
- # 新用户,创建记录
- user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
-
- # 暂时所有用户都是 annotator 角色(SSO 未提供角色信息)
- role = "annotator"
-
- cursor.execute("""
- INSERT INTO users (
- id, username, email, password_hash, role,
- oauth_provider, oauth_id, created_at
- ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
- """, (
- user_id,
- username,
- email,
- "", # OAuth 用户不需要密码
- role,
- "sso",
- oauth_id,
- datetime.now()
- ))
-
- conn.commit()
-
- # 查询新创建的用户
- cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
- row = cursor.fetchone()
- return User.from_row(row)
|