| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- """
- 从远程 PostgreSQL 导出完整结构和数据为 SQL 文件。
- 用法: python dump_pg_to_sql.py > backup/maas_collect_init.sql
- """
- import os
- import sys
- import io
- import psycopg2
- from dotenv import load_dotenv
- # 确保 stdout 使用 UTF-8
- if sys.stdout.encoding != 'utf-8':
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
- load_dotenv()
- DB_USER = os.environ.get("DB_USER", "")
- DB_PASSWORD = os.environ.get("DB_PASSWORD", "")
- DB_HOST = os.environ.get("DB_HOST", "")
- DB_PORT = os.environ.get("DB_PORT", "5432")
- DB_NAME = os.environ.get("DB_NAME", "")
- if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]):
- print("ERROR: 缺少 PG 配置", file=sys.stderr)
- sys.exit(1)
- PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
- # 需要跳过的系统表
- SKIP_TABLES = {"alembic_version"}
- def sql_escape_string(val):
- if val is None:
- return "NULL"
- return "'" + str(val).replace("'", "''") + "'"
- def dump():
- conn = psycopg2.connect(PG_URI)
- cur = conn.cursor()
- # 获取所有用户表
- cur.execute("""
- SELECT tablename FROM pg_tables
- WHERE schemaname = 'public'
- ORDER BY tablename
- """)
- tables = [r[0] for r in cur.fetchall() if r[0] not in SKIP_TABLES]
- print("-- ============================================")
- print(f"-- Dump of {DB_NAME} from {DB_HOST}")
- print("-- ============================================")
- print()
- # 需要特殊处理的保留字表名
- QUOTED_TABLES = {"user"}
- for table in tables:
- print(f"-- Table: {table}")
- safe_table = f'"{table}"' if table in QUOTED_TABLES else table
- # 获取列信息
- cur.execute(f"""
- SELECT column_name, data_type, character_maximum_length, is_nullable
- FROM information_schema.columns
- WHERE table_name = '{table}'
- ORDER BY ordinal_position
- """)
- columns = cur.fetchall()
- # 获取主键
- cur.execute(f"""
- SELECT a.attname
- FROM pg_index i
- JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
- WHERE i.indrelid = '{safe_table}'::regclass AND i.indisprimary
- """)
- pk_cols = [r[0] for r in cur.fetchall()]
- # 获取唯一约束
- cur.execute(f"""
- SELECT conname,
- array_agg(a.attname ORDER BY array_position(conkey, a.attnum))
- FROM pg_constraint c
- JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = ANY(c.conkey)
- WHERE c.contype = 'u' AND c.conrelid = '{safe_table}'::regclass
- GROUP BY conname, c.conrelid
- """)
- unique_constraints = cur.fetchall()
- # 生成 CREATE TABLE
- col_defs = []
- pk_col_set = set(pk_cols)
- for col_name, data_type, char_max_len, is_nullable in columns:
- # 主键 integer 列:使用 SERIAL 替代 integer NOT NULL + 单独的 PRIMARY KEY
- if col_name in pk_col_set and data_type == "integer":
- col_defs.append(f" {col_name} SERIAL PRIMARY KEY")
- continue
- type_str = data_type
- if data_type == "character varying" and char_max_len:
- type_str = f"VARCHAR({char_max_len})"
- elif data_type == "character":
- type_str = f"CHAR({char_max_len or 1})"
- elif data_type == "timestamp without time zone":
- type_str = "TIMESTAMP"
- elif data_type == "double precision":
- type_str = "DOUBLE PRECISION"
- null_str = "NOT NULL" if is_nullable == "NO" else "NULL"
- col_defs.append(f" {col_name} {type_str} {null_str}")
- if pk_cols and not any(c.startswith(" " + pk_cols[0] + " SERIAL") for c in col_defs):
- col_defs.append(f" PRIMARY KEY ({', '.join(pk_cols)})")
- for uc_name, uc_cols in unique_constraints:
- col_defs.append(f" CONSTRAINT {uc_name} UNIQUE ({', '.join(uc_cols)})")
- print(f"DROP TABLE IF EXISTS {safe_table} CASCADE;")
- print(f"CREATE TABLE {safe_table} (")
- print(",\n".join(col_defs))
- print(f");")
- print()
- # 生成 INSERT 语句
- col_names = [c[0] for c in columns]
- order_col = pk_cols[0] if pk_cols else col_names[0]
- cur.execute(f"SELECT * FROM {safe_table} ORDER BY {order_col}")
- rows = cur.fetchall()
- if rows:
- for row in rows:
- values = []
- for i, val in enumerate(row):
- col_name = col_names[i]
- col_type = columns[i][1]
- if val is None:
- values.append("NULL")
- elif col_type == "boolean":
- values.append("TRUE" if val else "FALSE")
- elif col_type in ("integer", "bigint", "smallint", "numeric", "double precision", "real"):
- values.append(str(val))
- else:
- values.append(sql_escape_string(val))
- cols_str = ", ".join(col_names)
- vals_str = ", ".join(values)
- print(f"INSERT INTO {safe_table} ({cols_str}) VALUES ({vals_str});")
- print()
- # 重置序列
- if pk_cols:
- try:
- cur.execute(f"SELECT setval('{table}_{pk_cols[0]}_seq', (SELECT MAX({pk_cols[0]}) FROM {safe_table}))")
- print("-- Sequence reset done")
- except psycopg2.Error:
- pass
- print()
- cur.close()
- conn.close()
- print("-- Dump complete", file=sys.stderr)
- if __name__ == "__main__":
- output_file = os.path.join(os.path.dirname(__file__), "backup", "maas_collect_init.sql")
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
- old_stdout = sys.stdout
- with open(output_file, "w", encoding="utf-8") as f:
- sys.stdout = f
- dump()
- sys.stdout = old_stdout
- print("Dump complete: backup/maas_collect_init.sql", file=sys.stderr)
|