sso.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. SSO (LQAI-middle-platform) OAuth2 integration.
  3. Implements the code exchange flow: code -> SSO access_token -> userinfo -> local JWT.
  4. """
  5. import logging
  6. from typing import Optional, Dict, Any
  7. from urllib.parse import urlencode
  8. import httpx
  9. from gpustack.config.config import Config
  10. from gpustack.security import JWTManager
  11. from gpustack.server.services import create_user_with_principal
  12. logger = logging.getLogger(__name__)
  13. SSO_TOKEN_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
  14. SSO_USERINFO_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
  15. def build_sso_authorize_url(config: Config, redirect: bool = False) -> str:
  16. """Build the SSO OAuth2 authorization URL."""
  17. params = {
  18. "response_type": "code",
  19. "client_id": config.sso_client_id,
  20. "redirect_uri": config.sso_redirect_uri,
  21. "scope": config.sso_scope,
  22. }
  23. authorize_url = f"{config.sso_base_url}/oauth/authorize?{urlencode(params)}"
  24. return authorize_url
  25. async def exchange_code_for_sso_token(
  26. config: Config, code: str
  27. ) -> Dict[str, Any]:
  28. """
  29. Step 4a: Use authorization code to get SSO access_token.
  30. POST {SSO_BASE_URL}/oauth/token
  31. """
  32. data = {
  33. "grant_type": "authorization_code",
  34. "code": code,
  35. "redirect_uri": config.sso_redirect_uri,
  36. "client_id": config.sso_client_id,
  37. "client_secret": config.sso_client_secret,
  38. }
  39. async with httpx.AsyncClient(
  40. timeout=SSO_TOKEN_TIMEOUT,
  41. verify=not config.sso_base_url.startswith("http://"),
  42. ) as client:
  43. resp = await client.post(
  44. f"{config.sso_base_url}/oauth/token",
  45. data=data,
  46. headers={"Content-Type": "application/x-www-form-urlencoded"},
  47. )
  48. if resp.status_code != 200:
  49. logger.error(f"SSO token exchange failed: {resp.status_code} {resp.text}")
  50. error_data = resp.json() if resp.text else {}
  51. error = error_data.get("error", "unknown_error")
  52. error_desc = error_data.get("error_description", "令牌交换失败")
  53. raise Exception(f"SSO token exchange failed: {error} - {error_desc}")
  54. return resp.json()
  55. async def get_sso_userinfo(
  56. config: Config, access_token: str
  57. ) -> Dict[str, Any]:
  58. """
  59. Step 4b: Get user info from SSO platform.
  60. GET {SSO_BASE_URL}/oauth/userinfo
  61. """
  62. async with httpx.AsyncClient(
  63. timeout=SSO_USERINFO_TIMEOUT,
  64. verify=not config.sso_base_url.startswith("http://"),
  65. ) as client:
  66. resp = await client.get(
  67. f"{config.sso_base_url}/oauth/userinfo",
  68. headers={"Authorization": f"Bearer {access_token}"},
  69. )
  70. if resp.status_code != 200:
  71. logger.error(f"SSO userinfo failed: {resp.status_code} {resp.text}")
  72. raise Exception("获取用户信息失败")
  73. return resp.json()
  74. def extract_role_codes(userinfo: Dict[str, Any]) -> list:
  75. """Extract role codes from SSO userinfo roles field."""
  76. roles = userinfo.get("roles", [])
  77. role_codes = []
  78. for role in roles:
  79. if isinstance(role, dict):
  80. code = role.get("code")
  81. if code:
  82. role_codes.append(code)
  83. elif isinstance(role, str):
  84. role_codes.append(role)
  85. return role_codes
  86. async def sync_user_from_sso(
  87. session,
  88. config: Config,
  89. userinfo: Dict[str, Any],
  90. ) -> Any:
  91. """
  92. Step 5: Sync user from SSO to local database.
  93. Find or create user, sync roles.
  94. """
  95. username = userinfo.get("username") or userinfo.get("sub")
  96. if not username:
  97. raise Exception("SSO 返回的用户信息中缺少 username")
  98. email = userinfo.get("email", "")
  99. full_name = userinfo.get("real_name", username)
  100. avatar_url = userinfo.get("avatar_url", "")
  101. role_codes = extract_role_codes(userinfo)
  102. is_admin = "super_admin" in role_codes
  103. # Find existing user by username
  104. from gpustack.schemas.users import User, AuthProviderEnum
  105. existing = await User.first_by_field(
  106. session, "username", username
  107. )
  108. if existing:
  109. # Update user info
  110. patch = {
  111. "full_name": full_name,
  112. "avatar_url": avatar_url,
  113. "is_admin": is_admin,
  114. "is_active": True,
  115. "source": AuthProviderEnum.OIDC,
  116. }
  117. await existing.update(session, patch)
  118. logger.info(f"Updated SSO user: {username}")
  119. return existing
  120. else:
  121. # Create new user
  122. # SSO users don't have a local password; generate a random one
  123. import secrets
  124. random_password = secrets.token_urlsafe(32)
  125. user = await create_user_with_principal(
  126. session=session,
  127. username=username,
  128. password=random_password,
  129. is_admin=is_admin,
  130. full_name=full_name,
  131. avatar_url=avatar_url,
  132. source=AuthProviderEnum.OIDC,
  133. )
  134. logger.info(f"Created SSO user: {username}")
  135. return user
  136. async def handle_sso_exchange_code(
  137. session,
  138. config: Config,
  139. code: str,
  140. jwt_manager,
  141. ) -> Dict[str, Any]:
  142. """
  143. Core SSO exchange code flow (Steps 4-6):
  144. 1. Exchange code for SSO access_token
  145. 2. Get user info from SSO
  146. 3. Sync user to local DB
  147. 4. Issue local JWT
  148. """
  149. # Step 4a: Get SSO access_token
  150. token_data = await exchange_code_for_sso_token(config, code)
  151. sso_access_token = token_data.get("access_token")
  152. if not sso_access_token:
  153. raise Exception("获取 SSO access_token 失败")
  154. # Step 4b: Get user info
  155. userinfo = await get_sso_userinfo(config, sso_access_token)
  156. if not userinfo.get("username") and not userinfo.get("sub"):
  157. raise Exception("SSO 用户信息格式异常")
  158. # Step 5: Sync user
  159. user = await sync_user_from_sso(session, config, userinfo)
  160. # Step 6: Issue local JWT
  161. local_token = jwt_manager.create_jwt_token(username=user.username)
  162. # Build user response
  163. role_codes = extract_role_codes(userinfo)
  164. user_data = {
  165. "id": str(user.id),
  166. "username": user.username,
  167. "email": userinfo.get("email", ""),
  168. "phone": userinfo.get("phone", ""),
  169. "full_name": user.full_name,
  170. "avatar_url": user.avatar_url,
  171. "is_superuser": user.is_admin,
  172. "is_active": user.is_active,
  173. "roles": role_codes,
  174. }
  175. return {
  176. "token": local_token,
  177. "refresh_token": "", # SSO flow doesn't need refresh token for now
  178. "user": user_data,
  179. }