dump_pg_to_sql.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. """
  2. 从远程 PostgreSQL 导出完整结构和数据为 SQL 文件。
  3. 用法: python dump_pg_to_sql.py > backup/maas_collect_init.sql
  4. """
  5. import os
  6. import sys
  7. import io
  8. import psycopg2
  9. from dotenv import load_dotenv
  10. # 确保 stdout 使用 UTF-8
  11. if sys.stdout.encoding != 'utf-8':
  12. sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
  13. load_dotenv()
  14. DB_USER = os.environ.get("DB_USER", "")
  15. DB_PASSWORD = os.environ.get("DB_PASSWORD", "")
  16. DB_HOST = os.environ.get("DB_HOST", "")
  17. DB_PORT = os.environ.get("DB_PORT", "5432")
  18. DB_NAME = os.environ.get("DB_NAME", "")
  19. if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]):
  20. print("ERROR: 缺少 PG 配置", file=sys.stderr)
  21. sys.exit(1)
  22. PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
  23. # 需要跳过的系统表
  24. SKIP_TABLES = {"alembic_version"}
  25. def sql_escape_string(val):
  26. if val is None:
  27. return "NULL"
  28. return "'" + str(val).replace("'", "''") + "'"
  29. def dump():
  30. conn = psycopg2.connect(PG_URI)
  31. cur = conn.cursor()
  32. # 获取所有用户表
  33. cur.execute("""
  34. SELECT tablename FROM pg_tables
  35. WHERE schemaname = 'public'
  36. ORDER BY tablename
  37. """)
  38. tables = [r[0] for r in cur.fetchall() if r[0] not in SKIP_TABLES]
  39. print("-- ============================================")
  40. print(f"-- Dump of {DB_NAME} from {DB_HOST}")
  41. print("-- ============================================")
  42. print()
  43. # 需要特殊处理的保留字表名
  44. QUOTED_TABLES = {"user"}
  45. for table in tables:
  46. print(f"-- Table: {table}")
  47. safe_table = f'"{table}"' if table in QUOTED_TABLES else table
  48. # 获取列信息
  49. cur.execute(f"""
  50. SELECT column_name, data_type, character_maximum_length, is_nullable
  51. FROM information_schema.columns
  52. WHERE table_name = '{table}'
  53. ORDER BY ordinal_position
  54. """)
  55. columns = cur.fetchall()
  56. # 获取主键
  57. cur.execute(f"""
  58. SELECT a.attname
  59. FROM pg_index i
  60. JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
  61. WHERE i.indrelid = '{safe_table}'::regclass AND i.indisprimary
  62. """)
  63. pk_cols = [r[0] for r in cur.fetchall()]
  64. # 获取唯一约束
  65. cur.execute(f"""
  66. SELECT conname,
  67. array_agg(a.attname ORDER BY array_position(conkey, a.attnum))
  68. FROM pg_constraint c
  69. JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = ANY(c.conkey)
  70. WHERE c.contype = 'u' AND c.conrelid = '{safe_table}'::regclass
  71. GROUP BY conname, c.conrelid
  72. """)
  73. unique_constraints = cur.fetchall()
  74. # 生成 CREATE TABLE
  75. col_defs = []
  76. pk_col_set = set(pk_cols)
  77. for col_name, data_type, char_max_len, is_nullable in columns:
  78. # 主键 integer 列:使用 SERIAL 替代 integer NOT NULL + 单独的 PRIMARY KEY
  79. if col_name in pk_col_set and data_type == "integer":
  80. col_defs.append(f" {col_name} SERIAL PRIMARY KEY")
  81. continue
  82. type_str = data_type
  83. if data_type == "character varying" and char_max_len:
  84. type_str = f"VARCHAR({char_max_len})"
  85. elif data_type == "character":
  86. type_str = f"CHAR({char_max_len or 1})"
  87. elif data_type == "timestamp without time zone":
  88. type_str = "TIMESTAMP"
  89. elif data_type == "double precision":
  90. type_str = "DOUBLE PRECISION"
  91. null_str = "NOT NULL" if is_nullable == "NO" else "NULL"
  92. col_defs.append(f" {col_name} {type_str} {null_str}")
  93. if pk_cols and not any(c.startswith(" " + pk_cols[0] + " SERIAL") for c in col_defs):
  94. col_defs.append(f" PRIMARY KEY ({', '.join(pk_cols)})")
  95. for uc_name, uc_cols in unique_constraints:
  96. col_defs.append(f" CONSTRAINT {uc_name} UNIQUE ({', '.join(uc_cols)})")
  97. print(f"DROP TABLE IF EXISTS {safe_table} CASCADE;")
  98. print(f"CREATE TABLE {safe_table} (")
  99. print(",\n".join(col_defs))
  100. print(f");")
  101. print()
  102. # 生成 INSERT 语句
  103. col_names = [c[0] for c in columns]
  104. order_col = pk_cols[0] if pk_cols else col_names[0]
  105. cur.execute(f"SELECT * FROM {safe_table} ORDER BY {order_col}")
  106. rows = cur.fetchall()
  107. if rows:
  108. for row in rows:
  109. values = []
  110. for i, val in enumerate(row):
  111. col_name = col_names[i]
  112. col_type = columns[i][1]
  113. if val is None:
  114. values.append("NULL")
  115. elif col_type == "boolean":
  116. values.append("TRUE" if val else "FALSE")
  117. elif col_type in ("integer", "bigint", "smallint", "numeric", "double precision", "real"):
  118. values.append(str(val))
  119. else:
  120. values.append(sql_escape_string(val))
  121. cols_str = ", ".join(col_names)
  122. vals_str = ", ".join(values)
  123. print(f"INSERT INTO {safe_table} ({cols_str}) VALUES ({vals_str});")
  124. print()
  125. # 重置序列
  126. if pk_cols:
  127. try:
  128. cur.execute(f"SELECT setval('{table}_{pk_cols[0]}_seq', (SELECT MAX({pk_cols[0]}) FROM {safe_table}))")
  129. print("-- Sequence reset done")
  130. except psycopg2.Error:
  131. pass
  132. print()
  133. cur.close()
  134. conn.close()
  135. print("-- Dump complete", file=sys.stderr)
  136. if __name__ == "__main__":
  137. output_file = os.path.join(os.path.dirname(__file__), "backup", "maas_collect_init.sql")
  138. os.makedirs(os.path.dirname(output_file), exist_ok=True)
  139. old_stdout = sys.stdout
  140. with open(output_file, "w", encoding="utf-8") as f:
  141. sys.stdout = f
  142. dump()
  143. sys.stdout = old_stdout
  144. print("Dump complete: backup/maas_collect_init.sql", file=sys.stderr)