auth.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import re
  2. import uuid
  3. from datetime import datetime, timezone
  4. import logging
  5. import aiohttp
  6. from aiocache import cached
  7. from fastapi import Depends, Request, WebSocket
  8. from starlette.datastructures import Headers
  9. from gpustack.config.config import Config
  10. from gpustack.schemas.config import GatewayModeEnum
  11. from gpustack.server.db import get_session, async_session
  12. from typing import Annotated, Optional, Tuple, List, Dict
  13. from fastapi.security import (
  14. APIKeyCookie,
  15. APIKeyHeader,
  16. HTTPAuthorizationCredentials,
  17. HTTPBasic,
  18. HTTPBasicCredentials,
  19. HTTPBearer,
  20. )
  21. from fastapi.security.utils import get_authorization_scheme_param
  22. from sqlmodel.ext.asyncio.session import AsyncSession
  23. from gpustack.api.exceptions import (
  24. ForbiddenException,
  25. InternalServerErrorException,
  26. UnauthorizedException,
  27. HTTPException,
  28. )
  29. from gpustack.schemas.api_keys import ApiKey, PermissionScope
  30. from gpustack.schemas.users import User, UserRole
  31. from gpustack.security import (
  32. JWTManager,
  33. verify_hashed_secret,
  34. get_key_pair,
  35. )
  36. from gpustack.server.services import APIKeyService, UserService, WorkerService
  37. from gpustack.websocket_proxy.authenticator import (
  38. Authenticator as WebsocketAuthenticator,
  39. )
  40. logger = logging.getLogger(__name__)
  41. SESSION_COOKIE_NAME = "gpustack_session"
  42. OIDC_ID_TOKEN_COOKIE_NAME = "gpustack_oidc_id_token"
  43. SSO_LOGIN_COOKIE_NAME = "gpustack_sso_login"
  44. SYSTEM_USER_PREFIX = "system/"
  45. SYSTEM_WORKER_USER_PREFIX = "system/worker/"
  46. GATEWAY_AUTH_TOKEN_HEADER = "X-GPUStack-Auth-Token"
  47. basic_auth = HTTPBasic(auto_error=False)
  48. bearer_auth = HTTPBearer(auto_error=False)
  49. api_key_header_auth = APIKeyHeader(name="X-API-Key", auto_error=False)
  50. cookie_auth = APIKeyCookie(name=SESSION_COOKIE_NAME, auto_error=False)
  51. _gateway_auth_header = APIKeyHeader(name=GATEWAY_AUTH_TOKEN_HEADER, auto_error=False)
  52. credentials_exception = UnauthorizedException(
  53. message="Invalid authentication credentials"
  54. )
  55. def gateway_token_auth(
  56. request: Request,
  57. token: Annotated[Optional[str], Depends(_gateway_auth_header)] = None,
  58. ):
  59. if not token:
  60. raise UnauthorizedException(message="Missing authentication token")
  61. cfg: Config = request.app.state.server_config
  62. if token != cfg.get_derived_gateway_token():
  63. raise UnauthorizedException(message="Invalid gateway token")
  64. def client_ip_getter(request: Request) -> str:
  65. if request.app.state.server_config.gateway_mode == GatewayModeEnum.embedded:
  66. return request.headers.get("X-GPUStack-Real-IP", "")
  67. else:
  68. return request.client.host
  69. async def get_current_user(
  70. request: Request,
  71. session: Annotated[AsyncSession, Depends(get_session)],
  72. basic_credentials: Annotated[
  73. Optional[HTTPBasicCredentials], Depends(basic_auth)
  74. ] = None,
  75. bearer_token: Annotated[
  76. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  77. ] = None,
  78. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  79. cookie_token: Annotated[Optional[str], Depends(cookie_auth)] = None,
  80. ) -> User:
  81. if hasattr(request.state, "user"):
  82. user: User = getattr(request.state, "user")
  83. return user
  84. api_key: Optional[ApiKey] = None
  85. user = None
  86. try:
  87. server_config: Config = request.app.state.server_config
  88. if basic_credentials and is_system_user(basic_credentials.username):
  89. user = await authenticate_system_user(server_config, basic_credentials)
  90. elif basic_credentials:
  91. user = await authenticate_basic_user(session, basic_credentials)
  92. elif cookie_token:
  93. jwt_manager: JWTManager = request.app.state.jwt_manager
  94. user = await get_user_from_jwt_token(session, jwt_manager, cookie_token)
  95. elif bearer_token or x_api_key:
  96. token = (bearer_token.credentials if bearer_token else None) or x_api_key
  97. if token is not None:
  98. # Try API key first, fall back to JWT token validation
  99. user, api_key = await get_user_from_api_token(session, token)
  100. if user is None:
  101. # Bearer token might be a JWT (e.g. from SSO callback)
  102. jwt_manager: JWTManager = request.app.state.jwt_manager
  103. user = await get_user_from_jwt_token(session, jwt_manager, token)
  104. if user is None and client_ip_getter(request=request) == "127.0.0.1":
  105. if not server_config.force_auth_localhost:
  106. user = await User.first_by_field(session, "is_admin", True)
  107. if user:
  108. if not user.is_active:
  109. raise UnauthorizedException(message="User account is deactivated")
  110. request.state.user = user
  111. if api_key is not None:
  112. request.state.api_key = api_key
  113. return user
  114. except HTTPException:
  115. raise
  116. except Exception as e:
  117. raise InternalServerErrorException(message=f"Failed to authenticate user: {e}")
  118. raise credentials_exception
  119. async def get_admin_user(
  120. current_user: Annotated[User, Depends(get_current_user)],
  121. ) -> User:
  122. if not current_user.is_admin:
  123. raise ForbiddenException(message="No permission to access")
  124. return current_user
  125. async def get_cluster_user(
  126. current_user: Annotated[User, Depends(get_current_user)],
  127. ) -> User:
  128. if (
  129. current_user.is_system
  130. and current_user.role == UserRole.Cluster
  131. and current_user.cluster_id is not None
  132. ):
  133. return current_user
  134. return await get_admin_user(current_user)
  135. async def get_worker_user(
  136. current_user: Annotated[User, Depends(get_current_user)],
  137. ) -> User:
  138. if (
  139. current_user.is_system
  140. and current_user.role == UserRole.Worker
  141. and current_user.worker is not None
  142. ):
  143. return current_user
  144. return await get_admin_user(current_user)
  145. def is_system_user(username: str) -> bool:
  146. return username.startswith(SYSTEM_USER_PREFIX)
  147. async def authenticate_system_user(
  148. config: Config,
  149. credentials: HTTPBasicCredentials,
  150. ) -> Optional[User]:
  151. if credentials.username.startswith(SYSTEM_WORKER_USER_PREFIX):
  152. if credentials.password == config.token:
  153. return User(username=credentials.username, is_admin=True)
  154. return None
  155. async def authenticate_basic_user(
  156. session: AsyncSession,
  157. basic_credentials: HTTPBasicCredentials,
  158. ) -> Optional[User]:
  159. try:
  160. user = await authenticate_user(
  161. session, basic_credentials.username, basic_credentials.password
  162. )
  163. return user
  164. except Exception:
  165. return None
  166. def get_access_token(
  167. bearer_token: Optional[HTTPAuthorizationCredentials],
  168. oauth2_bearer_token: Optional[str],
  169. cookie_token: Optional[str],
  170. ) -> str:
  171. if bearer_token:
  172. return bearer_token.credentials
  173. elif oauth2_bearer_token:
  174. return oauth2_bearer_token
  175. elif cookie_token:
  176. return cookie_token
  177. else:
  178. raise credentials_exception
  179. async def get_user_from_jwt_token(
  180. session: AsyncSession, jwt_manager: JWTManager, access_token: str
  181. ) -> Optional[User]:
  182. try:
  183. payload = jwt_manager.decode_jwt_token(access_token)
  184. username = payload.get("sub")
  185. except Exception:
  186. logger.debug("Failed to decode JWT token")
  187. return None
  188. if username is None:
  189. return None
  190. try:
  191. user = await UserService(session).get_by_username(username)
  192. except Exception as e:
  193. raise InternalServerErrorException(message=f"Failed to get user: {e}")
  194. return user
  195. def parse_hyphen_uuid(value: str) -> Optional[str]:
  196. if not re.match(
  197. r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$', value, re.I
  198. ):
  199. return None
  200. try:
  201. uuid.UUID(value)
  202. return value
  203. except ValueError:
  204. return None
  205. async def get_user_from_api_token(
  206. session: AsyncSession, token: str
  207. ) -> Tuple[Optional[User], Optional[ApiKey]]:
  208. try:
  209. access_key, secret_key = get_key_pair(token)
  210. worker_uuid = parse_hyphen_uuid(access_key)
  211. if worker_uuid is not None:
  212. # if access_key is a valid uuid, it's legacy worker re-registering with legacy token
  213. access_key = ""
  214. access_keys = [access_key]
  215. # the custom key should have 32 chars access key as it is generated by security.custom_key_hash which will return 32 chars hex string.
  216. if len(access_key) == 32 and "" not in access_keys:
  217. # this means it is custom key or legacy worker token, we should also try to find api key with empty access key for backward compatibility
  218. access_keys.append("")
  219. for access_key in access_keys:
  220. api_key: ApiKey = await APIKeyService(session).get_by_access_key(access_key)
  221. if api_key:
  222. logger.trace(f"Found API key for access key: {access_key}")
  223. break
  224. if (
  225. api_key is not None
  226. and verify_hashed_secret(api_key.hashed_secret_key, secret_key)
  227. and (
  228. api_key.expires_at is None
  229. or api_key.expires_at > datetime.now(timezone.utc)
  230. )
  231. ):
  232. user: Optional[User] = await UserService(session).get_by_id(
  233. user_id=api_key.user_id,
  234. )
  235. if user is not None:
  236. return user, api_key
  237. except Exception as e:
  238. raise InternalServerErrorException(message=f"Failed to get user: {e}")
  239. return None, None
  240. async def authenticate_user(
  241. session: AsyncSession, username: str, password: str
  242. ) -> User:
  243. user = await UserService(session).get_by_username(username)
  244. if not user:
  245. raise UnauthorizedException(message="Incorrect username or password")
  246. if not verify_hashed_secret(user.hashed_password, password):
  247. raise UnauthorizedException(message="Incorrect username or password")
  248. if not user.is_active:
  249. raise UnauthorizedException(message="User account is deactivated")
  250. return user
  251. async def worker_auth(
  252. request: Request,
  253. bearer_token: Annotated[
  254. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  255. ] = None,
  256. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  257. ):
  258. token_value = (bearer_token.credentials if bearer_token else None) or x_api_key
  259. if not token_value:
  260. raise UnauthorizedException(message="Invalid authentication credentials")
  261. token = request.app.state.token
  262. config: Config = request.app.state.config
  263. registration_token = config.token
  264. server_url = config.get_server_url()
  265. if token_value in [token, registration_token]:
  266. return
  267. model_name = request.headers.get("X-Higress-Llm-Model")
  268. if model_name is not None:
  269. cred = token_value
  270. show_len = max(1, min(6, len(cred)))
  271. masked_token = f"{'*' * (len(cred) - show_len)}{cred[-show_len:]}"
  272. logger.debug(f"Verifying worker token {masked_token} via server authentication")
  273. cached_auth = make_auth_token_via_server(request.app.state.http_client_no_proxy)
  274. is_valid = await cached_auth(server_url, token_value, model_name)
  275. if is_valid:
  276. return
  277. raise UnauthorizedException(message="Invalid authentication credentials")
  278. def make_auth_token_via_server(client: aiohttp.ClientSession):
  279. @cached(ttl=60)
  280. async def inner(server_url: str, token: str, model_name: str) -> bool:
  281. auth_url = f"{server_url.rstrip('/')}/token-auth"
  282. headers = {
  283. "Authorization": f"Bearer {token}",
  284. "X-Higress-Llm-Model": model_name,
  285. }
  286. try:
  287. async with client.get(auth_url, headers=headers) as resp:
  288. return resp.status == 200
  289. except aiohttp.ClientError as e:
  290. logger.error(f"Error verifying token via server: {e}")
  291. return False
  292. return inner
  293. def get_scopes(
  294. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  295. ) -> List[str]:
  296. api_key: ApiKey = getattr(request.state, "api_key", None)
  297. if api_key is not None:
  298. return api_key.scope
  299. return [PermissionScope.ALL]
  300. def inference_scope(
  301. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  302. ):
  303. scopes = get_scopes(request, _current_user)
  304. if PermissionScope.ALL not in scopes and PermissionScope.INFERENCE not in scopes:
  305. raise ForbiddenException(
  306. message="API key does not have permission to access inference features"
  307. )
  308. def management_scope(
  309. request: Request, _current_user: Annotated[User, Depends(get_current_user)]
  310. ):
  311. scopes = get_scopes(request, _current_user)
  312. if PermissionScope.ALL not in scopes and PermissionScope.MANAGEMENT not in scopes:
  313. raise ForbiddenException(
  314. message="API key does not have permission to access management features"
  315. )
  316. async def authenticate_worker_by_request_headers(
  317. header_dict: Dict[str, str],
  318. validate_proxy: Optional[bool] = None,
  319. ) -> Optional[User]:
  320. """
  321. Authenticate a worker based on request headers, used for both WebSocket and non-WebSocket requests.
  322. For WebSocket requests, the Bearer token is expected in the "Authorization" header.
  323. For non-WebSocket requests (e.g. HTTP requests to the proxy), the Bearer token can be in either "Authorization"
  324. or "Proxy-Authorization" header, with "Proxy-Authorization" taking precedence if both are present.
  325. """
  326. headers = Headers(header_dict)
  327. authorization: Optional[str] = None
  328. if validate_proxy:
  329. authorization = headers.get("Proxy-Authorization")
  330. elif validate_proxy is not None:
  331. authorization = headers.get("Authorization")
  332. else:
  333. # if validate_proxy is None, it means we are in a context where both headers could be used (e.g. WebSocket connection from the proxy)
  334. # in this case we give precedence to Proxy-Authorization if it exists, otherwise fall back to Authorization
  335. authorization = headers.get("Proxy-Authorization") or headers.get(
  336. "Authorization"
  337. )
  338. async with async_session() as session:
  339. scheme, credentials = get_authorization_scheme_param(authorization)
  340. if not (authorization and scheme and credentials) or scheme.lower() != "bearer":
  341. return None
  342. bearer_token = HTTPAuthorizationCredentials(
  343. scheme=scheme, credentials=credentials
  344. )
  345. user, _ = await get_user_from_api_token(session, bearer_token.credentials)
  346. if user is None:
  347. return None
  348. if user.worker_id is not None:
  349. user.worker = await WorkerService(session).get_by_id(user.worker_id)
  350. return user
  351. class BearerTokenAuthenticator(WebsocketAuthenticator):
  352. """Websocket authenticator that verifies bearer tokens via the main server."""
  353. token: Optional[str]
  354. def __init__(
  355. self,
  356. token: Optional[str] = None,
  357. headers: Optional[Dict[str, str]] = None,
  358. ):
  359. self.token = token
  360. if not self.token and headers:
  361. parsed_headers = Headers(headers)
  362. self.token = parsed_headers.get("Authorization", "").replace("Bearer ", "")
  363. def inject_headers(
  364. self,
  365. headers: Dict[str, str],
  366. ) -> None:
  367. # No need to inject headers for outgoing connections from the proxy
  368. for key in list(headers.keys()):
  369. if key.lower() == "authorization":
  370. headers.pop(key)
  371. if self.token:
  372. headers.setdefault("Authorization", f"Bearer {self.token}")
  373. async def authenticate(self, websocket: WebSocket) -> bool:
  374. user = await authenticate_worker_by_request_headers(
  375. websocket.headers, validate_proxy=False
  376. )
  377. if user is None:
  378. return False
  379. if user.worker is None:
  380. logger.debug(
  381. f"Authenticated user {user.id} with bearer token but it is not associated with any worker"
  382. )
  383. return False
  384. if websocket.headers.get("x-client-id") != user.worker.worker_uuid:
  385. logger.debug(
  386. f"Authenticated worker {user.worker_id} with bearer token but client_id {websocket.headers.get('x-client-id')} does not match worker_uuid {user.worker.worker_uuid}"
  387. )
  388. return False
  389. websocket.scope["user"] = user
  390. return True