oauth_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. """
  2. OAuth 2.0 认证服务
  3. 处理与 OAuth 认证中心的交互,包括 token 验证和刷新
  4. """
  5. import httpx
  6. import logging
  7. import secrets
  8. from typing import Dict, Any, Optional
  9. from datetime import datetime
  10. from fastapi import HTTPException, status
  11. from config import settings
  12. from models import User
  13. from database import get_db_connection
  14. logger = logging.getLogger(__name__)
  15. # SSO 角色 → 本地角色映射
  16. SSO_ROLE_MAPPING = {
  17. "super_admin": "admin",
  18. "label_admin": "admin",
  19. "admin": "admin",
  20. "labeler": "annotator",
  21. }
  22. DEFAULT_LOCAL_ROLE = "viewer"
  23. def map_sso_roles_to_local(sso_roles: list, is_superuser: bool = False) -> str:
  24. """
  25. 将 SSO 角色列表映射为本地单一角色。
  26. 优先级: admin > annotator > viewer
  27. """
  28. if is_superuser:
  29. return "admin"
  30. local_role = DEFAULT_LOCAL_ROLE
  31. for sso_role in sso_roles:
  32. mapped = SSO_ROLE_MAPPING.get(sso_role)
  33. if mapped == "admin":
  34. return "admin"
  35. if mapped == "annotator":
  36. local_role = "annotator"
  37. return local_role
  38. class OAuthService:
  39. """OAuth 2.0 认证服务"""
  40. @staticmethod
  41. def generate_state() -> str:
  42. """
  43. 生成随机 state 参数,用于防止 CSRF 攻击
  44. Returns:
  45. 随机字符串
  46. """
  47. return secrets.token_urlsafe(32)
  48. @staticmethod
  49. def get_authorization_url(state: str) -> str:
  50. """
  51. 构建 OAuth 授权 URL
  52. Args:
  53. state: 防CSRF的随机字符串
  54. Returns:
  55. 完整的授权URL
  56. """
  57. from urllib.parse import urlencode
  58. params = {
  59. "response_type": "code",
  60. "client_id": settings.OAUTH_CLIENT_ID,
  61. "redirect_uri": settings.OAUTH_REDIRECT_URI,
  62. "scope": settings.OAUTH_SCOPE,
  63. "state": state
  64. }
  65. authorize_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_AUTHORIZE_ENDPOINT}"
  66. return f"{authorize_url}?{urlencode(params)}"
  67. @staticmethod
  68. async def exchange_code_for_token(code: str) -> Dict[str, Any]:
  69. """
  70. 用授权码换取访问令牌
  71. Args:
  72. code: OAuth 授权码
  73. Returns:
  74. 令牌信息字典,包含 access_token, token_type, expires_in 等
  75. Raises:
  76. Exception: 令牌交换失败
  77. """
  78. token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
  79. async with httpx.AsyncClient() as client:
  80. response = await client.post(
  81. token_url,
  82. data={
  83. "grant_type": "authorization_code",
  84. "code": code,
  85. "redirect_uri": settings.OAUTH_REDIRECT_URI,
  86. "client_id": settings.OAUTH_CLIENT_ID,
  87. "client_secret": settings.OAUTH_CLIENT_SECRET
  88. },
  89. headers={"Content-Type": "application/x-www-form-urlencoded"}
  90. )
  91. if response.status_code != 200:
  92. raise Exception(f"令牌交换失败 ({response.status_code}): {response.text}")
  93. data = response.json()
  94. # 处理不同的响应格式
  95. if "access_token" in data:
  96. return data
  97. # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
  98. code = data.get("code")
  99. if (code == 0 or code == "000000") and "data" in data:
  100. return data["data"]
  101. else:
  102. raise Exception(f"无效的令牌响应格式: {data}")
  103. @staticmethod
  104. async def get_user_info(access_token: str) -> Dict[str, Any]:
  105. """
  106. 使用访问令牌获取用户信息
  107. Args:
  108. access_token: OAuth 访问令牌
  109. Returns:
  110. 用户信息字典
  111. Raises:
  112. Exception: 获取用户信息失败
  113. """
  114. userinfo_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_USERINFO_ENDPOINT}"
  115. async with httpx.AsyncClient() as client:
  116. response = await client.get(
  117. userinfo_url,
  118. headers={"Authorization": f"Bearer {access_token}"}
  119. )
  120. if response.status_code != 200:
  121. raise Exception(f"获取用户信息失败 ({response.status_code}): {response.text}")
  122. data = response.json()
  123. # 处理不同的响应格式
  124. if "sub" in data or "id" in data:
  125. return data
  126. # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
  127. code = data.get("code")
  128. if (code == 0 or code == "000000") and "data" in data:
  129. return data["data"]
  130. else:
  131. raise Exception(f"无效的用户信息响应格式: {data}")
  132. @staticmethod
  133. def sync_user_from_oauth(oauth_user_info: Dict[str, Any]) -> User:
  134. """
  135. 从 OAuth 用户信息同步到本地数据库
  136. 如果用户不存在则创建,如果存在则更新(包括角色)
  137. Args:
  138. oauth_user_info: OAuth 返回的用户信息
  139. Returns:
  140. 本地用户对象
  141. """
  142. with get_db_connection() as conn:
  143. cursor = conn.cursor()
  144. # 提取用户信息(兼容不同的字段名)
  145. oauth_id = oauth_user_info.get("sub") or oauth_user_info.get("id")
  146. username = oauth_user_info.get("username") or oauth_user_info.get("preferred_username") or oauth_user_info.get("name")
  147. email = oauth_user_info.get("email", "")
  148. if not oauth_id:
  149. raise ValueError("OAuth 用户信息缺少 ID 字段")
  150. if not username:
  151. raise ValueError("OAuth 用户信息缺少用户名字段")
  152. # 计算本地角色
  153. sso_roles = oauth_user_info.get("sso_roles") or oauth_user_info.get("roles", [])
  154. is_superuser = bool(oauth_user_info.get("is_superuser", False))
  155. role = oauth_user_info.get("role") or map_sso_roles_to_local(sso_roles, is_superuser)
  156. # 查找是否已存在该 OAuth 用户
  157. cursor.execute(
  158. "SELECT * FROM users WHERE oauth_provider = %s AND oauth_id = %s",
  159. ("sso", oauth_id)
  160. )
  161. row = cursor.fetchone()
  162. if row:
  163. # 用户已存在,更新信息(包括角色)
  164. user = User.from_row(row)
  165. cursor.execute("""
  166. UPDATE users
  167. SET username = %s, email = %s, role = %s
  168. WHERE id = %s
  169. """, (username, email, role, user.id))
  170. conn.commit()
  171. # 重新查询更新后的用户
  172. cursor.execute("SELECT * FROM users WHERE id = %s", (user.id,))
  173. row = cursor.fetchone()
  174. return User.from_row(row)
  175. else:
  176. # 新用户,创建记录
  177. user_id = f"user_{datetime.now().strftime('%Y%m%d%H%M%S')}_{secrets.token_hex(4)}"
  178. cursor.execute("""
  179. INSERT INTO users (
  180. id, username, email, password_hash, role,
  181. oauth_provider, oauth_id, created_at
  182. ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
  183. """, (
  184. user_id,
  185. username,
  186. email,
  187. "", # OAuth 用户不需要密码
  188. role,
  189. "sso",
  190. oauth_id,
  191. datetime.now()
  192. ))
  193. conn.commit()
  194. # 查询新创建的用户
  195. cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
  196. row = cursor.fetchone()
  197. return User.from_row(row)
  198. @staticmethod
  199. async def verify_sso_token(access_token: str) -> Dict[str, Any]:
  200. """
  201. 通过 SSO 验证 token 并获取用户信息(含角色)。
  202. 使用 /api/v1/system/users/profile 端点获取完整用户信息,
  203. 包括 roles 列表和 is_superuser 标记,然后映射为本地角色。
  204. Args:
  205. access_token: SSO 访问令牌
  206. Returns:
  207. 用户信息字典 {id, username, email, role, ...}
  208. Raises:
  209. HTTPException(401): token 无效
  210. HTTPException(503): SSO 中心不可用
  211. """
  212. profile_url = f"{settings.OAUTH_BASE_URL}/api/v1/system/users/profile"
  213. async with httpx.AsyncClient(timeout=10.0) as client:
  214. try:
  215. response = await client.get(
  216. profile_url,
  217. headers={"Authorization": f"Bearer {access_token}"}
  218. )
  219. except httpx.RequestError:
  220. raise HTTPException(
  221. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  222. detail="SSO 认证中心不可用"
  223. )
  224. if response.status_code == 401:
  225. raise HTTPException(
  226. status_code=status.HTTP_401_UNAUTHORIZED,
  227. detail="无效的访问令牌"
  228. )
  229. if response.status_code != 200:
  230. raise HTTPException(
  231. status_code=status.HTTP_401_UNAUTHORIZED,
  232. detail=f"SSO 验证失败 ({response.status_code})"
  233. )
  234. data = response.json()
  235. logger.debug(f"SSO profile response: {data}")
  236. # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
  237. code = data.get("code")
  238. if (code == 0 or code == "000000") and "data" in data:
  239. profile = data["data"]
  240. elif "id" in data or "username" in data:
  241. profile = data
  242. else:
  243. logger.error(f"Invalid profile response format: {data}")
  244. raise HTTPException(
  245. status_code=status.HTTP_401_UNAUTHORIZED,
  246. detail="无效的访问令牌"
  247. )
  248. # 提取角色信息并映射
  249. sso_roles = profile.get("roles", [])
  250. is_superuser = bool(profile.get("is_superuser", False))
  251. local_role = map_sso_roles_to_local(sso_roles, is_superuser)
  252. logger.info(
  253. f"SSO 用户 {profile.get('username')}: "
  254. f"sso_roles={sso_roles}, is_superuser={is_superuser} → local_role={local_role}"
  255. )
  256. # 返回统一格式的用户信息
  257. return {
  258. "id": profile.get("id"),
  259. "username": profile.get("username"),
  260. "email": profile.get("email", ""),
  261. "role": local_role,
  262. "sso_roles": sso_roles,
  263. "is_superuser": is_superuser,
  264. }
  265. @staticmethod
  266. async def refresh_sso_token(refresh_token: str) -> Dict[str, Any]:
  267. """
  268. 向 SSO 中心刷新 token。
  269. Args:
  270. refresh_token: SSO 刷新令牌
  271. Returns:
  272. 新的 token 信息 {access_token, refresh_token, ...}
  273. Raises:
  274. HTTPException(401): refresh_token 无效
  275. HTTPException(503): SSO 中心不可用
  276. """
  277. token_url = f"{settings.OAUTH_BASE_URL}{settings.OAUTH_TOKEN_ENDPOINT}"
  278. logger.debug(f"Refreshing token at: {token_url}")
  279. async with httpx.AsyncClient(timeout=10.0) as client:
  280. try:
  281. response = await client.post(
  282. token_url,
  283. data={
  284. "grant_type": "refresh_token",
  285. "refresh_token": refresh_token,
  286. "client_id": settings.OAUTH_CLIENT_ID,
  287. "client_secret": settings.OAUTH_CLIENT_SECRET
  288. },
  289. headers={"Content-Type": "application/x-www-form-urlencoded"}
  290. )
  291. except httpx.RequestError as e:
  292. logger.error(f"SSO refresh request failed: {e}")
  293. raise HTTPException(
  294. status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
  295. detail="SSO 认证中心不可用"
  296. )
  297. logger.debug(f"SSO refresh response: status={response.status_code}")
  298. if response.status_code != 200:
  299. logger.error(f"SSO refresh failed: {response.status_code}, body={response.text}")
  300. raise HTTPException(
  301. status_code=status.HTTP_401_UNAUTHORIZED,
  302. detail="刷新令牌无效或已过期,请重新登录"
  303. )
  304. data = response.json()
  305. logger.debug(f"SSO refresh response data: {data}")
  306. # 处理包装格式 {"code": 0, "data": {...}} 或 {"code": "000000", "data": {...}}
  307. code = data.get("code")
  308. if (code == 0 or code == "000000") and "data" in data:
  309. return data["data"]
  310. elif "access_token" in data:
  311. return data
  312. else:
  313. logger.error(f"Invalid refresh response format: {data}")
  314. raise HTTPException(
  315. status_code=status.HTTP_401_UNAUTHORIZED,
  316. detail="刷新令牌无效或已过期,请重新登录"
  317. )