sso.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. """
  2. SSO (LQAI-middle-platform) OAuth2 integration.
  3. Implements the code exchange flow: code -> SSO access_token -> userinfo -> local JWT.
  4. """
  5. import logging
  6. from typing import Optional, Dict, Any
  7. from urllib.parse import urlencode
  8. import httpx
  9. from gpustack.config.config import Config
  10. from gpustack.security import JWTManager
  11. from gpustack.server.services import create_user_with_principal
  12. logger = logging.getLogger(__name__)
  13. SSO_TOKEN_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
  14. SSO_USERINFO_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
  15. def build_sso_authorize_url(config: Config, redirect: bool = False) -> str:
  16. """Build the SSO OAuth2 authorization URL."""
  17. params = {
  18. "response_type": "code",
  19. "client_id": config.sso_client_id,
  20. "redirect_uri": config.sso_redirect_uri,
  21. "scope": config.sso_scope,
  22. }
  23. authorize_url = f"{config.sso_base_url}/oauth/authorize?{urlencode(params)}"
  24. return authorize_url
  25. async def exchange_code_for_sso_token(
  26. config: Config, code: str
  27. ) -> Dict[str, Any]:
  28. """
  29. Step 4a: Use authorization code to get SSO access_token.
  30. POST {SSO_BASE_URL}/oauth/token
  31. """
  32. data = {
  33. "grant_type": "authorization_code",
  34. "code": code,
  35. "redirect_uri": config.sso_redirect_uri,
  36. "client_id": config.sso_client_id,
  37. "client_secret": config.sso_client_secret,
  38. }
  39. async with httpx.AsyncClient(
  40. timeout=SSO_TOKEN_TIMEOUT,
  41. verify=not config.sso_base_url.startswith("http://"),
  42. ) as client:
  43. resp = await client.post(
  44. f"{config.sso_base_url}/oauth/token",
  45. data=data,
  46. headers={"Content-Type": "application/x-www-form-urlencoded"},
  47. )
  48. if resp.status_code != 200:
  49. logger.error(f"SSO token exchange failed: {resp.status_code} {resp.text}")
  50. error_data = resp.json() if resp.text else {}
  51. error = error_data.get("error", "unknown_error")
  52. error_desc = error_data.get("error_description", "令牌交换失败")
  53. raise Exception(f"SSO token exchange failed: {error} - {error_desc}")
  54. token_data = resp.json()
  55. if "access_token" not in token_data:
  56. logger.error(
  57. f"SSO token exchange returned 200 but missing access_token. "
  58. f"Response: {resp.text}. "
  59. f"Request data: redirect_uri={config.sso_redirect_uri}, "
  60. f"client_id={config.sso_client_id}"
  61. )
  62. raise Exception(
  63. f"SSO token exchange succeeded but no access_token in response: {resp.text}"
  64. )
  65. return token_data
  66. async def get_sso_userinfo(
  67. config: Config, access_token: str
  68. ) -> Dict[str, Any]:
  69. """
  70. Step 4b: Get user info from SSO platform.
  71. GET {SSO_BASE_URL}/oauth/userinfo
  72. """
  73. async with httpx.AsyncClient(
  74. timeout=SSO_USERINFO_TIMEOUT,
  75. verify=not config.sso_base_url.startswith("http://"),
  76. ) as client:
  77. resp = await client.get(
  78. f"{config.sso_base_url}/oauth/userinfo",
  79. headers={"Authorization": f"Bearer {access_token}"},
  80. )
  81. if resp.status_code != 200:
  82. logger.error(f"SSO userinfo failed: {resp.status_code} {resp.text}")
  83. raise Exception("获取用户信息失败")
  84. return resp.json()
  85. def extract_role_codes(userinfo: Dict[str, Any]) -> list:
  86. """Extract role codes from SSO userinfo roles field."""
  87. roles = userinfo.get("roles", [])
  88. role_codes = []
  89. for role in roles:
  90. if isinstance(role, dict):
  91. code = role.get("code")
  92. if code:
  93. role_codes.append(code)
  94. elif isinstance(role, str):
  95. role_codes.append(role)
  96. return role_codes
  97. async def sync_user_from_sso(
  98. session,
  99. config: Config,
  100. userinfo: Dict[str, Any],
  101. ) -> Any:
  102. """
  103. Step 5: Sync user from SSO to local database.
  104. Find or create user, sync roles.
  105. """
  106. username = userinfo.get("username") or userinfo.get("sub")
  107. if not username:
  108. raise Exception("SSO 返回的用户信息中缺少 username")
  109. email = userinfo.get("email", "")
  110. full_name = userinfo.get("real_name", username)
  111. avatar_url = userinfo.get("avatar_url", "")
  112. role_codes = extract_role_codes(userinfo)
  113. is_admin = "super_admin" in role_codes
  114. # Find existing user by username
  115. from gpustack.schemas.users import User, AuthProviderEnum
  116. existing = await User.first_by_field(
  117. session, "username", username
  118. )
  119. if existing:
  120. # Update user info
  121. patch = {
  122. "full_name": full_name,
  123. "avatar_url": avatar_url,
  124. "is_admin": is_admin,
  125. "is_active": True,
  126. "source": AuthProviderEnum.OIDC,
  127. }
  128. await existing.update(session, patch)
  129. logger.info(f"Updated SSO user: {username}")
  130. return existing
  131. else:
  132. # Create new user
  133. # SSO users don't have a local password; generate a random one
  134. import secrets
  135. random_password = secrets.token_urlsafe(32)
  136. user = await create_user_with_principal(
  137. session=session,
  138. username=username,
  139. password=random_password,
  140. is_admin=is_admin,
  141. full_name=full_name,
  142. avatar_url=avatar_url,
  143. source=AuthProviderEnum.OIDC,
  144. )
  145. logger.info(f"Created SSO user: {username}")
  146. return user
  147. async def handle_sso_exchange_code(
  148. session,
  149. config: Config,
  150. code: str,
  151. jwt_manager,
  152. ) -> Dict[str, Any]:
  153. """
  154. Core SSO exchange code flow (Steps 4-6):
  155. 1. Exchange code for SSO access_token
  156. 2. Get user info from SSO
  157. 3. Sync user to local DB
  158. 4. Issue local JWT
  159. """
  160. # Step 4a: Get SSO access_token
  161. token_data = await exchange_code_for_sso_token(config, code)
  162. sso_access_token = token_data.get("access_token")
  163. if not sso_access_token:
  164. raise Exception("获取 SSO access_token 失败")
  165. # Step 4b: Get user info
  166. userinfo = await get_sso_userinfo(config, sso_access_token)
  167. if not userinfo.get("username") and not userinfo.get("sub"):
  168. raise Exception("SSO 用户信息格式异常")
  169. # Step 5: Sync user
  170. user = await sync_user_from_sso(session, config, userinfo)
  171. # Step 6: Issue local JWT
  172. local_token = jwt_manager.create_jwt_token(username=user.username)
  173. # Build user response
  174. role_codes = extract_role_codes(userinfo)
  175. user_data = {
  176. "id": str(user.id),
  177. "username": user.username,
  178. "email": userinfo.get("email", ""),
  179. "phone": userinfo.get("phone", ""),
  180. "full_name": user.full_name,
  181. "avatar_url": user.avatar_url,
  182. "is_superuser": user.is_admin,
  183. "is_active": user.is_active,
  184. "roles": role_codes,
  185. }
  186. return {
  187. "token": local_token,
  188. "refresh_token": "", # SSO flow doesn't need refresh token for now
  189. "user": user_data,
  190. }