migrate_sqlite_to_pg.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """
  2. SQLite → PostgreSQL 数据迁移脚本
  3. 用法:
  4. python migrate_sqlite_to_pg.py [--dry-run]
  5. --dry-run 只打印计划,不执行写入
  6. """
  7. import os
  8. import sys
  9. import sqlite3
  10. import argparse
  11. from dotenv import load_dotenv
  12. load_dotenv()
  13. # ── PostgreSQL 连接 ──
  14. DB_USER = os.environ.get("DB_USER", "")
  15. DB_PASSWORD = os.environ.get("DB_PASSWORD", "")
  16. DB_HOST = os.environ.get("DB_HOST", "")
  17. DB_PORT = os.environ.get("DB_PORT", "5432")
  18. DB_NAME = os.environ.get("DB_NAME", "")
  19. if not all([DB_USER, DB_PASSWORD, DB_HOST, DB_NAME]):
  20. print("ERROR: 缺少 PG 配置 (DB_USER/DB_PASSWORD/DB_HOST/DB_NAME)")
  21. sys.exit(1)
  22. PG_URI = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
  23. # ── SQLite 路径 ──
  24. SQLITE_PATH = os.path.join(os.path.dirname(__file__), "app", "app.db")
  25. if not os.path.exists(SQLITE_PATH):
  26. print(f"ERROR: 找不到 SQLite 数据库: {SQLITE_PATH}")
  27. sys.exit(1)
  28. # ── 布尔列(SQLite 存 0/1,需转换为 PG boolean) ──
  29. BOOL_COLUMNS = {"has_pagination", "has_deep_collection", "is_active"}
  30. # ── 迁移顺序(按外键依赖) ──
  31. # (sqlite表名, pg表名, 跳过列, 需要额外处理的列)
  32. MIGRATION_PLAN = [
  33. ("user", '"user"', set(), {}),
  34. ("spider_source", "spider_source", set(), {}),
  35. ("collection_task", "collection_task", set(), {}),
  36. # spider_task 是旧表,已不在 models.py 中,跳过
  37. ("spider_result", "spider_result", set(), {}),
  38. ("deep_collection", "deep_collection", set(), {}),
  39. ("ai_model", "ai_model", set(), {}),
  40. ("token_usage_log", "token_usage_log", set(), {}),
  41. ("ai_conversation", "ai_conversation", set(), {}),
  42. ("ai_message", "ai_message", set(), {}),
  43. ("knowledge_import_task", "knowledge_import_task", set(), {}),
  44. ]
  45. def get_sqlite_tables():
  46. conn = sqlite3.connect(SQLITE_PATH)
  47. cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
  48. tables = {r[0] for r in cursor.fetchall()}
  49. conn.close()
  50. return tables
  51. def get_sqlite_data(table):
  52. conn = sqlite3.connect(SQLITE_PATH)
  53. cursor = conn.execute(f"SELECT * FROM {table}")
  54. columns = [desc[0] for desc in cursor.description]
  55. rows = cursor.fetchall()
  56. conn.close()
  57. return columns, rows
  58. def quote_identifier(name):
  59. """如果标识符未被引号包裹,添加双引号。"""
  60. if name.startswith('"') and name.endswith('"'):
  61. return name
  62. return f'"{name}"'
  63. def migrate(dry_run=False):
  64. import psycopg2
  65. print(f"SQLite: {SQLITE_PATH}")
  66. print(f"PG: {PG_URI.replace(DB_PASSWORD, '***')}")
  67. print()
  68. pg = psycopg2.connect(PG_URI)
  69. pg.autocommit = False
  70. cur = pg.cursor()
  71. sqlite_tables = get_sqlite_tables()
  72. total_inserted = 0
  73. total_skipped = 0
  74. for sqlite_table, pg_table, skip_cols, _ in MIGRATION_PLAN:
  75. if sqlite_table not in sqlite_tables:
  76. print(f" [skip] SQLite 中不存在表 '{sqlite_table}'")
  77. continue
  78. columns, rows = get_sqlite_data(sqlite_table)
  79. insert_cols = [c for c in columns if c not in skip_cols]
  80. col_indices = [i for i, c in enumerate(columns) if c not in skip_cols]
  81. if not rows:
  82. print(f" [empty] {sqlite_table} -> {pg_table} (0 行)")
  83. continue
  84. print(f" [{sqlite_table}] -> [{pg_table}]: {len(rows)} 行")
  85. if dry_run:
  86. total_inserted += len(rows)
  87. continue
  88. # 禁用触发器(跳过 FK 检查)来清空表
  89. quoted = quote_identifier(pg_table)
  90. try:
  91. cur.execute(f"ALTER TABLE {quoted} DISABLE TRIGGER ALL")
  92. cur.execute(f"TRUNCATE TABLE {quoted} RESTART IDENTITY CASCADE")
  93. cur.execute(f"ALTER TABLE {quoted} ENABLE TRIGGER ALL")
  94. except psycopg2.Error as e:
  95. print(f" WARN: TRUNCATE {pg_table} 失败: {e}")
  96. pg.rollback()
  97. continue
  98. # 批量插入
  99. inserted = 0
  100. skipped = 0
  101. for row in rows:
  102. values = []
  103. for i, idx in enumerate(col_indices):
  104. val = row[idx]
  105. col_name = columns[idx]
  106. if val in (0, 1) and col_name in BOOL_COLUMNS:
  107. val = bool(val)
  108. values.append(val)
  109. placeholders = ", ".join(["%s"] * len(values))
  110. col_names = ", ".join(insert_cols)
  111. sql = f"INSERT INTO {quoted} ({col_names}) VALUES ({placeholders})"
  112. try:
  113. cur.execute(sql, values)
  114. inserted += 1
  115. except psycopg2.Error as e:
  116. skipped += 1
  117. pg.rollback()
  118. # 对于 url 超长问题,尝试用 TEXT 列存储
  119. if "value too long" in str(e) and sqlite_table == "deep_collection":
  120. try:
  121. # 临时修改列类型
  122. cur.execute("ALTER TABLE deep_collection ALTER COLUMN url TYPE TEXT")
  123. cur.execute(sql, values)
  124. inserted += 1
  125. skipped -= 1
  126. print(f" NOTE: deep_collection.url 已自动改为 TEXT 类型")
  127. except psycopg2.Error as e2:
  128. print(f" WARN: 插入 {pg_table} 失败: {e2}")
  129. continue
  130. else:
  131. print(f" WARN: 插入 {pg_table} 失败 (row {inserted + skipped}): {e}")
  132. continue
  133. # 重置序列
  134. if inserted > 0:
  135. try:
  136. seq_sql = f"SELECT setval(pg_get_serial_sequence('{pg_table}', 'id'), (SELECT COALESCE(MAX(id), 1) FROM {quoted}))"
  137. cur.execute(seq_sql)
  138. except psycopg2.Error:
  139. pass
  140. total_inserted += inserted
  141. total_skipped += skipped
  142. status = "OK" if skipped == 0 else "PARTIAL"
  143. print(f" [{status}] 插入 {inserted} 行, 跳过 {skipped} 行")
  144. if not dry_run:
  145. pg.commit()
  146. print(f"\n迁移完成! 共插入 {total_inserted} 行, 跳过 {total_skipped} 行")
  147. else:
  148. print(f"\n[DRY RUN] 将插入 ~{total_inserted} 行")
  149. cur.close()
  150. pg.close()
  151. if __name__ == "__main__":
  152. parser = argparse.ArgumentParser(description="SQLite → PostgreSQL 数据迁移")
  153. parser.add_argument("--dry-run", action="store_true", help="只打印计划,不执行")
  154. args = parser.parse_args()
  155. migrate(dry_run=args.dry_run)