|
@@ -0,0 +1,212 @@
|
|
|
|
|
+"""
|
|
|
|
|
+SSO (LQAI-middle-platform) OAuth2 integration.
|
|
|
|
|
+Implements the code exchange flow: code -> SSO access_token -> userinfo -> local JWT.
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import logging
|
|
|
|
|
+from typing import Optional, Dict, Any
|
|
|
|
|
+from urllib.parse import urlencode
|
|
|
|
|
+
|
|
|
|
|
+import httpx
|
|
|
|
|
+
|
|
|
|
|
+from gpustack.config.config import Config
|
|
|
|
|
+from gpustack.security import JWTManager
|
|
|
|
|
+from gpustack.server.services import create_user_with_principal
|
|
|
|
|
+
|
|
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
+
|
|
|
|
|
+SSO_TOKEN_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
|
|
|
|
|
+SSO_USERINFO_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_sso_authorize_url(config: Config, redirect: bool = False) -> str:
|
|
|
|
|
+ """Build the SSO OAuth2 authorization URL."""
|
|
|
|
|
+ params = {
|
|
|
|
|
+ "response_type": "code",
|
|
|
|
|
+ "client_id": config.sso_client_id,
|
|
|
|
|
+ "redirect_uri": config.sso_redirect_uri,
|
|
|
|
|
+ "scope": config.sso_scope,
|
|
|
|
|
+ }
|
|
|
|
|
+ authorize_url = f"{config.sso_base_url}/oauth/authorize?{urlencode(params)}"
|
|
|
|
|
+ return authorize_url
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def exchange_code_for_sso_token(
|
|
|
|
|
+ config: Config, code: str
|
|
|
|
|
+) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Step 4a: Use authorization code to get SSO access_token.
|
|
|
|
|
+ POST {SSO_BASE_URL}/oauth/token
|
|
|
|
|
+ """
|
|
|
|
|
+ data = {
|
|
|
|
|
+ "grant_type": "authorization_code",
|
|
|
|
|
+ "code": code,
|
|
|
|
|
+ "redirect_uri": config.sso_redirect_uri,
|
|
|
|
|
+ "client_id": config.sso_client_id,
|
|
|
|
|
+ "client_secret": config.sso_client_secret,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ async with httpx.AsyncClient(
|
|
|
|
|
+ timeout=SSO_TOKEN_TIMEOUT,
|
|
|
|
|
+ verify=not config.sso_base_url.startswith("http://"),
|
|
|
|
|
+ ) as client:
|
|
|
|
|
+ resp = await client.post(
|
|
|
|
|
+ f"{config.sso_base_url}/oauth/token",
|
|
|
|
|
+ data=data,
|
|
|
|
|
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if resp.status_code != 200:
|
|
|
|
|
+ logger.error(f"SSO token exchange failed: {resp.status_code} {resp.text}")
|
|
|
|
|
+ error_data = resp.json() if resp.text else {}
|
|
|
|
|
+ error = error_data.get("error", "unknown_error")
|
|
|
|
|
+ error_desc = error_data.get("error_description", "令牌交换失败")
|
|
|
|
|
+ raise Exception(f"SSO token exchange failed: {error} - {error_desc}")
|
|
|
|
|
+
|
|
|
|
|
+ return resp.json()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def get_sso_userinfo(
|
|
|
|
|
+ config: Config, access_token: str
|
|
|
|
|
+) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Step 4b: Get user info from SSO platform.
|
|
|
|
|
+ GET {SSO_BASE_URL}/oauth/userinfo
|
|
|
|
|
+ """
|
|
|
|
|
+ async with httpx.AsyncClient(
|
|
|
|
|
+ timeout=SSO_USERINFO_TIMEOUT,
|
|
|
|
|
+ verify=not config.sso_base_url.startswith("http://"),
|
|
|
|
|
+ ) as client:
|
|
|
|
|
+ resp = await client.get(
|
|
|
|
|
+ f"{config.sso_base_url}/oauth/userinfo",
|
|
|
|
|
+ headers={"Authorization": f"Bearer {access_token}"},
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if resp.status_code != 200:
|
|
|
|
|
+ logger.error(f"SSO userinfo failed: {resp.status_code} {resp.text}")
|
|
|
|
|
+ raise Exception("获取用户信息失败")
|
|
|
|
|
+
|
|
|
|
|
+ return resp.json()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def extract_role_codes(userinfo: Dict[str, Any]) -> list:
|
|
|
|
|
+ """Extract role codes from SSO userinfo roles field."""
|
|
|
|
|
+ roles = userinfo.get("roles", [])
|
|
|
|
|
+ role_codes = []
|
|
|
|
|
+ for role in roles:
|
|
|
|
|
+ if isinstance(role, dict):
|
|
|
|
|
+ code = role.get("code")
|
|
|
|
|
+ if code:
|
|
|
|
|
+ role_codes.append(code)
|
|
|
|
|
+ elif isinstance(role, str):
|
|
|
|
|
+ role_codes.append(role)
|
|
|
|
|
+ return role_codes
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def sync_user_from_sso(
|
|
|
|
|
+ session,
|
|
|
|
|
+ config: Config,
|
|
|
|
|
+ userinfo: Dict[str, Any],
|
|
|
|
|
+) -> Any:
|
|
|
|
|
+ """
|
|
|
|
|
+ Step 5: Sync user from SSO to local database.
|
|
|
|
|
+ Find or create user, sync roles.
|
|
|
|
|
+ """
|
|
|
|
|
+ username = userinfo.get("username") or userinfo.get("sub")
|
|
|
|
|
+ if not username:
|
|
|
|
|
+ raise Exception("SSO 返回的用户信息中缺少 username")
|
|
|
|
|
+
|
|
|
|
|
+ email = userinfo.get("email", "")
|
|
|
|
|
+ full_name = userinfo.get("real_name", username)
|
|
|
|
|
+ avatar_url = userinfo.get("avatar_url", "")
|
|
|
|
|
+ role_codes = extract_role_codes(userinfo)
|
|
|
|
|
+
|
|
|
|
|
+ is_admin = "super_admin" in role_codes
|
|
|
|
|
+
|
|
|
|
|
+ # Find existing user by username
|
|
|
|
|
+ from gpustack.schemas.users import User, AuthProviderEnum
|
|
|
|
|
+
|
|
|
|
|
+ existing = await User.first_by_field(
|
|
|
|
|
+ session, "username", username
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if existing:
|
|
|
|
|
+ # Update user info
|
|
|
|
|
+ patch = {
|
|
|
|
|
+ "full_name": full_name,
|
|
|
|
|
+ "avatar_url": avatar_url,
|
|
|
|
|
+ "is_admin": is_admin,
|
|
|
|
|
+ "is_active": True,
|
|
|
|
|
+ "source": AuthProviderEnum.OIDC,
|
|
|
|
|
+ }
|
|
|
|
|
+ await existing.update(session, patch)
|
|
|
|
|
+ logger.info(f"Updated SSO user: {username}")
|
|
|
|
|
+ return existing
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Create new user
|
|
|
|
|
+ # SSO users don't have a local password; generate a random one
|
|
|
|
|
+ import secrets
|
|
|
|
|
+ random_password = secrets.token_urlsafe(32)
|
|
|
|
|
+
|
|
|
|
|
+ user = await create_user_with_principal(
|
|
|
|
|
+ session=session,
|
|
|
|
|
+ username=username,
|
|
|
|
|
+ password=random_password,
|
|
|
|
|
+ is_admin=is_admin,
|
|
|
|
|
+ full_name=full_name,
|
|
|
|
|
+ avatar_url=avatar_url,
|
|
|
|
|
+ source=AuthProviderEnum.OIDC,
|
|
|
|
|
+ )
|
|
|
|
|
+ logger.info(f"Created SSO user: {username}")
|
|
|
|
|
+ return user
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def handle_sso_exchange_code(
|
|
|
|
|
+ session,
|
|
|
|
|
+ config: Config,
|
|
|
|
|
+ code: str,
|
|
|
|
|
+ jwt_manager,
|
|
|
|
|
+) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ Core SSO exchange code flow (Steps 4-6):
|
|
|
|
|
+ 1. Exchange code for SSO access_token
|
|
|
|
|
+ 2. Get user info from SSO
|
|
|
|
|
+ 3. Sync user to local DB
|
|
|
|
|
+ 4. Issue local JWT
|
|
|
|
|
+ """
|
|
|
|
|
+ # Step 4a: Get SSO access_token
|
|
|
|
|
+ token_data = await exchange_code_for_sso_token(config, code)
|
|
|
|
|
+ sso_access_token = token_data.get("access_token")
|
|
|
|
|
+ if not sso_access_token:
|
|
|
|
|
+ raise Exception("获取 SSO access_token 失败")
|
|
|
|
|
+
|
|
|
|
|
+ # Step 4b: Get user info
|
|
|
|
|
+ userinfo = await get_sso_userinfo(config, sso_access_token)
|
|
|
|
|
+ if not userinfo.get("username") and not userinfo.get("sub"):
|
|
|
|
|
+ raise Exception("SSO 用户信息格式异常")
|
|
|
|
|
+
|
|
|
|
|
+ # Step 5: Sync user
|
|
|
|
|
+ user = await sync_user_from_sso(session, config, userinfo)
|
|
|
|
|
+
|
|
|
|
|
+ # Step 6: Issue local JWT
|
|
|
|
|
+ local_token = jwt_manager.create_jwt_token(username=user.username)
|
|
|
|
|
+
|
|
|
|
|
+ # Build user response
|
|
|
|
|
+ role_codes = extract_role_codes(userinfo)
|
|
|
|
|
+ user_data = {
|
|
|
|
|
+ "id": str(user.id),
|
|
|
|
|
+ "username": user.username,
|
|
|
|
|
+ "email": userinfo.get("email", ""),
|
|
|
|
|
+ "phone": userinfo.get("phone", ""),
|
|
|
|
|
+ "full_name": user.full_name,
|
|
|
|
|
+ "avatar_url": user.avatar_url,
|
|
|
|
|
+ "is_superuser": user.is_admin,
|
|
|
|
|
+ "is_active": user.is_active,
|
|
|
|
|
+ "roles": role_codes,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "token": local_token,
|
|
|
|
|
+ "refresh_token": "", # SSO flow doesn't need refresh token for now
|
|
|
|
|
+ "user": user_data,
|
|
|
|
|
+ }
|