""" SQLite → PostgreSQL 数据迁移脚本 用法: python migrate_sqlite_to_pg.py [--dry-run] --dry-run 只打印计划,不执行写入 """ import os import sys import sqlite3 import argparse from dotenv import load_dotenv load_dotenv() # ── PostgreSQL 连接 ── DB_USER = os.environ.get("DB_USER", "") DB_PASSWORD = os.environ.get("DB_PASSWORD", "") DB_HOST = os.environ.get("DB_HOST", "") DB_PORT = os.environ.get("DB_PORT", "5432") DB_NAME = os.environ.get("DB_NAME", "") if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]): print("ERROR: 缺少 PG 配置 (DB_USER/DB_PASSWORD/DB_HOST/DB_NAME)") sys.exit(1) PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" # ── SQLite 路径 ── SQLITE_PATH = os.path.join(os.path.dirname(__file__), "app", "app.db") if not os.path.exists(SQLITE_PATH): print(f"ERROR: 找不到 SQLite 数据库: {SQLITE_PATH}") sys.exit(1) # ── 布尔列(SQLite 存 0/1,需转换为 PG boolean) ── BOOL_COLUMNS = {"has_pagination", "has_deep_collection", "is_active"} # ── 迁移顺序(按外键依赖) ── # (sqlite表名, pg表名, 跳过列, 需要额外处理的列) MIGRATION_PLAN = [ ("user", '"user"', set(), {}), ("spider_source", "spider_source", set(), {}), ("collection_task", "collection_task", set(), {}), # spider_task 是旧表,已不在 models.py 中,跳过 ("spider_result", "spider_result", set(), {}), ("deep_collection", "deep_collection", set(), {}), ("ai_model", "ai_model", set(), {}), ("token_usage_log", "token_usage_log", set(), {}), ("ai_conversation", "ai_conversation", set(), {}), ("ai_message", "ai_message", set(), {}), ("knowledge_import_task", "knowledge_import_task", set(), {}), ] def get_sqlite_tables(): conn = sqlite3.connect(SQLITE_PATH) cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = {r[0] for r in cursor.fetchall()} conn.close() return tables def get_sqlite_data(table): conn = sqlite3.connect(SQLITE_PATH) cursor = conn.execute(f"SELECT * FROM {table}") columns = [desc[0] for desc in cursor.description] rows = cursor.fetchall() conn.close() return columns, rows def quote_identifier(name): """如果标识符未被引号包裹,添加双引号。""" if name.startswith('"') and name.endswith('"'): return name return f'"{name}"' def migrate(dry_run=False): import psycopg2 print(f"SQLite: {SQLITE_PATH}") print(f"PG: {PG_URI.replace(DB_PASSWORD, '***')}") print() pg = psycopg2.connect(PG_URI) pg.autocommit = False cur = pg.cursor() sqlite_tables = get_sqlite_tables() total_inserted = 0 total_skipped = 0 for sqlite_table, pg_table, skip_cols, _ in MIGRATION_PLAN: if sqlite_table not in sqlite_tables: print(f" [skip] SQLite 中不存在表 '{sqlite_table}'") continue columns, rows = get_sqlite_data(sqlite_table) insert_cols = [c for c in columns if c not in skip_cols] col_indices = [i for i, c in enumerate(columns) if c not in skip_cols] if not rows: print(f" [empty] {sqlite_table} -> {pg_table} (0 行)") continue print(f" [{sqlite_table}] -> [{pg_table}]: {len(rows)} 行") if dry_run: total_inserted += len(rows) continue # 禁用触发器(跳过 FK 检查)来清空表 quoted = quote_identifier(pg_table) try: cur.execute(f"ALTER TABLE {quoted} DISABLE TRIGGER ALL") cur.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE") cur.execute(f"ALTER TABLE {quoted} ENABLE TRIGGER ALL") except psycopg2.Error as e: print(f" WARN: TRUNCATE {pg_table} 失败: {e}") pg.rollback() continue # 批量插入 inserted = 0 skipped = 0 for row in rows: values = [] for i, idx in enumerate(col_indices): val = row[idx] col_name = columns[idx] if val in (0, 1) and col_name in BOOL_COLUMNS: val = bool(val) values.append(val) placeholders = ", ".join(["%s"] * len(values)) col_names = ", ".join(insert_cols) sql = f"INSERT INTO {quoted} ({col_names}) VALUES ({placeholders})" try: cur.execute(sql, values) inserted += 1 except psycopg2.Error as e: skipped += 1 pg.rollback() # 对于 url 超长问题,尝试用 TEXT 列存储 if "value too long" in str(e) and sqlite_table == "deep_collection": try: # 临时修改列类型 cur.execute("ALTER TABLE deep_collection ALTER COLUMN url TYPE TEXT") cur.execute(sql, values) inserted += 1 skipped -= 1 print(f" NOTE: deep_collection.url 已自动改为 TEXT 类型") except psycopg2.Error as e2: print(f" WARN: 插入 {pg_table} 失败: {e2}") continue else: print(f" WARN: 插入 {pg_table} 失败 (row {inserted + skipped}): {e}") continue # 重置序列 if inserted > 0: try: seq_sql = f"SELECT setval(pg_get_serial_sequence('{pg_table}', 'id'), (SELECT COALESCE(MAX(id), 1) FROM {quoted}))" cur.execute(seq_sql) except psycopg2.Error: pass total_inserted += inserted total_skipped += skipped status = "OK" if skipped == 0 else "PARTIAL" print(f" [{status}] 插入 {inserted} 行, 跳过 {skipped} 行") if not dry_run: pg.commit() print(f"\n迁移完成! 共插入 {total_inserted} 行, 跳过 {total_skipped} 行") else: print(f"\n[DRY RUN] 将插入 ~{total_inserted} 行") cur.close() pg.close() if __name__ == "__main__": parser = argparse.ArgumentParser(description="SQLite → PostgreSQL 数据迁移") parser.add_argument("--dry-run", action="store_true", help="只打印计划,不执行") args = parser.parse_args() migrate(dry_run=args.dry_run)