security.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import secrets
  2. import string
  3. from datetime import datetime, timedelta, timezone
  4. from typing import Optional, Union, Tuple
  5. from functools import lru_cache
  6. import jwt
  7. from argon2 import PasswordHasher
  8. import hashlib
  9. from gpustack import envs
  10. ph = PasswordHasher()
  11. API_KEY_PREFIX = "gpustack"
  12. @lru_cache(maxsize=2048)
  13. def verify_hashed_secret(hashed: Union[str, bytes], plain: Union[str, bytes]) -> bool:
  14. try:
  15. return ph.verify(hashed, plain)
  16. except Exception:
  17. return False
  18. def get_secret_hash(plain: Union[str, bytes]):
  19. return ph.hash(plain)
  20. def generate_secure_password(length=12):
  21. if length < 8:
  22. raise ValueError("Password length should be at least 8 characters")
  23. special_characters = "!@#$%^&*_+"
  24. characters = string.ascii_letters + string.digits + special_characters
  25. while True:
  26. password = ''.join(secrets.choice(characters) for i in range(length))
  27. if (
  28. any(c.islower() for c in password)
  29. and any(c.isupper() for c in password)
  30. and any(c.isdigit() for c in password)
  31. and any(c in special_characters for c in password)
  32. ):
  33. return password
  34. def custom_key_hash(secret_key: str) -> str:
  35. return hashlib.blake2b(secret_key.encode(), digest_size=16).hexdigest()
  36. def is_valid_format(key: str) -> Tuple[bool, str, str]:
  37. if not key.startswith(f"{API_KEY_PREFIX}_"):
  38. return False, "", ""
  39. parts = key.split("_", 2)
  40. if len(parts) != 3:
  41. return False, "", ""
  42. access_key, secret_key = parts[1], parts[2]
  43. return True, access_key, secret_key
  44. def get_key_pair(key: str) -> Tuple[str, str]:
  45. """
  46. Parse and validate an API key.
  47. Scenarios:
  48. 1. Standard format key: "gpustack_{access_key}_{secret_key}"
  49. - access_key: 8 characters (hex string, e.g. "3192253c")
  50. - secret_key: 16 characters (hex string, e.g. "c11c75ed6334ea9505da4ad9")
  51. - Used for normal API authentication via /v2/* routes
  52. 2. Legacy UUID format key: standard UUID format with dashes
  53. - Example: access_key: "3192253c-c11c-75ed-6334-ea9505da4ad9", the secret_key can be any string
  54. - Used by legacy worker tokens that use UUID as identifier
  55. - Falls back to custom_key_hash for backward compatibility
  56. 3. Custom/unrecognized format key:
  57. - Example: "any_random_string_here", "sk-xxx"
  58. - Any other string format that doesn't match standard format
  59. - Returns hashed value for storage, original value for lookup
  60. - Used for backward compatibility with non-standard API keys
  61. Returns:
  62. Tuple of (access_key, secret_key):
  63. - For standard format: returns the parsed access_key and secret_key
  64. - For non-standard format: returns (custom_key_hash(key), key)
  65. """
  66. valid, access_key, secret_key = is_valid_format(key)
  67. if not valid:
  68. return custom_key_hash(key), key
  69. return access_key, secret_key
  70. AUTH_CACHE_HEADER = "x-gpustack-auth-cache"
  71. class JWTManager:
  72. def __init__(
  73. self,
  74. secret_key: str,
  75. algorithm: str = "HS256",
  76. expires_delta: Optional[timedelta] = None,
  77. ):
  78. if expires_delta is None:
  79. expires_delta = timedelta(minutes=envs.JWT_TOKEN_EXPIRE_MINUTES)
  80. self.secret_key = secret_key
  81. self.algorithm = algorithm
  82. self.expires_delta = expires_delta
  83. def create_jwt_token(self, username: str):
  84. to_encode = {"sub": username}
  85. expire = datetime.now(timezone.utc) + self.expires_delta
  86. to_encode.update({"exp": expire})
  87. encoded_jwt = jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
  88. return encoded_jwt
  89. def create_token(self, payload: dict, expires_delta: Optional[timedelta] = None):
  90. delta = expires_delta if expires_delta is not None else self.expires_delta
  91. to_encode = {"data": payload, "exp": datetime.now(timezone.utc) + delta}
  92. return jwt.encode(to_encode, self.secret_key, algorithm=self.algorithm)
  93. def decode_jwt_token(self, token: str):
  94. return jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
  95. def decode_jwt_data(self, token: str) -> dict:
  96. return self.decode_jwt_token(token)["data"]