| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711 |
- import json
- import os
- from pathlib import Path
- import httpx
- import logging
- import jwt
- from jwt.algorithms import RSAAlgorithm
- from gpustack.config.config import Config
- from typing import Annotated, Dict, Optional
- from fastapi import APIRouter, Form, Request, Response
- from gpustack.api.exceptions import (
- InvalidException,
- UnauthorizedException,
- BadRequestException,
- )
- from gpustack.schemas.users import UpdatePassword
- from gpustack.schemas.users import User, AuthProviderEnum
- from gpustack.security import (
- JWTManager,
- get_secret_hash,
- verify_hashed_secret,
- )
- from gpustack import envs
- from gpustack.api.auth import (
- SESSION_COOKIE_NAME,
- OIDC_ID_TOKEN_COOKIE_NAME,
- SSO_LOGIN_COOKIE_NAME,
- authenticate_user,
- )
- from gpustack.server.deps import CurrentUserDep, SessionDep
- from gpustack.server.services import (
- create_user_with_principal,
- provision_user_principal,
- )
- from onelogin.saml2.auth import OneLogin_Saml2_Auth
- from fastapi.responses import RedirectResponse
- from lxml import etree
- from gpustack.utils.convert import safe_b64decode, inflate_data
- from urllib.parse import urlencode
- from gpustack.utils.network import (
- get_system_trust_store_ssl_context,
- use_proxy_env_for_url,
- )
- router = APIRouter()
- timeout = httpx.Timeout(connect=15.0, read=60.0, write=60.0, pool=10.0)
- logger = logging.getLogger(__name__)
- async def decode_and_validate_token(
- client: httpx.AsyncClient, token: str, config: Config
- ) -> Dict:
- """
- Decode the JWT token without verification and check if required fields are present.
- Args:
- token: token from OIDC provider
- config: Application configuration
- Returns:
- Dictionary containing decoded token data
- """
- jwks_uri = config.openid_configuration["jwks_uri"]
- jwks_res = await client.get(jwks_uri)
- jwks = jwks_res.json()
- unverified_header = jwt.get_unverified_header(token)
- kid = unverified_header.get("kid", None)
- public_key = None
- if kid:
- for key in jwks['keys']:
- if key['kid'] == kid:
- public_key = RSAAlgorithm.from_jwk(json.dumps(key))
- break
- else:
- public_key = RSAAlgorithm.from_jwk(json.dumps(jwks['keys'][0]))
- if public_key is None:
- raise UnauthorizedException(message="Public key not found in JWKS")
- claims = jwt.decode(
- token,
- public_key,
- algorithms=['RS256'],
- options={"verify_aud": False, "verify_iss": False},
- )
- return claims
- async def get_oidc_user_data(
- client: httpx.AsyncClient, token_res, config: Config
- ) -> Dict:
- """
- Retrieve user data from OIDC token or userinfo endpoint.
- By default, it uses the userinfo endpoint (standard OIDC).
- If `oidc_skip_userinfo` is set to True in config, it retrieves data from the ID token.
- Args:
- client: HTTP client for making requests
- token_res: The token response from OIDC provider
- config: Application configuration
- Returns:
- Dictionary containing user data
- """
- user_data = None
- if not isinstance(token_res, Dict):
- raise InvalidException(message="Invalid token response")
- if config.oidc_skip_userinfo:
- tokens = []
- if access_token := token_res.get("access_token", None):
- tokens.append(access_token)
- if id_token := token_res.get("id_token", None):
- tokens.append(id_token)
- for token in tokens:
- try:
- user_data = await decode_and_validate_token(client, token, config)
- if user_data:
- break
- except Exception as e:
- logger.warning(f"Token decoding/validation failed: {str(e)}")
- else:
- token = token_res.get("access_token", "")
- userinfo_endpoint = config.openid_configuration["userinfo_endpoint"]
- headers = {'Authorization': f'Bearer {token}'}
- userinfo_res = await client.get(userinfo_endpoint, headers=headers)
- if userinfo_res.status_code == 200:
- user_data = userinfo_res.json()
- else:
- raise UnauthorizedException(
- message="Failed to fetch user info from userinfo endpoint"
- )
- if not user_data:
- raise UnauthorizedException(message="Failed to retrieve valid user data")
- return user_data
- async def init_saml_auth(request: Request):
- """
- Initialize SAML authentication configuration.
- """
- config: Config = request.app.state.server_config
- form_data = await request.form()
- form_dict = dict(form_data)
- saml_settings = {
- "strict": True,
- "sp": {
- "entityId": config.saml_sp_entity_id, # sp_entityId
- "assertionConsumerService": {
- "url": config.saml_sp_acs_url, # callback url
- "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
- },
- "singleLogoutService": {
- "url": config.saml_sp_slo_url,
- "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
- },
- "x509cert": config.saml_sp_x509_cert, # SP public key
- "privateKey": config.saml_sp_private_key, # sp privateKey
- },
- "idp": {
- "entityId": config.saml_idp_entity_id, # idp_entityId
- "singleSignOnService": {
- "url": config.saml_idp_server_url, # server url
- "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
- },
- "singleLogoutService": {
- "url": config.saml_idp_logout_url,
- "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
- },
- "x509cert": config.saml_idp_x509_cert, # idp public key
- },
- "security": json.loads(config.saml_security),
- } # Signature configuration
- req = {
- "http_host": request.client.host,
- "script_name": request.url.path,
- "get_data": dict(request.query_params),
- "post_data": form_dict,
- }
- return OneLogin_Saml2_Auth(req, saml_settings)
- # SAML login and callback endpoints
- @router.get("/saml/login")
- async def saml_login(request: Request):
- auth = await init_saml_auth(request)
- return RedirectResponse(url=auth.login())
- @router.api_route("/saml/callback", methods=["GET", "POST"])
- async def saml_callback(request: Request, session: SessionDep):
- logger.debug("Invoke saml callback.")
- try:
- if request.method == "GET":
- query = dict(request.query_params)
- SAMLResponse = query['SAMLResponse']
- decoded = safe_b64decode(SAMLResponse)
- xml_bytes = inflate_data(decoded)
- else:
- form_data = await request.form()
- form_dict = dict(form_data)
- SAMLResponse = form_dict.get('SAMLResponse')
- xml_bytes = safe_b64decode(SAMLResponse)
- root = etree.fromstring(xml_bytes)
- name_id = root.find('.//{*}NameID').text
- ns = {'saml': 'urn:oasis:names:tc:SAML:2.0:assertion'}
- attributes = {}
- attributes['name_id'] = name_id
- for attr in root.xpath('//saml:Attribute', namespaces=ns):
- attr_name = attr.get('Name')
- values = [v.text for v in attr.xpath('saml:AttributeValue', namespaces=ns)]
- attributes[attr_name] = values[0] if len(values) == 1 else values
- config: Config = request.app.state.server_config
- if config.external_auth_name:
- # If external_auth_name is set, use it as username.
- username = get_saml_attributes(
- config, attributes, config.external_auth_name
- )
- else:
- # Try email or name_id for username if external_auth_name is not set.
- for key in ["email", "emailaddress", "name_id", "nameidentifier"]:
- username = get_saml_attributes(config, attributes, key)
- if username:
- break
- else:
- raise Exception(message="No valid username found in saml attributes")
- if config.external_auth_full_name and '+' not in config.external_auth_full_name:
- # If external_auth_full_name is set, use it as user's full name.
- full_name = get_saml_attributes(
- config, attributes, config.external_auth_full_name
- )
- elif config.external_auth_full_name:
- # external_auth_full_name is set with concat symbol '+'.
- full_name = ' '.join(
- [
- get_saml_attributes(config, attributes, v.strip())
- for v in config.external_auth_full_name.split('+')
- ]
- )
- else:
- full_name = ""
- # Try common claims. These are not guaranteed to be present.
- for key in ["displayName", "name"]:
- full_name = get_saml_attributes(config, attributes, key)
- if full_name:
- break
- avatar_url = None
- if config.external_auth_avatar_url:
- avatar_url = get_saml_attributes(
- config, attributes, config.external_auth_avatar_url
- )
- # determine whether the user already exists
- user = await User.first_by_field(
- session=session, field="username", value=username
- )
- # create user
- if not user:
- user_info = User(
- username=username,
- full_name=full_name,
- avatar_url=avatar_url,
- hashed_password="",
- is_admin=False,
- is_active=not config.external_auth_default_inactive,
- source=AuthProviderEnum.SAML,
- require_password_change=False,
- )
- user = await create_user_with_principal(session, user_info)
- await session.commit()
- elif (
- getattr(user, "id", None) is not None
- and getattr(user, "principal_id", None) is None
- ):
- # Backfill for SSO users created before Personal Org
- # provisioning was wired in. Idempotent: only fires when the
- # user has no Personal Org pointer.
- await provision_user_principal(session, user)
- await session.commit()
- jwt_manager: JWTManager = request.app.state.jwt_manager
- access_token = jwt_manager.create_jwt_token(
- username=username,
- )
- response = RedirectResponse(url='/', status_code=303)
- response.set_cookie(
- key=SESSION_COOKIE_NAME,
- value=access_token,
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- response.set_cookie(
- key=SSO_LOGIN_COOKIE_NAME,
- value="true",
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- except Exception as e:
- logger.error(f"SAML callback error: {str(e)}")
- raise UnauthorizedException(message=str(e))
- return response
- @router.api_route("/saml/logout/callback", methods=["GET", "POST"])
- async def saml_logout_callback(request: Request):
- try:
- auth = await init_saml_auth(request)
- auth.process_slo(False)
- except Exception:
- pass
- response = RedirectResponse(url="/")
- response.delete_cookie(key=SESSION_COOKIE_NAME)
- response.delete_cookie(key=SSO_LOGIN_COOKIE_NAME)
- return response
- def get_saml_attributes(
- config: Config, attributes: Dict[str, str], name: str
- ) -> Optional[str]:
- search_keys = []
- if config.saml_sp_attribute_prefix:
- search_keys.append(config.saml_sp_attribute_prefix + name)
- search_keys.extend(
- [
- f"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/{name}",
- name,
- ]
- )
- for key in search_keys:
- if key in attributes:
- return attributes[key]
- return None
- # OIDC login and callback endpoints
- @router.get("/oidc/login")
- async def oidc_login(request: Request):
- config: Config = request.app.state.server_config
- authorization_endpoint = config.openid_configuration["authorization_endpoint"]
- authUrl = (
- f'{authorization_endpoint}?response_type=code&'
- f'client_id={config.oidc_client_id}&'
- f'redirect_uri={config.oidc_redirect_uri}&'
- f'scope=openid profile email&state=random_state_string'
- )
- return RedirectResponse(url=authUrl)
- @router.get("/oidc/callback")
- async def oidc_callback(request: Request, session: SessionDep):
- logger.debug("Invoke oidc callback.")
- config: Config = request.app.state.server_config
- query = dict(request.query_params)
- code = query['code']
- data = {
- "grant_type": "authorization_code",
- "code": code,
- "client_id": config.oidc_client_id,
- "client_secret": config.oidc_client_secret,
- "redirect_uri": config.oidc_redirect_uri,
- }
- token_endpoint = config.openid_configuration["token_endpoint"]
- use_proxy_env = use_proxy_env_for_url(token_endpoint)
- verify = get_system_trust_store_ssl_context()
- async with httpx.AsyncClient(
- timeout=timeout, verify=verify, trust_env=use_proxy_env
- ) as client:
- try:
- token_res = await client.request("POST", token_endpoint, data=data)
- res_data = json.loads(token_res.text)
- if token_res.status_code != 200:
- raise BadRequestException(
- message=f"Failed to get token, {res_data['error_description']}"
- )
- # Get user data from token or userinfo endpoint
- user_data = await get_oidc_user_data(client, res_data, config)
- if config.external_auth_name:
- # If external_auth_name is set, use it as username.
- username = user_data.get(config.external_auth_name)
- else:
- # Try common OIDC fields for username if external_auth_name is not set.
- # Ref: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.18.1.1
- for key in ["email", "sub"]:
- if key in user_data:
- username = user_data[key]
- break
- else:
- raise UnauthorizedException(
- message="No valid username found in user data"
- )
- if (
- config.external_auth_full_name
- and '+' not in config.external_auth_full_name
- ):
- full_name = user_data.get(config.external_auth_full_name)
- elif config.external_auth_full_name:
- full_name = ' '.join(
- [
- user_data.get(v.strip())
- for v in config.external_auth_full_name.split('+')
- ]
- )
- else:
- full_name = user_data.get("name", "")
- if config.external_auth_avatar_url:
- avatar_url = user_data.get(config.external_auth_avatar_url)
- else:
- avatar_url = user_data.get("picture", None)
- except Exception as e:
- logger.error(f"Get OIDC user info error: {str(e)}")
- raise UnauthorizedException(message=str(e))
- # determine whether the user already exists
- user = await User.first_by_field(session=session, field="username", value=username)
- # create user
- if not user:
- user_info = User(
- username=username,
- full_name=full_name,
- avatar_url=avatar_url,
- hashed_password="",
- is_admin=False,
- is_active=not config.external_auth_default_inactive,
- source=AuthProviderEnum.OIDC,
- require_password_change=False,
- )
- user = await create_user_with_principal(session, user_info)
- await session.commit()
- elif (
- getattr(user, "id", None) is not None
- and getattr(user, "principal_id", None) is None
- ):
- # Backfill for SSO users created before Personal Org
- # provisioning was wired in. Idempotent: only fires when the
- # user has no Personal Org pointer.
- await provision_user_principal(session, user)
- await session.commit()
- jwt_manager: JWTManager = request.app.state.jwt_manager
- access_token = jwt_manager.create_jwt_token(
- username=username,
- )
- response = RedirectResponse(url='/')
- response.set_cookie(
- key=SESSION_COOKIE_NAME,
- value=access_token,
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- try:
- id_token = res_data.get("id_token")
- if id_token:
- response.set_cookie(
- key=OIDC_ID_TOKEN_COOKIE_NAME,
- value=id_token,
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- response.set_cookie(
- key=SSO_LOGIN_COOKIE_NAME,
- value="true",
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- except Exception as e:
- logger.warning(f"Failed to set id_token cookie: {str(e)}")
- return response
- # Local authentication endpoints
- @router.post("/login")
- async def login(
- request: Request,
- response: Response,
- session: SessionDep,
- username: Annotated[str, Form()] = "",
- password: Annotated[str, Form()] = "",
- ):
- user = await authenticate_user(session, username, password)
- user_name = user.username
- jwt_manager: JWTManager = request.app.state.jwt_manager
- access_token = jwt_manager.create_jwt_token(
- username=user_name,
- )
- response.set_cookie(
- key=SESSION_COOKIE_NAME,
- value=access_token,
- httponly=True,
- max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
- )
- @router.post("/logout")
- async def logout(request: Request):
- config: Config = request.app.state.server_config
- external_logout_url = None
- if (
- config.external_auth_type == AuthProviderEnum.OIDC
- and config.openid_configuration
- ):
- end_session_endpoint = config.openid_configuration.get("end_session_endpoint")
- if end_session_endpoint:
- redirect_uri = str(config.server_external_url or request.base_url)
- params = {
- "client_id": config.oidc_client_id,
- "post_logout_redirect_uri": redirect_uri,
- "id_token_hint": request.cookies.get(OIDC_ID_TOKEN_COOKIE_NAME),
- }
- if config.external_auth_post_logout_redirect_key:
- params[config.external_auth_post_logout_redirect_key] = redirect_uri
- query = urlencode({k: v for k, v in params.items() if v})
- external_logout_url = (
- end_session_endpoint if not query else f"{end_session_endpoint}?{query}"
- )
- elif config.external_auth_type == AuthProviderEnum.SAML:
- try:
- auth = await init_saml_auth(request)
- redirect_uri = str(config.server_external_url or request.base_url)
- params = {}
- if config.external_auth_post_logout_redirect_key:
- params[config.external_auth_post_logout_redirect_key] = redirect_uri
- external_logout_url = auth.logout(return_to=redirect_uri)
- query = urlencode({k: v for k, v in params.items() if v})
- if query:
- external_logout_url += f"&{query}"
- except Exception as e:
- logger.error(f"Failed to get SAML logout url: {str(e)}")
- external_logout_url = None
- # SSO logout: return SSO platform logout URL
- sso_login = request.cookies.get(SSO_LOGIN_COOKIE_NAME)
- sso_logout_url = config.sso_logout_redirect_url
- if sso_login and sso_logout_url:
- external_logout_url = sso_logout_url
- content = json.dumps({"logout_url": external_logout_url}) if sso_login else ""
- resp = Response(content=content, media_type="application/json")
- resp.delete_cookie(key=SESSION_COOKIE_NAME)
- resp.delete_cookie(key=OIDC_ID_TOKEN_COOKIE_NAME)
- resp.delete_cookie(key=SSO_LOGIN_COOKIE_NAME)
- return resp
- @router.post("/update-password")
- async def update_password(
- request: Request,
- session: SessionDep,
- user: CurrentUserDep,
- update_in: UpdatePassword,
- ):
- if not verify_hashed_secret(user.hashed_password, update_in.current_password):
- raise InvalidException(message="Incorrect current password")
- hashed_password = get_secret_hash(update_in.new_password)
- patch = {"hashed_password": hashed_password, "require_password_change": False}
- await user.update(session, patch)
- remove_initial_password_file_if_exists(request.app.state.server_config)
- @router.get("/config")
- async def get_auth_config(request: Request):
- req_dict = {}
- config: Config = request.app.state.server_config
- auth_type = (config.external_auth_type or "Local").lower()
- if auth_type == "oidc":
- req_dict = {"is_oidc": True, "is_saml": False}
- elif auth_type == "saml":
- req_dict = {"is_oidc": False, "is_saml": True}
- initial_password_file = Path(config.data_dir) / "initial_admin_password"
- if initial_password_file.exists():
- req_dict["first_time_setup"] = True
- req_dict["get_initial_password_command"] = _get_initial_password_command(
- initial_password_file
- )
- return req_dict
- def _get_initial_password_command(initial_password_file: Path) -> str:
- """
- Get the command to retrieve the initial admin password.
- """
- if os.getenv("KUBERNETES_SERVICE_HOST") is not None:
- # Kubernetes
- pod_name = os.getenv("HOSTNAME", "<pod_name>")
- namespace_file = Path("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
- namespace = (
- namespace_file.read_text().strip()
- if namespace_file.exists()
- else "<namespace>"
- )
- return f"kubectl exec {pod_name} -n {namespace} -- cat {initial_password_file}"
- elif Path("/.dockerenv").exists():
- # Docker
- return f"docker exec <container_name_or_id> cat {initial_password_file}"
- else:
- # Non-containerized
- return f"cat {initial_password_file}"
- def remove_initial_password_file_if_exists(config: Config):
- """
- Remove the initial admin password file if it exists.
- """
- initial_password_file = Path(config.data_dir) / "initial_admin_password"
- if initial_password_file.exists():
- try:
- initial_password_file.unlink()
- logger.debug(f"Initial password file deleted: {initial_password_file}")
- except Exception as e:
- logger.warning(f"Failed to delete initial password file: {e}")
- # SSO (LQAI-middle-platform) OAuth2 integration endpoints
- from gpustack.api.sso import (
- build_sso_authorize_url,
- handle_sso_exchange_code,
- )
- from pydantic import BaseModel
- class ExchangeCodeRequest(BaseModel):
- code: str
- @router.get("/sso/authorize")
- async def sso_authorize(request: Request, redirect: bool = False):
- """
- Build SSO OAuth2 authorization URL.
- If redirect=True, directly 302 redirect to SSO authorization page.
- """
- config: Config = request.app.state.server_config
- if not config.sso_base_url or not config.sso_client_id:
- raise InvalidException(message="SSO 未配置,请先配置 SSO_BASE_URL 和 SSO_CLIENT_ID")
- authorize_url = build_sso_authorize_url(config)
- if redirect:
- return RedirectResponse(url=authorize_url)
- return {
- "code": "000000",
- "message": "获取授权URL成功",
- "data": {"authorize_url": authorize_url},
- }
- @router.post("/oauth/exchange-code")
- async def oauth_exchange_code(
- request: Request,
- session: SessionDep,
- body: ExchangeCodeRequest,
- ):
- """
- Exchange SSO authorization code for local JWT.
- Core SSO login endpoint.
- """
- config: Config = request.app.state.server_config
- if not config.sso_base_url or not config.sso_client_id:
- raise InvalidException(message="SSO 未配置")
- if not body.code:
- raise BadRequestException(message="缺少授权码")
- try:
- jwt_manager: JWTManager = request.app.state.jwt_manager
- result = await handle_sso_exchange_code(session, config, body.code, jwt_manager)
- return {
- "code": "000000",
- "message": "登录成功",
- "data": result,
- }
- except Exception as e:
- logger.error(f"SSO exchange failed: {e}")
- error_msg = str(e)
- if "invalid_grant" in error_msg or "授权码" in error_msg:
- raise BadRequestException(message=f"登录失败: 授权码无效")
- raise InvalidException(message=f"登录失败: {error_msg}")
|