""" SSO (LQAI-middle-platform) OAuth2 integration. Implements the code exchange flow: code -> SSO access_token -> userinfo -> local JWT. """ import logging from typing import Optional, Dict, Any from urllib.parse import urlencode import httpx from gpustack.config.config import Config from gpustack.security import JWTManager, get_secret_hash from gpustack.server.services import create_user_with_principal logger = logging.getLogger(__name__) SSO_TOKEN_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0) SSO_USERINFO_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0) def build_sso_authorize_url(config: Config, redirect: bool = False) -> str: """Build the SSO OAuth2 authorization URL.""" params = { "response_type": "code", "client_id": config.sso_client_id, "redirect_uri": config.sso_redirect_uri, "scope": config.sso_scope, } authorize_url = f"{config.sso_base_url}/oauth/authorize?{urlencode(params)}" return authorize_url async def exchange_code_for_sso_token( config: Config, code: str ) -> Dict[str, Any]: """ Step 4a: Use authorization code to get SSO access_token. POST {SSO_BASE_URL}/oauth/token """ data = { "grant_type": "authorization_code", "code": code, "redirect_uri": config.sso_redirect_uri, "client_id": config.sso_client_id, "client_secret": config.sso_client_secret, } async with httpx.AsyncClient( timeout=SSO_TOKEN_TIMEOUT, verify=not config.sso_base_url.startswith("http://"), ) as client: resp = await client.post( f"{config.sso_base_url}/oauth/token", data=data, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) if resp.status_code != 200: logger.error(f"SSO token exchange failed: {resp.status_code} {resp.text}") error_data = resp.json() if resp.text else {} error = error_data.get("error", "unknown_error") error_desc = error_data.get("error_description", "令牌交换失败") raise Exception(f"SSO token exchange failed: {error} - {error_desc}") token_data = resp.json() if "access_token" not in token_data: logger.error( f"SSO token exchange returned 200 but missing access_token. " f"Response: {resp.text}. " f"Request data: redirect_uri={config.sso_redirect_uri}, " f"client_id={config.sso_client_id}" ) raise Exception( f"SSO token exchange succeeded but no access_token in response: {resp.text}" ) return token_data async def get_sso_userinfo( config: Config, access_token: str ) -> Dict[str, Any]: """ Step 4b: Get user info from SSO platform. GET {SSO_BASE_URL}/oauth/userinfo """ async with httpx.AsyncClient( timeout=SSO_USERINFO_TIMEOUT, verify=not config.sso_base_url.startswith("http://"), ) as client: resp = await client.get( f"{config.sso_base_url}/oauth/userinfo", headers={"Authorization": f"Bearer {access_token}"}, ) if resp.status_code != 200: logger.error(f"SSO userinfo failed: {resp.status_code} {resp.text}") raise Exception("获取用户信息失败") return resp.json() def extract_role_codes(userinfo: Dict[str, Any]) -> list: """Extract role codes from SSO userinfo roles field.""" roles = userinfo.get("roles", []) role_codes = [] for role in roles: if isinstance(role, dict): code = role.get("code") if code: role_codes.append(code) elif isinstance(role, str): role_codes.append(role) return role_codes async def sync_user_from_sso( session, config: Config, userinfo: Dict[str, Any], ) -> Any: """ Step 5: Sync user from SSO to local database. Find or create user, sync roles. """ username = userinfo.get("username") or userinfo.get("sub") if not username: raise Exception("SSO 返回的用户信息中缺少 username") email = userinfo.get("email", "") full_name = userinfo.get("real_name", username) avatar_url = userinfo.get("avatar_url", "") role_codes = extract_role_codes(userinfo) is_admin = "super_admin" in role_codes # Find existing user by username from gpustack.schemas.users import User, AuthProviderEnum existing = await User.first_by_field( session, "username", username ) if existing: # Update user info patch = { "full_name": full_name, "avatar_url": avatar_url, "is_admin": is_admin, "is_active": True, "source": AuthProviderEnum.OIDC, } await existing.update(session, patch) logger.info(f"Updated SSO user: {username}") return existing else: # Create new user — construct User object first, then persist # with create_user_with_principal (session, user) signature. import secrets random_password = secrets.token_urlsafe(32) user_info = User( username=username, full_name=full_name, avatar_url=avatar_url, hashed_password=get_secret_hash(random_password), is_admin=is_admin, is_active=True, source=AuthProviderEnum.OIDC, require_password_change=False, ) user = await create_user_with_principal(session, user_info) logger.info(f"Created SSO user: {username}") return user async def handle_sso_exchange_code( session, config: Config, code: str, jwt_manager, ) -> Dict[str, Any]: """ Core SSO exchange code flow (Steps 4-6): 1. Exchange code for SSO access_token 2. Get user info from SSO 3. Sync user to local DB 4. Issue local JWT """ # Step 4a: Get SSO access_token token_data = await exchange_code_for_sso_token(config, code) sso_access_token = token_data.get("access_token") if not sso_access_token: raise Exception("获取 SSO access_token 失败") # Step 4b: Get user info userinfo = await get_sso_userinfo(config, sso_access_token) if not userinfo.get("username") and not userinfo.get("sub"): raise Exception("SSO 用户信息格式异常") # Step 5: Sync user user = await sync_user_from_sso(session, config, userinfo) await session.commit() # Step 6: Issue local JWT local_token = jwt_manager.create_jwt_token(username=user.username) # Build user response role_codes = extract_role_codes(userinfo) user_data = { "id": str(user.id), "username": user.username, "email": userinfo.get("email", ""), "phone": userinfo.get("phone", ""), "full_name": user.full_name, "avatar_url": user.avatar_url, "is_superuser": user.is_admin, "is_active": user.is_active, "roles": role_codes, } return { "token": local_token, "refresh_token": "", # SSO flow doesn't need refresh token for now "user": user_data, }