token.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import logging
  2. from datetime import timedelta
  3. from typing import Optional, Annotated
  4. from fastapi.security import HTTPAuthorizationCredentials
  5. from fastapi import APIRouter, Request, Response, Depends
  6. from gpustack.api.exceptions import (
  7. NotFoundException,
  8. ForbiddenException,
  9. UnauthorizedException,
  10. BadRequestException,
  11. )
  12. from gpustack.server.services import ModelRouteService, UserService
  13. from gpustack.schemas.api_keys import ApiKey
  14. from gpustack.schemas.users import User
  15. from gpustack.schemas.models import AccessPolicyEnum
  16. from gpustack.server.deps import SessionDep
  17. from gpustack.api.auth import (
  18. GATEWAY_AUTH_TOKEN_HEADER,
  19. api_key_header_auth,
  20. basic_auth,
  21. cookie_auth,
  22. bearer_auth,
  23. get_current_user,
  24. credentials_exception,
  25. gateway_token_auth,
  26. inference_scope,
  27. )
  28. from gpustack.security import JWTManager, AUTH_CACHE_HEADER
  29. logger = logging.getLogger(__name__)
  30. router = APIRouter()
  31. model_name_missing_exception = BadRequestException(
  32. message="Missing 'model' field",
  33. is_openai_exception=True,
  34. )
  35. model_not_found_exception = NotFoundException(
  36. message="Model not found",
  37. is_openai_exception=True,
  38. )
  39. @router.get("")
  40. async def server_auth(
  41. request: Request,
  42. session: SessionDep,
  43. ):
  44. jwt_manager: JWTManager = request.app.state.jwt_manager
  45. cached = request.headers.get(AUTH_CACHE_HEADER)
  46. request_model = request.headers.get("x-higress-llm-model")
  47. if cached and request_model:
  48. try:
  49. data = jwt_manager.decode_jwt_data(cached)
  50. if data.get("model") == request_model:
  51. return Response(
  52. status_code=200,
  53. headers={
  54. "X-Mse-Consumer": data["consumer"],
  55. "Authorization": "Bearer " + data["token"],
  56. "cookie": "dummy=dummy",
  57. },
  58. )
  59. except Exception:
  60. pass
  61. user: Optional[User] = None
  62. api_key: Optional[ApiKey] = None
  63. access_key: Optional[str] = None
  64. consumer = 'none'
  65. cookie_token = await cookie_auth(request)
  66. x_api_key = await api_key_header_auth(request)
  67. try:
  68. user = await get_current_user(
  69. request=request,
  70. session=session,
  71. basic_credentials=await basic_auth(request),
  72. bearer_token=await bearer_auth(request),
  73. x_api_key=x_api_key,
  74. cookie_token=cookie_token,
  75. )
  76. api_key = getattr(request.state, "api_key", None)
  77. access_key = None if api_key is None else api_key.access_key
  78. consumer = '.'.join(
  79. [part for part in [access_key, f"gpustack-{user.id}"] if part is not None]
  80. )
  81. except UnauthorizedException:
  82. logger.debug("Unauthenticated request to server token-auth endpoint")
  83. except Exception as e:
  84. logger.error(f"Error during authentication: {e}")
  85. raise e
  86. if user is None:
  87. gateway_token_auth(
  88. request,
  89. token=request.headers.get(GATEWAY_AUTH_TOKEN_HEADER),
  90. )
  91. model_name = request.headers.get("x-higress-llm-model")
  92. if model_name is None or model_name == "":
  93. logger.debug(
  94. "Missing x-higress-llm-model header for token authentication",
  95. )
  96. raise credentials_exception if user is None else model_name_missing_exception
  97. pair = await ModelRouteService(session=session).get_model_auth_info_by_name(
  98. model_name
  99. )
  100. if pair is None:
  101. raise credentials_exception if user is None else model_not_found_exception
  102. policy = pair[0]
  103. registration_token = pair[1]
  104. if user is None and policy != AccessPolicyEnum.PUBLIC:
  105. logger.debug(
  106. f"Unauthenticated request to access model {model_name} with policy {policy}",
  107. )
  108. raise credentials_exception
  109. if policy != AccessPolicyEnum.PUBLIC:
  110. # llm_scope will raise exception if the api key is not allowed to access llm.
  111. inference_scope(request, user)
  112. if not await UserService(session).model_allowed_for_user(
  113. model_name=model_name,
  114. user_id=user.id,
  115. api_key=api_key,
  116. ):
  117. raise ForbiddenException(
  118. message=f"Api key not allowed to access model {model_name}"
  119. )
  120. cache_token = jwt_manager.create_token(
  121. {"consumer": consumer, "token": registration_token, "model": model_name},
  122. expires_delta=timedelta(minutes=5),
  123. )
  124. return Response(
  125. status_code=200,
  126. headers={
  127. "X-Mse-Consumer": consumer,
  128. "Authorization": f"Bearer {registration_token}",
  129. AUTH_CACHE_HEADER: cache_token,
  130. "cookie": "dummy=dummy",
  131. },
  132. )
  133. async def worker_auth(
  134. request: Request,
  135. bearer_token: Annotated[
  136. Optional[HTTPAuthorizationCredentials], Depends(bearer_auth)
  137. ] = None,
  138. x_api_key: Annotated[Optional[str], Depends(api_key_header_auth)] = None,
  139. ):
  140. token: str = request.app.state.token
  141. registration_token = request.app.state.config.token
  142. model_name = request.headers.get("X-Higress-Llm-Model")
  143. if model_name is None:
  144. logger.warning("Missing X-Higress-Llm-Model header for token authentication")
  145. raise credentials_exception
  146. token_value = (bearer_token.credentials if bearer_token else None) or x_api_key
  147. if token_value is None:
  148. raise credentials_exception
  149. if token_value != token and token_value != registration_token:
  150. raise credentials_exception
  151. return Response(
  152. status_code=200,
  153. headers={
  154. "X-Mse-Consumer": "gpustack-server",
  155. },
  156. )