2025_10_09_1037-eeacfbc6a2bf_model_access_control.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. """Model Access Control
  2. Revision ID: eeacfbc6a2bf
  3. Revises: 2025_10_07_add_is_active
  4. Create Date: 2025-10-09 10:37:20.646154
  5. """
  6. from typing import Sequence, Union
  7. from alembic import op
  8. import sqlalchemy as sa
  9. import sqlmodel
  10. import gpustack
  11. from gpustack.migrations.utils import table_exists
  12. # revision identifiers, used by Alembic.
  13. revision: str = 'eeacfbc6a2bf'
  14. down_revision: Union[str, None] = '2025_10_07_add_is_active'
  15. branch_labels: Union[str, Sequence[str], None] = None
  16. depends_on: Union[str, Sequence[str], None] = None
  17. def access_control_upgrade() -> None:
  18. access_policy_enum = sa.Enum(
  19. 'PUBLIC',
  20. 'AUTHED',
  21. 'ALLOWED_USERS',
  22. name='accesspolicyenum',
  23. )
  24. bind = op.get_bind()
  25. if bind.dialect.name in ('postgresql', 'mysql'):
  26. access_policy_enum.create(bind, checkfirst=True)
  27. with op.batch_alter_table('models', schema=None) as batch_op:
  28. batch_op.add_column(sa.Column('access_policy', access_policy_enum, nullable=True, server_default='AUTHED'))
  29. op.execute(
  30. "UPDATE models SET access_policy='AUTHED' WHERE access_policy IS NULL"
  31. )
  32. with op.batch_alter_table('models', schema=None) as batch_op:
  33. batch_op.alter_column('access_policy', existing_type=access_policy_enum, nullable=False)
  34. with op.batch_alter_table('api_keys', schema=None) as batch_op:
  35. batch_op.add_column(sa.Column('allowed_model_names', sa.JSON(), nullable=True))
  36. if not table_exists('modeluserlink'):
  37. op.create_table('modeluserlink',
  38. sa.Column('model_id', sa.Integer(), nullable=False),
  39. sa.Column('user_id', sa.Integer(), nullable=False),
  40. sa.ForeignKeyConstraint(['model_id'], ['models.id'], name='fk_model_user_link_models', ondelete='CASCADE'),
  41. sa.ForeignKeyConstraint(['user_id'], ['users.id'], name='fk_model_user_link_users', ondelete='CASCADE'),
  42. sa.PrimaryKeyConstraint('model_id', 'user_id')
  43. )
  44. def access_control_downgrade() -> None:
  45. with op.batch_alter_table('api_keys', schema=None) as batch_op:
  46. batch_op.drop_column('allowed_model_names')
  47. if table_exists('modeluserlink'):
  48. op.drop_table('modeluserlink')
  49. with op.batch_alter_table('models', schema=None) as batch_op:
  50. batch_op.drop_column('access_policy')
  51. access_policy_enum = sa.Enum(
  52. 'PUBLIC',
  53. 'AUTHED',
  54. 'ALLOWED_USERS',
  55. name='accesspolicyenum',
  56. )
  57. access_policy_enum.drop(op.get_bind(), checkfirst=True)
  58. def upgrade() -> None:
  59. access_control_upgrade()
  60. def downgrade() -> None:
  61. access_control_downgrade()