| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- """
- SSO (LQAI-middle-platform) OAuth2 integration.
- Implements the code exchange flow: code -> SSO access_token -> userinfo -> local JWT.
- """
- import logging
- from typing import Optional, Dict, Any
- from urllib.parse import urlencode
- import httpx
- from gpustack.config.config import Config
- from gpustack.security import JWTManager, get_secret_hash
- from gpustack.server.services import create_user_with_principal
- logger = logging.getLogger(__name__)
- SSO_TOKEN_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
- SSO_USERINFO_TIMEOUT = httpx.Timeout(connect=15.0, read=30.0, write=30.0, pool=5.0)
- def build_sso_authorize_url(config: Config, redirect: bool = False) -> str:
- """Build the SSO OAuth2 authorization URL."""
- params = {
- "response_type": "code",
- "client_id": config.sso_client_id,
- "redirect_uri": config.sso_redirect_uri,
- "scope": config.sso_scope,
- }
- authorize_url = f"{config.sso_base_url}/oauth/authorize?{urlencode(params)}"
- return authorize_url
- async def exchange_code_for_sso_token(
- config: Config, code: str
- ) -> Dict[str, Any]:
- """
- Step 4a: Use authorization code to get SSO access_token.
- POST {SSO_BASE_URL}/oauth/token
- """
- data = {
- "grant_type": "authorization_code",
- "code": code,
- "redirect_uri": config.sso_redirect_uri,
- "client_id": config.sso_client_id,
- "client_secret": config.sso_client_secret,
- }
- async with httpx.AsyncClient(
- timeout=SSO_TOKEN_TIMEOUT,
- verify=not config.sso_base_url.startswith("http://"),
- ) as client:
- resp = await client.post(
- f"{config.sso_base_url}/oauth/token",
- data=data,
- headers={"Content-Type": "application/x-www-form-urlencoded"},
- )
- if resp.status_code != 200:
- logger.error(f"SSO token exchange failed: {resp.status_code} {resp.text}")
- error_data = resp.json() if resp.text else {}
- error = error_data.get("error", "unknown_error")
- error_desc = error_data.get("error_description", "令牌交换失败")
- raise Exception(f"SSO token exchange failed: {error} - {error_desc}")
- token_data = resp.json()
- if "access_token" not in token_data:
- logger.error(
- f"SSO token exchange returned 200 but missing access_token. "
- f"Response: {resp.text}. "
- f"Request data: redirect_uri={config.sso_redirect_uri}, "
- f"client_id={config.sso_client_id}"
- )
- raise Exception(
- f"SSO token exchange succeeded but no access_token in response: {resp.text}"
- )
- return token_data
- async def get_sso_userinfo(
- config: Config, access_token: str
- ) -> Dict[str, Any]:
- """
- Step 4b: Get user info from SSO platform.
- GET {SSO_BASE_URL}/oauth/userinfo
- """
- async with httpx.AsyncClient(
- timeout=SSO_USERINFO_TIMEOUT,
- verify=not config.sso_base_url.startswith("http://"),
- ) as client:
- resp = await client.get(
- f"{config.sso_base_url}/oauth/userinfo",
- headers={"Authorization": f"Bearer {access_token}"},
- )
- if resp.status_code != 200:
- logger.error(f"SSO userinfo failed: {resp.status_code} {resp.text}")
- raise Exception("获取用户信息失败")
- return resp.json()
- def extract_role_codes(userinfo: Dict[str, Any]) -> list:
- """Extract role codes from SSO userinfo roles field."""
- roles = userinfo.get("roles", [])
- role_codes = []
- for role in roles:
- if isinstance(role, dict):
- code = role.get("code")
- if code:
- role_codes.append(code)
- elif isinstance(role, str):
- role_codes.append(role)
- return role_codes
- async def sync_user_from_sso(
- session,
- config: Config,
- userinfo: Dict[str, Any],
- ) -> Any:
- """
- Step 5: Sync user from SSO to local database.
- Find or create user, sync roles.
- """
- username = userinfo.get("username") or userinfo.get("sub")
- if not username:
- raise Exception("SSO 返回的用户信息中缺少 username")
- email = userinfo.get("email", "")
- full_name = userinfo.get("real_name", username)
- avatar_url = userinfo.get("avatar_url", "")
- role_codes = extract_role_codes(userinfo)
- is_admin = "super_admin" in role_codes
- # Find existing user by username
- from gpustack.schemas.users import User, AuthProviderEnum
- existing = await User.first_by_field(
- session, "username", username
- )
- if existing:
- # Update user info
- patch = {
- "full_name": full_name,
- "avatar_url": avatar_url,
- "is_admin": is_admin,
- "is_active": True,
- "source": AuthProviderEnum.OIDC,
- }
- await existing.update(session, patch)
- logger.info(f"Updated SSO user: {username}")
- return existing
- else:
- # Create new user — construct User object first, then persist
- # with create_user_with_principal (session, user) signature.
- import secrets
- random_password = secrets.token_urlsafe(32)
- user_info = User(
- username=username,
- full_name=full_name,
- avatar_url=avatar_url,
- hashed_password=get_secret_hash(random_password),
- is_admin=is_admin,
- is_active=True,
- source=AuthProviderEnum.OIDC,
- require_password_change=False,
- )
- user = await create_user_with_principal(session, user_info)
- logger.info(f"Created SSO user: {username}")
- return user
- async def handle_sso_exchange_code(
- session,
- config: Config,
- code: str,
- jwt_manager,
- ) -> Dict[str, Any]:
- """
- Core SSO exchange code flow (Steps 4-6):
- 1. Exchange code for SSO access_token
- 2. Get user info from SSO
- 3. Sync user to local DB
- 4. Issue local JWT
- """
- # Step 4a: Get SSO access_token
- token_data = await exchange_code_for_sso_token(config, code)
- sso_access_token = token_data.get("access_token")
- if not sso_access_token:
- raise Exception("获取 SSO access_token 失败")
- # Step 4b: Get user info
- userinfo = await get_sso_userinfo(config, sso_access_token)
- if not userinfo.get("username") and not userinfo.get("sub"):
- raise Exception("SSO 用户信息格式异常")
- # Step 5: Sync user
- user = await sync_user_from_sso(session, config, userinfo)
- await session.commit()
- # Step 6: Issue local JWT
- local_token = jwt_manager.create_jwt_token(username=user.username)
- # Build user response
- role_codes = extract_role_codes(userinfo)
- user_data = {
- "id": str(user.id),
- "username": user.username,
- "email": userinfo.get("email", ""),
- "phone": userinfo.get("phone", ""),
- "full_name": user.full_name,
- "avatar_url": user.avatar_url,
- "is_superuser": user.is_admin,
- "is_active": user.is_active,
- "roles": role_codes,
- }
- return {
- "token": local_token,
- "refresh_token": "", # SSO flow doesn't need refresh token for now
- "user": user_data,
- }
|