api_keys.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from enum import Enum
  2. from datetime import datetime
  3. from typing import ClassVar, Optional, List, TYPE_CHECKING
  4. from sqlalchemy import Column, ForeignKey, Integer, UniqueConstraint
  5. from sqlmodel import Field, SQLModel, Text, JSON, Relationship
  6. from gpustack.mixins import BaseModelMixin
  7. from gpustack.schemas.common import ListParams, PaginatedList, UTCDateTime
  8. from gpustack.schemas.principals import PLATFORM_PRINCIPAL_ID
  9. if TYPE_CHECKING:
  10. from gpustack.schemas.users import User
  11. class PermissionScope(str, Enum):
  12. """
  13. Permission scope for API key access control.
  14. Currently supports coarse-grained scopes. Future extensions may include:
  15. - management.readonly: Read-only API access (GET requests only)
  16. - management.write: Full API write access
  17. - inference.chat: Chat completion endpoints only
  18. - inference.embeddings: Embeddings endpoints only
  19. - inference.completions: Completions endpoints only
  20. """
  21. ALL = "*"
  22. MANAGEMENT = "management"
  23. INFERENCE = "inference"
  24. class ApiKeyUpdate(SQLModel):
  25. allowed_model_names: Optional[List[str]] = Field(
  26. default=None,
  27. sa_column=Column(JSON, nullable=True),
  28. )
  29. description: Optional[str] = Field(
  30. default=None, sa_column=Column(Text, nullable=True)
  31. )
  32. scope: List[PermissionScope] = Field(
  33. default=[PermissionScope.ALL],
  34. sa_column=Column(JSON, nullable=False),
  35. )
  36. class ApiKeyBase(ApiKeyUpdate):
  37. name: str
  38. class ApiKey(ApiKeyBase, BaseModelMixin, table=True):
  39. __tablename__ = 'api_keys'
  40. __table_args__ = (
  41. UniqueConstraint(
  42. 'user_id', 'owner_principal_id', 'name', name='uix_user_org_name'
  43. ),
  44. )
  45. id: Optional[int] = Field(default=None, primary_key=True)
  46. access_key: str = Field(unique=True, index=True)
  47. hashed_secret_key: str = Field(unique=True)
  48. user_id: int = Field(foreign_key='users.id', nullable=False)
  49. owner_principal_id: int = Field(
  50. default=PLATFORM_PRINCIPAL_ID,
  51. sa_column=Column(
  52. Integer,
  53. ForeignKey("principals.id", ondelete="CASCADE"),
  54. nullable=False,
  55. ),
  56. )
  57. expires_at: Optional[datetime] = Field(sa_column=Column(UTCDateTime), default=None)
  58. user: Optional["User"] = Relationship(
  59. back_populates="api_keys",
  60. sa_relationship_kwargs={"lazy": "noload"},
  61. )
  62. is_custom: bool = Field(default=False, nullable=False)
  63. @property
  64. def user_name(self) -> Optional[str]:
  65. return self.user.username if self.user else None
  66. class ApiKeyListParams(ListParams):
  67. sortable_fields: ClassVar[List[str]] = [
  68. "name",
  69. "expires_at",
  70. "created_at",
  71. "updated_at",
  72. ]
  73. class ApiKeyCreate(ApiKeyBase):
  74. expires_in: Optional[int] = None
  75. custom: Optional[str] = None
  76. class ApiKeyPublic(ApiKeyBase):
  77. id: int
  78. user_name: Optional[str] = None
  79. value: Optional[str] = None # only available when creating
  80. masked_value: Optional[str] = None # partial characters for identification
  81. is_custom: bool
  82. created_at: datetime
  83. updated_at: datetime
  84. expires_at: Optional[datetime] = None
  85. ApiKeysPublic = PaginatedList[ApiKeyPublic]