""" 从远程 PostgreSQL 导出完整结构和数据为 SQL 文件。 用法: python dump_pg_to_sql.py > backup/maas_collect_init.sql """ import os import sys import io import psycopg2 from dotenv import load_dotenv # 确保 stdout 使用 UTF-8 if sys.stdout.encoding != 'utf-8': sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') load_dotenv() DB_USER = os.environ.get("DB_USER", "") DB_PASSWORD = os.environ.get("DB_PASSWORD", "") DB_HOST = os.environ.get("DB_HOST", "") DB_PORT = os.environ.get("DB_PORT", "5432") DB_NAME = os.environ.get("DB_NAME", "") if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]): print("ERROR: 缺少 PG 配置", file=sys.stderr) sys.exit(1) PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" # 需要跳过的系统表 SKIP_TABLES = {"alembic_version"} def sql_escape_string(val): if val is None: return "NULL" return "'" + str(val).replace("'", "''") + "'" def dump(): conn = psycopg2.connect(PG_URI) cur = conn.cursor() # 获取所有用户表 cur.execute(""" SELECT tablename FROM pg_tables WHERE schemaname = 'public' ORDER BY tablename """) tables = [r[0] for r in cur.fetchall() if r[0] not in SKIP_TABLES] print("-- ============================================") print(f"-- Dump of {DB_NAME} from {DB_HOST}") print("-- ============================================") print() # 需要特殊处理的保留字表名 QUOTED_TABLES = {"user"} for table in tables: print(f"-- Table: {table}") safe_table = f'"{table}"' if table in QUOTED_TABLES else table # 获取列信息 cur.execute(f""" SELECT column_name, data_type, character_maximum_length, is_nullable FROM information_schema.columns WHERE table_name = '{table}' ORDER BY ordinal_position """) columns = cur.fetchall() # 获取主键 cur.execute(f""" SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid = '{safe_table}'::regclass AND i.indisprimary """) pk_cols = [r[0] for r in cur.fetchall()] # 获取唯一约束 cur.execute(f""" SELECT conname, array_agg(a.attname ORDER BY array_position(conkey, a.attnum)) FROM pg_constraint c JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = ANY(c.conkey) WHERE c.contype = 'u' AND c.conrelid = '{safe_table}'::regclass GROUP BY conname, c.conrelid """) unique_constraints = cur.fetchall() # 生成 CREATE TABLE col_defs = [] pk_col_set = set(pk_cols) for col_name, data_type, char_max_len, is_nullable in columns: # 主键 integer 列:使用 SERIAL 替代 integer NOT NULL + 单独的 PRIMARY KEY if col_name in pk_col_set and data_type == "integer": col_defs.append(f" {col_name} SERIAL PRIMARY KEY") continue type_str = data_type if data_type == "character varying" and char_max_len: type_str = f"VARCHAR({char_max_len})" elif data_type == "character": type_str = f"CHAR({char_max_len or 1})" elif data_type == "timestamp without time zone": type_str = "TIMESTAMP" elif data_type == "double precision": type_str = "DOUBLE PRECISION" null_str = "NOT NULL" if is_nullable == "NO" else "NULL" col_defs.append(f" {col_name} {type_str} {null_str}") if pk_cols and not any(c.startswith(" " + pk_cols[0] + " SERIAL") for c in col_defs): col_defs.append(f" PRIMARY KEY ({', '.join(pk_cols)})") for uc_name, uc_cols in unique_constraints: col_defs.append(f" CONSTRAINT {uc_name} UNIQUE ({', '.join(uc_cols)})") print(f"DROP TABLE IF EXISTS {safe_table} CASCADE;") print(f"CREATE TABLE {safe_table} (") print(",\n".join(col_defs)) print(f");") print() # 生成 INSERT 语句 col_names = [c[0] for c in columns] order_col = pk_cols[0] if pk_cols else col_names[0] cur.execute(f"SELECT * FROM {safe_table} ORDER BY {order_col}") rows = cur.fetchall() if rows: for row in rows: values = [] for i, val in enumerate(row): col_name = col_names[i] col_type = columns[i][1] if val is None: values.append("NULL") elif col_type == "boolean": values.append("TRUE" if val else "FALSE") elif col_type in ("integer", "bigint", "smallint", "numeric", "double precision", "real"): values.append(str(val)) else: values.append(sql_escape_string(val)) cols_str = ", ".join(col_names) vals_str = ", ".join(values) print(f"INSERT INTO {safe_table} ({cols_str}) VALUES ({vals_str});") print() # 重置序列 if pk_cols: try: cur.execute(f"SELECT setval('{table}_{pk_cols[0]}_seq', (SELECT MAX({pk_cols[0]}) FROM {safe_table}))") print("-- Sequence reset done") except psycopg2.Error: pass print() cur.close() conn.close() print("-- Dump complete", file=sys.stderr) if __name__ == "__main__": output_file = os.path.join(os.path.dirname(__file__), "backup", "maas_collect_init.sql") os.makedirs(os.path.dirname(output_file), exist_ok=True) old_stdout = sys.stdout with open(output_file, "w", encoding="utf-8") as f: sys.stdout = f dump() sys.stdout = old_stdout print("Dump complete: backup/maas_collect_init.sql", file=sys.stderr)