|
|
@@ -0,0 +1,189 @@
|
|
|
+"""
|
|
|
+SQLite → PostgreSQL 数据迁移脚本
|
|
|
+
|
|
|
+用法:
|
|
|
+ python migrate_sqlite_to_pg.py [--dry-run]
|
|
|
+
|
|
|
+--dry-run 只打印计划,不执行写入
|
|
|
+"""
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import sqlite3
|
|
|
+import argparse
|
|
|
+from dotenv import load_dotenv
|
|
|
+
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+# ── PostgreSQL 连接 ──
|
|
|
+DB_USER = os.environ.get("DB_USER", "")
|
|
|
+DB_PASSWORD = os.environ.get("DB_PASSWORD", "")
|
|
|
+DB_HOST = os.environ.get("DB_HOST", "")
|
|
|
+DB_PORT = os.environ.get("DB_PORT", "5432")
|
|
|
+DB_NAME = os.environ.get("DB_NAME", "")
|
|
|
+
|
|
|
+if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]):
|
|
|
+ print("ERROR: 缺少 PG 配置 (DB_USER/DB_PASSWORD/DB_HOST/DB_NAME)")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
|
|
+
|
|
|
+# ── SQLite 路径 ──
|
|
|
+SQLITE_PATH = os.path.join(os.path.dirname(__file__), "app", "app.db")
|
|
|
+if not os.path.exists(SQLITE_PATH):
|
|
|
+ print(f"ERROR: 找不到 SQLite 数据库: {SQLITE_PATH}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+# ── 布尔列(SQLite 存 0/1,需转换为 PG boolean) ──
|
|
|
+BOOL_COLUMNS = {"has_pagination", "has_deep_collection", "is_active"}
|
|
|
+
|
|
|
+# ── 迁移顺序(按外键依赖) ──
|
|
|
+# (sqlite表名, pg表名, 跳过列, 需要额外处理的列)
|
|
|
+MIGRATION_PLAN = [
|
|
|
+ ("user", '"user"', set(), {}),
|
|
|
+ ("spider_source", "spider_source", set(), {}),
|
|
|
+ ("collection_task", "collection_task", set(), {}),
|
|
|
+ # spider_task 是旧表,已不在 models.py 中,跳过
|
|
|
+ ("spider_result", "spider_result", set(), {}),
|
|
|
+ ("deep_collection", "deep_collection", set(), {}),
|
|
|
+ ("ai_model", "ai_model", set(), {}),
|
|
|
+ ("token_usage_log", "token_usage_log", set(), {}),
|
|
|
+ ("ai_conversation", "ai_conversation", set(), {}),
|
|
|
+ ("ai_message", "ai_message", set(), {}),
|
|
|
+ ("knowledge_import_task", "knowledge_import_task", set(), {}),
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+def get_sqlite_tables():
|
|
|
+ conn = sqlite3.connect(SQLITE_PATH)
|
|
|
+ cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
|
+ tables = {r[0] for r in cursor.fetchall()}
|
|
|
+ conn.close()
|
|
|
+ return tables
|
|
|
+
|
|
|
+
|
|
|
+def get_sqlite_data(table):
|
|
|
+ conn = sqlite3.connect(SQLITE_PATH)
|
|
|
+ cursor = conn.execute(f"SELECT * FROM {table}")
|
|
|
+ columns = [desc[0] for desc in cursor.description]
|
|
|
+ rows = cursor.fetchall()
|
|
|
+ conn.close()
|
|
|
+ return columns, rows
|
|
|
+
|
|
|
+
|
|
|
+def quote_identifier(name):
|
|
|
+ """如果标识符未被引号包裹,添加双引号。"""
|
|
|
+ if name.startswith('"') and name.endswith('"'):
|
|
|
+ return name
|
|
|
+ return f'"{name}"'
|
|
|
+
|
|
|
+
|
|
|
+def migrate(dry_run=False):
|
|
|
+ import psycopg2
|
|
|
+
|
|
|
+ print(f"SQLite: {SQLITE_PATH}")
|
|
|
+ print(f"PG: {PG_URI.replace(DB_PASSWORD, '***')}")
|
|
|
+ print()
|
|
|
+
|
|
|
+ pg = psycopg2.connect(PG_URI)
|
|
|
+ pg.autocommit = False
|
|
|
+ cur = pg.cursor()
|
|
|
+
|
|
|
+ sqlite_tables = get_sqlite_tables()
|
|
|
+ total_inserted = 0
|
|
|
+ total_skipped = 0
|
|
|
+
|
|
|
+ for sqlite_table, pg_table, skip_cols, _ in MIGRATION_PLAN:
|
|
|
+ if sqlite_table not in sqlite_tables:
|
|
|
+ print(f" [skip] SQLite 中不存在表 '{sqlite_table}'")
|
|
|
+ continue
|
|
|
+
|
|
|
+ columns, rows = get_sqlite_data(sqlite_table)
|
|
|
+ insert_cols = [c for c in columns if c not in skip_cols]
|
|
|
+ col_indices = [i for i, c in enumerate(columns) if c not in skip_cols]
|
|
|
+
|
|
|
+ if not rows:
|
|
|
+ print(f" [empty] {sqlite_table} -> {pg_table} (0 行)")
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f" [{sqlite_table}] -> [{pg_table}]: {len(rows)} 行")
|
|
|
+
|
|
|
+ if dry_run:
|
|
|
+ total_inserted += len(rows)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 禁用触发器(跳过 FK 检查)来清空表
|
|
|
+ quoted = quote_identifier(pg_table)
|
|
|
+ try:
|
|
|
+ cur.execute(f"ALTER TABLE {quoted} DISABLE TRIGGER ALL")
|
|
|
+ cur.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE")
|
|
|
+ cur.execute(f"ALTER TABLE {quoted} ENABLE TRIGGER ALL")
|
|
|
+ except psycopg2.Error as e:
|
|
|
+ print(f" WARN: TRUNCATE {pg_table} 失败: {e}")
|
|
|
+ pg.rollback()
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 批量插入
|
|
|
+ inserted = 0
|
|
|
+ skipped = 0
|
|
|
+ for row in rows:
|
|
|
+ values = []
|
|
|
+ for i, idx in enumerate(col_indices):
|
|
|
+ val = row[idx]
|
|
|
+ col_name = columns[idx]
|
|
|
+ if val in (0, 1) and col_name in BOOL_COLUMNS:
|
|
|
+ val = bool(val)
|
|
|
+ values.append(val)
|
|
|
+
|
|
|
+ placeholders = ", ".join(["%s"] * len(values))
|
|
|
+ col_names = ", ".join(insert_cols)
|
|
|
+ sql = f"INSERT INTO {quoted} ({col_names}) VALUES ({placeholders})"
|
|
|
+ try:
|
|
|
+ cur.execute(sql, values)
|
|
|
+ inserted += 1
|
|
|
+ except psycopg2.Error as e:
|
|
|
+ skipped += 1
|
|
|
+ pg.rollback()
|
|
|
+ # 对于 url 超长问题,尝试用 TEXT 列存储
|
|
|
+ if "value too long" in str(e) and sqlite_table == "deep_collection":
|
|
|
+ try:
|
|
|
+ # 临时修改列类型
|
|
|
+ cur.execute("ALTER TABLE deep_collection ALTER COLUMN url TYPE TEXT")
|
|
|
+ cur.execute(sql, values)
|
|
|
+ inserted += 1
|
|
|
+ skipped -= 1
|
|
|
+ print(f" NOTE: deep_collection.url 已自动改为 TEXT 类型")
|
|
|
+ except psycopg2.Error as e2:
|
|
|
+ print(f" WARN: 插入 {pg_table} 失败: {e2}")
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ print(f" WARN: 插入 {pg_table} 失败 (row {inserted + skipped}): {e}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 重置序列
|
|
|
+ if inserted > 0:
|
|
|
+ try:
|
|
|
+ seq_sql = f"SELECT setval(pg_get_serial_sequence('{pg_table}', 'id'), (SELECT COALESCE(MAX(id), 1) FROM {quoted}))"
|
|
|
+ cur.execute(seq_sql)
|
|
|
+ except psycopg2.Error:
|
|
|
+ pass
|
|
|
+
|
|
|
+ total_inserted += inserted
|
|
|
+ total_skipped += skipped
|
|
|
+ status = "OK" if skipped == 0 else "PARTIAL"
|
|
|
+ print(f" [{status}] 插入 {inserted} 行, 跳过 {skipped} 行")
|
|
|
+
|
|
|
+ if not dry_run:
|
|
|
+ pg.commit()
|
|
|
+ print(f"\n迁移完成! 共插入 {total_inserted} 行, 跳过 {total_skipped} 行")
|
|
|
+ else:
|
|
|
+ print(f"\n[DRY RUN] 将插入 ~{total_inserted} 行")
|
|
|
+
|
|
|
+ cur.close()
|
|
|
+ pg.close()
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ parser = argparse.ArgumentParser(description="SQLite → PostgreSQL 数据迁移")
|
|
|
+ parser.add_argument("--dry-run", action="store_true", help="只打印计划,不执行")
|
|
|
+ args = parser.parse_args()
|
|
|
+ migrate(dry_run=args.dry_run)
|