| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- """
- SQLite → PostgreSQL 数据迁移脚本
- 用法:
- python migrate_sqlite_to_pg.py [--dry-run]
- --dry-run 只打印计划,不执行写入
- """
- import os
- import sys
- import sqlite3
- import argparse
- from dotenv import load_dotenv
- load_dotenv()
- # ── PostgreSQL 连接 ──
- DB_USER = os.environ.get("DB_USER", "")
- DB_PASSWORD = os.environ.get("DB_PASSWORD", "")
- DB_HOST = os.environ.get("DB_HOST", "")
- DB_PORT = os.environ.get("DB_PORT", "5432")
- DB_NAME = os.environ.get("DB_NAME", "")
- if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]):
- print("ERROR: 缺少 PG 配置 (DB_USER/DB_PASSWORD/DB_HOST/DB_NAME)")
- sys.exit(1)
- PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
- # ── SQLite 路径 ──
- SQLITE_PATH = os.path.join(os.path.dirname(__file__), "app", "app.db")
- if not os.path.exists(SQLITE_PATH):
- print(f"ERROR: 找不到 SQLite 数据库: {SQLITE_PATH}")
- sys.exit(1)
- # ── 布尔列(SQLite 存 0/1,需转换为 PG boolean) ──
- BOOL_COLUMNS = {"has_pagination", "has_deep_collection", "is_active"}
- # ── 迁移顺序(按外键依赖) ──
- # (sqlite表名, pg表名, 跳过列, 需要额外处理的列)
- MIGRATION_PLAN = [
- ("user", '"user"', set(), {}),
- ("spider_source", "spider_source", set(), {}),
- ("collection_task", "collection_task", set(), {}),
- # spider_task 是旧表,已不在 models.py 中,跳过
- ("spider_result", "spider_result", set(), {}),
- ("deep_collection", "deep_collection", set(), {}),
- ("ai_model", "ai_model", set(), {}),
- ("token_usage_log", "token_usage_log", set(), {}),
- ("ai_conversation", "ai_conversation", set(), {}),
- ("ai_message", "ai_message", set(), {}),
- ("knowledge_import_task", "knowledge_import_task", set(), {}),
- ]
- def get_sqlite_tables():
- conn = sqlite3.connect(SQLITE_PATH)
- cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
- tables = {r[0] for r in cursor.fetchall()}
- conn.close()
- return tables
- def get_sqlite_data(table):
- conn = sqlite3.connect(SQLITE_PATH)
- cursor = conn.execute(f"SELECT * FROM {table}")
- columns = [desc[0] for desc in cursor.description]
- rows = cursor.fetchall()
- conn.close()
- return columns, rows
- def quote_identifier(name):
- """如果标识符未被引号包裹,添加双引号。"""
- if name.startswith('"') and name.endswith('"'):
- return name
- return f'"{name}"'
- def migrate(dry_run=False):
- import psycopg2
- print(f"SQLite: {SQLITE_PATH}")
- print(f"PG: {PG_URI.replace(DB_PASSWORD, '***')}")
- print()
- pg = psycopg2.connect(PG_URI)
- pg.autocommit = False
- cur = pg.cursor()
- sqlite_tables = get_sqlite_tables()
- total_inserted = 0
- total_skipped = 0
- for sqlite_table, pg_table, skip_cols, _ in MIGRATION_PLAN:
- if sqlite_table not in sqlite_tables:
- print(f" [skip] SQLite 中不存在表 '{sqlite_table}'")
- continue
- columns, rows = get_sqlite_data(sqlite_table)
- insert_cols = [c for c in columns if c not in skip_cols]
- col_indices = [i for i, c in enumerate(columns) if c not in skip_cols]
- if not rows:
- print(f" [empty] {sqlite_table} -> {pg_table} (0 行)")
- continue
- print(f" [{sqlite_table}] -> [{pg_table}]: {len(rows)} 行")
- if dry_run:
- total_inserted += len(rows)
- continue
- # 禁用触发器(跳过 FK 检查)来清空表
- quoted = quote_identifier(pg_table)
- try:
- cur.execute(f"ALTER TABLE {quoted} DISABLE TRIGGER ALL")
- cur.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE")
- cur.execute(f"ALTER TABLE {quoted} ENABLE TRIGGER ALL")
- except psycopg2.Error as e:
- print(f" WARN: TRUNCATE {pg_table} 失败: {e}")
- pg.rollback()
- continue
- # 批量插入
- inserted = 0
- skipped = 0
- for row in rows:
- values = []
- for i, idx in enumerate(col_indices):
- val = row[idx]
- col_name = columns[idx]
- if val in (0, 1) and col_name in BOOL_COLUMNS:
- val = bool(val)
- values.append(val)
- placeholders = ", ".join(["%s"] * len(values))
- col_names = ", ".join(insert_cols)
- sql = f"INSERT INTO {quoted} ({col_names}) VALUES ({placeholders})"
- try:
- cur.execute(sql, values)
- inserted += 1
- except psycopg2.Error as e:
- skipped += 1
- pg.rollback()
- # 对于 url 超长问题,尝试用 TEXT 列存储
- if "value too long" in str(e) and sqlite_table == "deep_collection":
- try:
- # 临时修改列类型
- cur.execute("ALTER TABLE deep_collection ALTER COLUMN url TYPE TEXT")
- cur.execute(sql, values)
- inserted += 1
- skipped -= 1
- print(f" NOTE: deep_collection.url 已自动改为 TEXT 类型")
- except psycopg2.Error as e2:
- print(f" WARN: 插入 {pg_table} 失败: {e2}")
- continue
- else:
- print(f" WARN: 插入 {pg_table} 失败 (row {inserted + skipped}): {e}")
- continue
- # 重置序列
- if inserted > 0:
- try:
- seq_sql = f"SELECT setval(pg_get_serial_sequence('{pg_table}', 'id'), (SELECT COALESCE(MAX(id), 1) FROM {quoted}))"
- cur.execute(seq_sql)
- except psycopg2.Error:
- pass
- total_inserted += inserted
- total_skipped += skipped
- status = "OK" if skipped == 0 else "PARTIAL"
- print(f" [{status}] 插入 {inserted} 行, 跳过 {skipped} 行")
- if not dry_run:
- pg.commit()
- print(f"\n迁移完成! 共插入 {total_inserted} 行, 跳过 {total_skipped} 行")
- else:
- print(f"\n[DRY RUN] 将插入 ~{total_inserted} 行")
- cur.close()
- pg.close()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="SQLite → PostgreSQL 数据迁移")
- parser.add_argument("--dry-run", action="store_true", help="只打印计划,不执行")
- args = parser.parse_args()
- migrate(dry_run=args.dry_run)
|