sql_enum.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import Dict, Tuple, List
  2. from alembic import op
  3. import sqlalchemy as sa
  4. from sqlalchemy.engine import Connection
  5. def add_enum_values(
  6. table_columns: Dict[str, str], original_enum: sa.Enum, *to_add_values: str
  7. ):
  8. """
  9. add_enum_values add new values to an existing enum type in the database.
  10. :param table_columns: a dictionary mapping table names to their column definitions
  11. :type table_columns: Dict[str, str]
  12. :param original_enum: existing enum type in the database
  13. :type original_enum: sa.Enum
  14. :param to_add_values: new values to add to the enum
  15. :type to_add_values: Tuple[str, ...]
  16. """
  17. if len(to_add_values) == 0:
  18. return
  19. conn = op.get_bind()
  20. if conn.dialect.name == 'postgresql':
  21. for value in to_add_values:
  22. conn.execute(
  23. sa.text(f"ALTER TYPE {original_enum.name} ADD VALUE '{value}'")
  24. )
  25. elif conn.dialect.name == 'mysql':
  26. add_mysql_enum_values(table_columns, *to_add_values)
  27. def add_mysql_enum_values(table_columns: Dict[str, str], *to_add_values: str):
  28. conn = op.get_bind()
  29. for table_name, column_name in table_columns.items():
  30. modify_mysql_table_column_enum(
  31. conn, table_name, column_name, list(to_add_values), []
  32. )
  33. def modify_mysql_table_column_enum(
  34. conn: Connection,
  35. table_name: str,
  36. column_name: str,
  37. to_add_values: List[str],
  38. to_remove_values: List[str],
  39. ):
  40. result = conn.execute(
  41. sa.text(
  42. f"""
  43. SELECT COLUMN_TYPE
  44. FROM information_schema.COLUMNS
  45. WHERE TABLE_NAME = '{table_name}'
  46. AND COLUMN_NAME = '{column_name}'
  47. AND TABLE_SCHEMA = DATABASE()
  48. """
  49. )
  50. ).scalar()
  51. existing_values = []
  52. if result:
  53. enum_str = result.split("enum(")[1].split(")")[0]
  54. existing_values = [v.strip("'") for v in enum_str.split("','")]
  55. new_values = [v for v in existing_values if v not in to_remove_values]
  56. new_values.extend(to_add_values)
  57. if set(new_values) != set(existing_values):
  58. new_enum_str = "enum('" + "','".join(new_values) + "')"
  59. # Construct new ALTER TABLE statement
  60. alter_sql = (
  61. f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {new_enum_str};"
  62. )
  63. # Execute modification
  64. conn.execute(sa.text(alter_sql))
  65. def remove_postgres_enum_values(
  66. conn: Connection,
  67. table_name: str,
  68. column_name: str,
  69. original_enum: sa.Enum,
  70. *to_remove_values: str,
  71. ):
  72. new_enum_values_str = ','.join(
  73. [repr(v) for v in original_enum.enums if v not in to_remove_values]
  74. )
  75. conn.execute(
  76. sa.text(f"CREATE TYPE {original_enum.name}tmp AS ENUM ({new_enum_values_str});")
  77. )
  78. conn.execute(
  79. sa.text(
  80. f"ALTER TABLE {table_name} ALTER COLUMN {column_name} TYPE {original_enum.name}tmp USING {column_name}::text::{original_enum.name}tmp;"
  81. )
  82. )
  83. conn.execute(sa.text(f"DROP TYPE {original_enum.name};"))
  84. conn.execute(
  85. sa.text(f"ALTER TYPE {original_enum.name}tmp RENAME TO {original_enum.name};")
  86. )
  87. def remove_enum_values(
  88. table_columns: Dict[str, Tuple[str, str]],
  89. original_enum: sa.Enum,
  90. *to_remove_values: str,
  91. ):
  92. """
  93. remove_enum_values removes specified values from an existing enum type in the database.
  94. :param table_columns: a dictionary mapping table names to their column definitions
  95. :type table_columns: Dict[str, Tuple[str, str]]
  96. :param original_enum: existing enum type in the database
  97. :type original_enum: sa.Enum
  98. :param to_remove_values: values to remove from the enum
  99. :type to_remove_values: Tuple[str, ...]
  100. """
  101. if len(to_remove_values) == 0:
  102. return
  103. conn = op.get_bind()
  104. for table_name, (column_name, default_value) in table_columns.items():
  105. conn.execute(
  106. sa.text(
  107. f"""
  108. UPDATE {table_name}
  109. SET {column_name} = {repr(default_value)}
  110. WHERE {column_name} IN ({','.join([repr(v) for v in to_remove_values])});
  111. """
  112. )
  113. )
  114. if conn.dialect.name == 'mysql':
  115. modify_mysql_table_column_enum(
  116. conn, table_name, column_name, [], list(to_remove_values)
  117. )
  118. if conn.dialect.name == 'postgresql':
  119. remove_postgres_enum_values(
  120. conn, table_name, column_name, original_enum, *to_remove_values
  121. )