auth.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. import json
  2. import os
  3. from pathlib import Path
  4. import httpx
  5. import logging
  6. import jwt
  7. from jwt.algorithms import RSAAlgorithm
  8. from gpustack.config.config import Config
  9. from typing import Annotated, Dict, Optional
  10. from fastapi import APIRouter, Form, Request, Response
  11. from gpustack.api.exceptions import (
  12. InvalidException,
  13. UnauthorizedException,
  14. BadRequestException,
  15. )
  16. from gpustack.schemas.users import UpdatePassword
  17. from gpustack.schemas.users import User, AuthProviderEnum
  18. from gpustack.security import (
  19. JWTManager,
  20. get_secret_hash,
  21. verify_hashed_secret,
  22. )
  23. from gpustack import envs
  24. from gpustack.api.auth import (
  25. SESSION_COOKIE_NAME,
  26. OIDC_ID_TOKEN_COOKIE_NAME,
  27. SSO_LOGIN_COOKIE_NAME,
  28. authenticate_user,
  29. )
  30. from gpustack.server.deps import CurrentUserDep, SessionDep
  31. from gpustack.server.services import (
  32. create_user_with_principal,
  33. provision_user_principal,
  34. )
  35. from onelogin.saml2.auth import OneLogin_Saml2_Auth
  36. from fastapi.responses import RedirectResponse
  37. from lxml import etree
  38. from gpustack.utils.convert import safe_b64decode, inflate_data
  39. from urllib.parse import urlencode
  40. from gpustack.utils.network import (
  41. get_system_trust_store_ssl_context,
  42. use_proxy_env_for_url,
  43. )
  44. router = APIRouter()
  45. timeout = httpx.Timeout(connect=15.0, read=60.0, write=60.0, pool=10.0)
  46. logger = logging.getLogger(__name__)
  47. async def decode_and_validate_token(
  48. client: httpx.AsyncClient, token: str, config: Config
  49. ) -> Dict:
  50. """
  51. Decode the JWT token without verification and check if required fields are present.
  52. Args:
  53. token: token from OIDC provider
  54. config: Application configuration
  55. Returns:
  56. Dictionary containing decoded token data
  57. """
  58. jwks_uri = config.openid_configuration["jwks_uri"]
  59. jwks_res = await client.get(jwks_uri)
  60. jwks = jwks_res.json()
  61. unverified_header = jwt.get_unverified_header(token)
  62. kid = unverified_header.get("kid", None)
  63. public_key = None
  64. if kid:
  65. for key in jwks['keys']:
  66. if key['kid'] == kid:
  67. public_key = RSAAlgorithm.from_jwk(json.dumps(key))
  68. break
  69. else:
  70. public_key = RSAAlgorithm.from_jwk(json.dumps(jwks['keys'][0]))
  71. if public_key is None:
  72. raise UnauthorizedException(message="Public key not found in JWKS")
  73. claims = jwt.decode(
  74. token,
  75. public_key,
  76. algorithms=['RS256'],
  77. options={"verify_aud": False, "verify_iss": False},
  78. )
  79. return claims
  80. async def get_oidc_user_data(
  81. client: httpx.AsyncClient, token_res, config: Config
  82. ) -> Dict:
  83. """
  84. Retrieve user data from OIDC token or userinfo endpoint.
  85. By default, it uses the userinfo endpoint (standard OIDC).
  86. If `oidc_skip_userinfo` is set to True in config, it retrieves data from the ID token.
  87. Args:
  88. client: HTTP client for making requests
  89. token_res: The token response from OIDC provider
  90. config: Application configuration
  91. Returns:
  92. Dictionary containing user data
  93. """
  94. user_data = None
  95. if not isinstance(token_res, Dict):
  96. raise InvalidException(message="Invalid token response")
  97. if config.oidc_skip_userinfo:
  98. tokens = []
  99. if access_token := token_res.get("access_token", None):
  100. tokens.append(access_token)
  101. if id_token := token_res.get("id_token", None):
  102. tokens.append(id_token)
  103. for token in tokens:
  104. try:
  105. user_data = await decode_and_validate_token(client, token, config)
  106. if user_data:
  107. break
  108. except Exception as e:
  109. logger.warning(f"Token decoding/validation failed: {str(e)}")
  110. else:
  111. token = token_res.get("access_token", "")
  112. userinfo_endpoint = config.openid_configuration["userinfo_endpoint"]
  113. headers = {'Authorization': f'Bearer {token}'}
  114. userinfo_res = await client.get(userinfo_endpoint, headers=headers)
  115. if userinfo_res.status_code == 200:
  116. user_data = userinfo_res.json()
  117. else:
  118. raise UnauthorizedException(
  119. message="Failed to fetch user info from userinfo endpoint"
  120. )
  121. if not user_data:
  122. raise UnauthorizedException(message="Failed to retrieve valid user data")
  123. return user_data
  124. async def init_saml_auth(request: Request):
  125. """
  126. Initialize SAML authentication configuration.
  127. """
  128. config: Config = request.app.state.server_config
  129. form_data = await request.form()
  130. form_dict = dict(form_data)
  131. saml_settings = {
  132. "strict": True,
  133. "sp": {
  134. "entityId": config.saml_sp_entity_id, # sp_entityId
  135. "assertionConsumerService": {
  136. "url": config.saml_sp_acs_url, # callback url
  137. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  138. },
  139. "singleLogoutService": {
  140. "url": config.saml_sp_slo_url,
  141. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  142. },
  143. "x509cert": config.saml_sp_x509_cert, # SP public key
  144. "privateKey": config.saml_sp_private_key, # sp privateKey
  145. },
  146. "idp": {
  147. "entityId": config.saml_idp_entity_id, # idp_entityId
  148. "singleSignOnService": {
  149. "url": config.saml_idp_server_url, # server url
  150. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  151. },
  152. "singleLogoutService": {
  153. "url": config.saml_idp_logout_url,
  154. "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect",
  155. },
  156. "x509cert": config.saml_idp_x509_cert, # idp public key
  157. },
  158. "security": json.loads(config.saml_security),
  159. } # Signature configuration
  160. req = {
  161. "http_host": request.client.host,
  162. "script_name": request.url.path,
  163. "get_data": dict(request.query_params),
  164. "post_data": form_dict,
  165. }
  166. return OneLogin_Saml2_Auth(req, saml_settings)
  167. # SAML login and callback endpoints
  168. @router.get("/saml/login")
  169. async def saml_login(request: Request):
  170. auth = await init_saml_auth(request)
  171. return RedirectResponse(url=auth.login())
  172. @router.api_route("/saml/callback", methods=["GET", "POST"])
  173. async def saml_callback(request: Request, session: SessionDep):
  174. logger.debug("Invoke saml callback.")
  175. try:
  176. if request.method == "GET":
  177. query = dict(request.query_params)
  178. SAMLResponse = query['SAMLResponse']
  179. decoded = safe_b64decode(SAMLResponse)
  180. xml_bytes = inflate_data(decoded)
  181. else:
  182. form_data = await request.form()
  183. form_dict = dict(form_data)
  184. SAMLResponse = form_dict.get('SAMLResponse')
  185. xml_bytes = safe_b64decode(SAMLResponse)
  186. root = etree.fromstring(xml_bytes)
  187. name_id = root.find('.//{*}NameID').text
  188. ns = {'saml': 'urn:oasis:names:tc:SAML:2.0:assertion'}
  189. attributes = {}
  190. attributes['name_id'] = name_id
  191. for attr in root.xpath('//saml:Attribute', namespaces=ns):
  192. attr_name = attr.get('Name')
  193. values = [v.text for v in attr.xpath('saml:AttributeValue', namespaces=ns)]
  194. attributes[attr_name] = values[0] if len(values) == 1 else values
  195. config: Config = request.app.state.server_config
  196. if config.external_auth_name:
  197. # If external_auth_name is set, use it as username.
  198. username = get_saml_attributes(
  199. config, attributes, config.external_auth_name
  200. )
  201. else:
  202. # Try email or name_id for username if external_auth_name is not set.
  203. for key in ["email", "emailaddress", "name_id", "nameidentifier"]:
  204. username = get_saml_attributes(config, attributes, key)
  205. if username:
  206. break
  207. else:
  208. raise Exception(message="No valid username found in saml attributes")
  209. if config.external_auth_full_name and '+' not in config.external_auth_full_name:
  210. # If external_auth_full_name is set, use it as user's full name.
  211. full_name = get_saml_attributes(
  212. config, attributes, config.external_auth_full_name
  213. )
  214. elif config.external_auth_full_name:
  215. # external_auth_full_name is set with concat symbol '+'.
  216. full_name = ' '.join(
  217. [
  218. get_saml_attributes(config, attributes, v.strip())
  219. for v in config.external_auth_full_name.split('+')
  220. ]
  221. )
  222. else:
  223. full_name = ""
  224. # Try common claims. These are not guaranteed to be present.
  225. for key in ["displayName", "name"]:
  226. full_name = get_saml_attributes(config, attributes, key)
  227. if full_name:
  228. break
  229. avatar_url = None
  230. if config.external_auth_avatar_url:
  231. avatar_url = get_saml_attributes(
  232. config, attributes, config.external_auth_avatar_url
  233. )
  234. # determine whether the user already exists
  235. user = await User.first_by_field(
  236. session=session, field="username", value=username
  237. )
  238. # create user
  239. if not user:
  240. user_info = User(
  241. username=username,
  242. full_name=full_name,
  243. avatar_url=avatar_url,
  244. hashed_password="",
  245. is_admin=False,
  246. is_active=not config.external_auth_default_inactive,
  247. source=AuthProviderEnum.SAML,
  248. require_password_change=False,
  249. )
  250. user = await create_user_with_principal(session, user_info)
  251. await session.commit()
  252. elif (
  253. getattr(user, "id", None) is not None
  254. and getattr(user, "principal_id", None) is None
  255. ):
  256. # Backfill for SSO users created before Personal Org
  257. # provisioning was wired in. Idempotent: only fires when the
  258. # user has no Personal Org pointer.
  259. await provision_user_principal(session, user)
  260. await session.commit()
  261. jwt_manager: JWTManager = request.app.state.jwt_manager
  262. access_token = jwt_manager.create_jwt_token(
  263. username=username,
  264. )
  265. response = RedirectResponse(url='/', status_code=303)
  266. response.set_cookie(
  267. key=SESSION_COOKIE_NAME,
  268. value=access_token,
  269. httponly=True,
  270. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  271. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  272. )
  273. response.set_cookie(
  274. key=SSO_LOGIN_COOKIE_NAME,
  275. value="true",
  276. httponly=True,
  277. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  278. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  279. )
  280. except Exception as e:
  281. logger.error(f"SAML callback error: {str(e)}")
  282. raise UnauthorizedException(message=str(e))
  283. return response
  284. @router.api_route("/saml/logout/callback", methods=["GET", "POST"])
  285. async def saml_logout_callback(request: Request):
  286. try:
  287. auth = await init_saml_auth(request)
  288. auth.process_slo(False)
  289. except Exception:
  290. pass
  291. response = RedirectResponse(url="/")
  292. response.delete_cookie(key=SESSION_COOKIE_NAME)
  293. response.delete_cookie(key=SSO_LOGIN_COOKIE_NAME)
  294. return response
  295. def get_saml_attributes(
  296. config: Config, attributes: Dict[str, str], name: str
  297. ) -> Optional[str]:
  298. search_keys = []
  299. if config.saml_sp_attribute_prefix:
  300. search_keys.append(config.saml_sp_attribute_prefix + name)
  301. search_keys.extend(
  302. [
  303. f"http://schemas.xmlsoap.org/ws/2005/05/identity/claims/{name}",
  304. name,
  305. ]
  306. )
  307. for key in search_keys:
  308. if key in attributes:
  309. return attributes[key]
  310. return None
  311. # OIDC login and callback endpoints
  312. @router.get("/oidc/login")
  313. async def oidc_login(request: Request):
  314. config: Config = request.app.state.server_config
  315. authorization_endpoint = config.openid_configuration["authorization_endpoint"]
  316. authUrl = (
  317. f'{authorization_endpoint}?response_type=code&'
  318. f'client_id={config.oidc_client_id}&'
  319. f'redirect_uri={config.oidc_redirect_uri}&'
  320. f'scope=openid profile email&state=random_state_string'
  321. )
  322. return RedirectResponse(url=authUrl)
  323. @router.get("/oidc/callback")
  324. async def oidc_callback(request: Request, session: SessionDep):
  325. logger.debug("Invoke oidc callback.")
  326. config: Config = request.app.state.server_config
  327. query = dict(request.query_params)
  328. code = query['code']
  329. data = {
  330. "grant_type": "authorization_code",
  331. "code": code,
  332. "client_id": config.oidc_client_id,
  333. "client_secret": config.oidc_client_secret,
  334. "redirect_uri": config.oidc_redirect_uri,
  335. }
  336. token_endpoint = config.openid_configuration["token_endpoint"]
  337. use_proxy_env = use_proxy_env_for_url(token_endpoint)
  338. verify = get_system_trust_store_ssl_context()
  339. async with httpx.AsyncClient(
  340. timeout=timeout, verify=verify, trust_env=use_proxy_env
  341. ) as client:
  342. try:
  343. token_res = await client.request("POST", token_endpoint, data=data)
  344. res_data = json.loads(token_res.text)
  345. if token_res.status_code != 200:
  346. raise BadRequestException(
  347. message=f"Failed to get token, {res_data['error_description']}"
  348. )
  349. # Get user data from token or userinfo endpoint
  350. user_data = await get_oidc_user_data(client, res_data, config)
  351. if config.external_auth_name:
  352. # If external_auth_name is set, use it as username.
  353. username = user_data.get(config.external_auth_name)
  354. else:
  355. # Try common OIDC fields for username if external_auth_name is not set.
  356. # Ref: https://openid.net/specs/openid-connect-core-1_0.html#rfc.section.18.1.1
  357. for key in ["email", "sub"]:
  358. if key in user_data:
  359. username = user_data[key]
  360. break
  361. else:
  362. raise UnauthorizedException(
  363. message="No valid username found in user data"
  364. )
  365. if (
  366. config.external_auth_full_name
  367. and '+' not in config.external_auth_full_name
  368. ):
  369. full_name = user_data.get(config.external_auth_full_name)
  370. elif config.external_auth_full_name:
  371. full_name = ' '.join(
  372. [
  373. user_data.get(v.strip())
  374. for v in config.external_auth_full_name.split('+')
  375. ]
  376. )
  377. else:
  378. full_name = user_data.get("name", "")
  379. if config.external_auth_avatar_url:
  380. avatar_url = user_data.get(config.external_auth_avatar_url)
  381. else:
  382. avatar_url = user_data.get("picture", None)
  383. except Exception as e:
  384. logger.error(f"Get OIDC user info error: {str(e)}")
  385. raise UnauthorizedException(message=str(e))
  386. # determine whether the user already exists
  387. user = await User.first_by_field(session=session, field="username", value=username)
  388. # create user
  389. if not user:
  390. user_info = User(
  391. username=username,
  392. full_name=full_name,
  393. avatar_url=avatar_url,
  394. hashed_password="",
  395. is_admin=False,
  396. is_active=not config.external_auth_default_inactive,
  397. source=AuthProviderEnum.OIDC,
  398. require_password_change=False,
  399. )
  400. user = await create_user_with_principal(session, user_info)
  401. await session.commit()
  402. elif (
  403. getattr(user, "id", None) is not None
  404. and getattr(user, "principal_id", None) is None
  405. ):
  406. # Backfill for SSO users created before Personal Org
  407. # provisioning was wired in. Idempotent: only fires when the
  408. # user has no Personal Org pointer.
  409. await provision_user_principal(session, user)
  410. await session.commit()
  411. jwt_manager: JWTManager = request.app.state.jwt_manager
  412. access_token = jwt_manager.create_jwt_token(
  413. username=username,
  414. )
  415. response = RedirectResponse(url='/')
  416. response.set_cookie(
  417. key=SESSION_COOKIE_NAME,
  418. value=access_token,
  419. httponly=True,
  420. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  421. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  422. )
  423. try:
  424. id_token = res_data.get("id_token")
  425. if id_token:
  426. response.set_cookie(
  427. key=OIDC_ID_TOKEN_COOKIE_NAME,
  428. value=id_token,
  429. httponly=True,
  430. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  431. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  432. )
  433. response.set_cookie(
  434. key=SSO_LOGIN_COOKIE_NAME,
  435. value="true",
  436. httponly=True,
  437. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  438. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  439. )
  440. except Exception as e:
  441. logger.warning(f"Failed to set id_token cookie: {str(e)}")
  442. return response
  443. # Local authentication endpoints
  444. @router.post("/login")
  445. async def login(
  446. request: Request,
  447. response: Response,
  448. session: SessionDep,
  449. username: Annotated[str, Form()] = "",
  450. password: Annotated[str, Form()] = "",
  451. ):
  452. user = await authenticate_user(session, username, password)
  453. user_name = user.username
  454. jwt_manager: JWTManager = request.app.state.jwt_manager
  455. access_token = jwt_manager.create_jwt_token(
  456. username=user_name,
  457. )
  458. response.set_cookie(
  459. key=SESSION_COOKIE_NAME,
  460. value=access_token,
  461. httponly=True,
  462. max_age=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  463. expires=envs.JWT_TOKEN_EXPIRE_MINUTES * 60,
  464. )
  465. @router.post("/logout")
  466. async def logout(request: Request):
  467. config: Config = request.app.state.server_config
  468. external_logout_url = None
  469. sso_login = None # Ensure initialized before any conditional path
  470. if (
  471. config.external_auth_type == AuthProviderEnum.OIDC
  472. and config.openid_configuration
  473. ):
  474. end_session_endpoint = config.openid_configuration.get("end_session_endpoint")
  475. if end_session_endpoint:
  476. redirect_uri = str(config.server_external_url or request.base_url)
  477. params = {
  478. "client_id": config.oidc_client_id,
  479. "post_logout_redirect_uri": redirect_uri,
  480. "id_token_hint": request.cookies.get(OIDC_ID_TOKEN_COOKIE_NAME),
  481. }
  482. if config.external_auth_post_logout_redirect_key:
  483. params[config.external_auth_post_logout_redirect_key] = redirect_uri
  484. query = urlencode({k: v for k, v in params.items() if v})
  485. external_logout_url = (
  486. end_session_endpoint if not query else f"{end_session_endpoint}?{query}"
  487. )
  488. elif config.external_auth_type == AuthProviderEnum.SAML:
  489. try:
  490. auth = await init_saml_auth(request)
  491. redirect_uri = str(config.server_external_url or request.base_url)
  492. params = {}
  493. if config.external_auth_post_logout_redirect_key:
  494. params[config.external_auth_post_logout_redirect_key] = redirect_uri
  495. external_logout_url = auth.logout(return_to=redirect_uri)
  496. query = urlencode({k: v for k, v in params.items() if v})
  497. if query:
  498. external_logout_url += f"&{query}"
  499. except Exception as e:
  500. logger.error(f"Failed to get SAML logout url: {str(e)}")
  501. external_logout_url = None
  502. # SSO logout: return SSO platform logout URL
  503. sso_login = request.cookies.get(SSO_LOGIN_COOKIE_NAME)
  504. sso_logout_url = config.sso_logout_redirect_url
  505. if sso_login and sso_logout_url:
  506. external_logout_url = sso_logout_url
  507. content = json.dumps({"logout_url": external_logout_url}) if sso_login else ""
  508. resp = Response(content=content, media_type="application/json")
  509. resp.delete_cookie(key=SESSION_COOKIE_NAME)
  510. resp.delete_cookie(key=OIDC_ID_TOKEN_COOKIE_NAME)
  511. resp.delete_cookie(key=SSO_LOGIN_COOKIE_NAME)
  512. return resp
  513. @router.post("/update-password")
  514. async def update_password(
  515. request: Request,
  516. session: SessionDep,
  517. user: CurrentUserDep,
  518. update_in: UpdatePassword,
  519. ):
  520. if not verify_hashed_secret(user.hashed_password, update_in.current_password):
  521. raise InvalidException(message="Incorrect current password")
  522. hashed_password = get_secret_hash(update_in.new_password)
  523. patch = {"hashed_password": hashed_password, "require_password_change": False}
  524. await user.update(session, patch)
  525. remove_initial_password_file_if_exists(request.app.state.server_config)
  526. @router.get("/config")
  527. async def get_auth_config(request: Request):
  528. req_dict = {}
  529. config: Config = request.app.state.server_config
  530. auth_type = (config.external_auth_type or "Local").lower()
  531. if auth_type == "oidc":
  532. req_dict = {"is_oidc": True, "is_saml": False}
  533. elif auth_type == "saml":
  534. req_dict = {"is_oidc": False, "is_saml": True}
  535. initial_password_file = Path(config.data_dir) / "initial_admin_password"
  536. if initial_password_file.exists():
  537. req_dict["first_time_setup"] = True
  538. req_dict["get_initial_password_command"] = _get_initial_password_command(
  539. initial_password_file
  540. )
  541. return req_dict
  542. def _get_initial_password_command(initial_password_file: Path) -> str:
  543. """
  544. Get the command to retrieve the initial admin password.
  545. """
  546. if os.getenv("KUBERNETES_SERVICE_HOST") is not None:
  547. # Kubernetes
  548. pod_name = os.getenv("HOSTNAME", "<pod_name>")
  549. namespace_file = Path("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
  550. namespace = (
  551. namespace_file.read_text().strip()
  552. if namespace_file.exists()
  553. else "<namespace>"
  554. )
  555. return f"kubectl exec {pod_name} -n {namespace} -- cat {initial_password_file}"
  556. elif Path("/.dockerenv").exists():
  557. # Docker
  558. return f"docker exec <container_name_or_id> cat {initial_password_file}"
  559. else:
  560. # Non-containerized
  561. return f"cat {initial_password_file}"
  562. def remove_initial_password_file_if_exists(config: Config):
  563. """
  564. Remove the initial admin password file if it exists.
  565. """
  566. initial_password_file = Path(config.data_dir) / "initial_admin_password"
  567. if initial_password_file.exists():
  568. try:
  569. initial_password_file.unlink()
  570. logger.debug(f"Initial password file deleted: {initial_password_file}")
  571. except Exception as e:
  572. logger.warning(f"Failed to delete initial password file: {e}")
  573. # SSO (LQAI-middle-platform) OAuth2 integration endpoints
  574. from gpustack.api.sso import (
  575. build_sso_authorize_url,
  576. handle_sso_exchange_code,
  577. )
  578. from pydantic import BaseModel
  579. class ExchangeCodeRequest(BaseModel):
  580. code: str
  581. @router.get("/sso/authorize")
  582. async def sso_authorize(request: Request, redirect: bool = False):
  583. """
  584. Build SSO OAuth2 authorization URL.
  585. If redirect=True, directly 302 redirect to SSO authorization page.
  586. """
  587. config: Config = request.app.state.server_config
  588. if not config.sso_base_url or not config.sso_client_id:
  589. raise InvalidException(message="SSO 未配置,请先配置 SSO_BASE_URL 和 SSO_CLIENT_ID")
  590. authorize_url = build_sso_authorize_url(config)
  591. if redirect:
  592. return RedirectResponse(url=authorize_url)
  593. return {
  594. "code": "000000",
  595. "message": "获取授权URL成功",
  596. "data": {"authorize_url": authorize_url},
  597. }
  598. @router.post("/oauth/exchange-code")
  599. async def oauth_exchange_code(
  600. request: Request,
  601. session: SessionDep,
  602. body: ExchangeCodeRequest,
  603. ):
  604. """
  605. Exchange SSO authorization code for local JWT.
  606. Core SSO login endpoint.
  607. """
  608. config: Config = request.app.state.server_config
  609. if not config.sso_base_url or not config.sso_client_id:
  610. raise InvalidException(message="SSO 未配置")
  611. if not body.code:
  612. raise BadRequestException(message="缺少授权码")
  613. try:
  614. jwt_manager: JWTManager = request.app.state.jwt_manager
  615. result = await handle_sso_exchange_code(session, config, body.code, jwt_manager)
  616. return {
  617. "code": "000000",
  618. "message": "登录成功",
  619. "data": result,
  620. }
  621. except Exception as e:
  622. logger.error(f"SSO exchange failed: {e}")
  623. error_msg = str(e)
  624. if "invalid_grant" in error_msg or "授权码" in error_msg:
  625. raise BadRequestException(message=f"登录失败: 授权码无效")
  626. raise InvalidException(message=f"登录失败: {error_msg}")