oauth_service.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """
  2. OAuth 2.0 认证服务
  3. 处理与 OAuth 认证中心的交互
  4. """
  5. import httpx
  6. import secrets
  7. from typing import Dict, Any, Optional
  8. from datetime import datetime
  9. from config import settings
  10. from models import User
  11. from database import get_db_connection
  12. class OAuthService:
  13. """OAuth 2.0 认证服务"""
  14. @staticmethod
  15. def generate_state() -> str:
  16. """
  17. 生成随机 state 参数,用于防止 CSRF 攻击
  18. Returns:
  19. 随机字符串
  20. """
  21. return secrets.token_urlsafe(32)
  22. @staticmethod
  23. def get_authorization_url(state: str) -> str:
  24. """
  25. 构建 OAuth 授权 URL
  26. Args:
  27. state: 防CSRF的随机字符串
  28. Returns:
  29. 完整的授权URL
  30. """
  31. from urllib.parse import urlencode
  32. params = {
  33. "response_type": "code",
  34. "client_id": settings.OAUTH_CLIENT_ID,
  35. "redirect_uri": settings.OAUTH_REDIRECT_URI,
  36. "scope": settings.OAUTH_SCOPE,
  37. "state": state
  38. }
  39. authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}"
  40. return f"{authorize_url}?{urlencode(params)}"
  41. @staticmethod
  42. async def exchange_code_for_token(code: str) -> Dict[str, Any]:
  43. """
  44. 用授权码换取访问令牌
  45. Args:
  46. code: OAuth 授权码
  47. Returns:
  48. 令牌信息字典,包含 access_token, token_type, expires_in 等
  49. Raises:
  50. Exception: 令牌交换失败
  51. """
  52. token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
  53. async with httpx.AsyncClient() as client:
  54. response = await client.post(
  55. token_url,
  56. data={
  57. "grant_type": "authorization_code",
  58. "code": code,
  59. "redirect_uri": settings.OAUTH_REDIRECT_URI,
  60. "client_id": settings.OAUTH_CLIENT_ID,
  61. "client_secret": settings.OAUTH_CLIENT_SECRET
  62. },
  63. headers={"Content-Type": "application/x-www-form-urlencoded"}
  64. )
  65. if response.status_code != 200:
  66. raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
  67. data = response.json()
  68. # 处理不同的响应格式
  69. if "access_token" in data:
  70. return data
  71. elif data.get("code") == 0 and "data" in data:
  72. return data["data"]
  73. else:
  74. raise Exception(f"无效的令牌响应格式: {data}")
  75. @staticmethod
  76. async def get_user_info(access_token: str) -> Dict[str, Any]:
  77. """
  78. 使用访问令牌获取用户信息
  79. Args:
  80. access_token: OAuth 访问令牌
  81. Returns:
  82. 用户信息字典
  83. Raises:
  84. Exception: 获取用户信息失败
  85. """
  86. userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}"
  87. async with httpx.AsyncClient() as client:
  88. response = await client.get(
  89. userinfo_url,
  90. headers={"Authorization": f"Bearer {access_token}"}
  91. )
  92. if response.status_code != 200:
  93. raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
  94. data = response.json()
  95. # 处理不同的响应格式
  96. if "sub" in data or "id" in data:
  97. return data
  98. elif data.get("code") == 0 and "data" in data:
  99. return data["data"]
  100. else:
  101. raise Exception(f"无效的用户信息响应格式: {data}")
  102. @staticmethod
  103. def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
  104. """
  105. 从 OAuth 用户信息同步到本地数据库
  106. 如果用户不存在则创建,如果存在则更新
  107. Args:
  108. oauth_user_info: OAuth 返回的用户信息
  109. Returns:
  110. 本地用户对象
  111. """
  112. with get_db_connection() as conn:
  113. cursor = conn.cursor()
  114. # 提取用户信息(兼容不同的字段名)
  115. oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
  116. username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
  117. email = oauth_user_info.get("email", "")
  118. if not oauth_id:
  119. raise ValueError("OAuth 用户信息缺少 ID 字段")
  120. if not username:
  121. raise ValueError("OAuth 用户信息缺少用户名字段")
  122. # 查找是否已存在该 OAuth 用户
  123. cursor.execute(
  124. "SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
  125. ("sso", oauth_id)
  126. )
  127. row = cursor.fetchone()
  128. if row:
  129. # 用户已存在,更新信息
  130. user = User.from_row(row)
  131. # 更新用户名和邮箱(如果有变化)
  132. cursor.execute("""
  133. UPDATE users
  134. SET username = ?, email = ?
  135. WHERE id = ?
  136. """, (username, email, user.id))
  137. conn.commit()
  138. # 重新查询更新后的用户
  139. cursor.execute("SELECT * FROM users WHERE id = ?", (user.id,))
  140. row = cursor.fetchone()
  141. return User.from_row(row)
  142. else:
  143. # 新用户,创建记录
  144. user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
  145. # 暂时所有用户都是 annotator 角色(SSO 未提供角色信息)
  146. role = "annotator"
  147. cursor.execute("""
  148. INSERT INTO users (
  149. id, username, email, password_hash, role,
  150. oauth_provider, oauth_id, created_at
  151. ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
  152. """, (
  153. user_id,
  154. username,
  155. email,
  156. "", # OAuth 用户不需要密码
  157. role,
  158. "sso",
  159. oauth_id,
  160. datetime.now()
  161. ))
  162. conn.commit()
  163. # 查询新创建的用户
  164. cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
  165. row = cursor.fetchone()
  166. return User.from_row(row)