run_migrations.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. """
  2. 数据库迁移管理脚本
  3. 执行方式:
  4. # 执行所有待执行的迁移
  5. python scripts/run_migrations.py
  6. # 执行特定迁移
  7. python scripts/run_migrations.py --migration migrate_add_soft_delete
  8. # 查看迁移状态
  9. python scripts/run_migrations.py --status
  10. # 回滚最后一次迁移
  11. python scripts/run_migrations.py --rollback
  12. """
  13. import sys
  14. import os
  15. from datetime import datetime
  16. # 添加项目根目录到 Python 路径
  17. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  18. from sqlalchemy import text
  19. from app.database import SessionLocal
  20. # 迁移列表(按执行顺序)
  21. MIGRATIONS = [
  22. {
  23. 'name': 'migrate_add_soft_delete',
  24. 'description': '为图像翻译表添加软删除字段',
  25. 'module': 'scripts.migrate_add_soft_delete',
  26. 'version': '20260122_001'
  27. },
  28. # 在这里添加新的迁移...
  29. ]
  30. def ensure_migration_table():
  31. """确保迁移记录表存在"""
  32. db = SessionLocal()
  33. try:
  34. db.execute(text("""
  35. CREATE TABLE IF NOT EXISTS aigcspace.schema_migrations (
  36. id SERIAL PRIMARY KEY,
  37. version VARCHAR(50) UNIQUE NOT NULL,
  38. name VARCHAR(255) NOT NULL,
  39. description TEXT,
  40. executed_at TIMESTAMP NOT NULL DEFAULT NOW(),
  41. execution_time_ms INTEGER,
  42. status VARCHAR(20) NOT NULL DEFAULT 'success'
  43. )
  44. """))
  45. db.commit()
  46. print("✓ 迁移记录表已就绪")
  47. except Exception as e:
  48. print(f"✗ 创建迁移记录表失败: {e}")
  49. db.rollback()
  50. raise
  51. finally:
  52. db.close()
  53. def is_migration_executed(version):
  54. """检查迁移是否已执行"""
  55. db = SessionLocal()
  56. try:
  57. result = db.execute(
  58. text("SELECT COUNT(*) FROM aigcspace.schema_migrations WHERE version = :version"),
  59. {"version": version}
  60. )
  61. count = result.scalar()
  62. return count > 0
  63. finally:
  64. db.close()
  65. def record_migration(version, name, description, execution_time_ms, status='success'):
  66. """记录迁移执行"""
  67. db = SessionLocal()
  68. try:
  69. db.execute(
  70. text("""
  71. INSERT INTO aigcspace.schema_migrations
  72. (version, name, description, execution_time_ms, status)
  73. VALUES (:version, :name, :description, :execution_time_ms, :status)
  74. """),
  75. {
  76. "version": version,
  77. "name": name,
  78. "description": description,
  79. "execution_time_ms": execution_time_ms,
  80. "status": status
  81. }
  82. )
  83. db.commit()
  84. except Exception as e:
  85. print(f"✗ 记录迁移失败: {e}")
  86. db.rollback()
  87. finally:
  88. db.close()
  89. def run_migration(migration_info):
  90. """执行单个迁移"""
  91. version = migration_info['version']
  92. name = migration_info['name']
  93. description = migration_info['description']
  94. module_name = migration_info['module']
  95. print("\n" + "=" * 70)
  96. print(f"执行迁移: {name}")
  97. print(f"版本: {version}")
  98. print(f"描述: {description}")
  99. print("=" * 70)
  100. # 检查是否已执行
  101. if is_migration_executed(version):
  102. print(f"⚠ 迁移 {version} 已执行,跳过")
  103. return True
  104. try:
  105. # 动态导入迁移模块
  106. module = __import__(module_name, fromlist=['run_migration'])
  107. # 执行迁移
  108. start_time = datetime.now()
  109. success = module.run_migration()
  110. end_time = datetime.now()
  111. execution_time_ms = int((end_time - start_time).total_seconds() * 1000)
  112. if success:
  113. # 记录成功的迁移
  114. record_migration(version, name, description, execution_time_ms, 'success')
  115. print(f"\n✓ 迁移 {name} 执行成功 (耗时: {execution_time_ms}ms)")
  116. return True
  117. else:
  118. # 记录失败的迁移
  119. record_migration(version, name, description, execution_time_ms, 'failed')
  120. print(f"\n✗ 迁移 {name} 执行失败")
  121. return False
  122. except Exception as e:
  123. print(f"\n✗ 迁移 {name} 执行异常: {e}")
  124. import traceback
  125. traceback.print_exc()
  126. # 记录失败
  127. record_migration(version, name, description, 0, 'error')
  128. return False
  129. def run_all_migrations():
  130. """执行所有待执行的迁移"""
  131. print("\n" + "=" * 70)
  132. print("开始执行数据库迁移")
  133. print("=" * 70)
  134. # 确保迁移表存在
  135. ensure_migration_table()
  136. # 统计
  137. total = len(MIGRATIONS)
  138. executed = 0
  139. skipped = 0
  140. failed = 0
  141. # 执行每个迁移
  142. for migration in MIGRATIONS:
  143. if is_migration_executed(migration['version']):
  144. skipped += 1
  145. print(f"\n⚠ 跳过已执行的迁移: {migration['name']} ({migration['version']})")
  146. continue
  147. success = run_migration(migration)
  148. if success:
  149. executed += 1
  150. else:
  151. failed += 1
  152. print(f"\n✗ 迁移失败,停止执行后续迁移")
  153. break
  154. # 输出总结
  155. print("\n" + "=" * 70)
  156. print("迁移执行总结")
  157. print("=" * 70)
  158. print(f"总迁移数: {total}")
  159. print(f"已执行: {executed}")
  160. print(f"已跳过: {skipped}")
  161. print(f"失败: {failed}")
  162. print("=" * 70)
  163. return failed == 0
  164. def show_migration_status():
  165. """显示迁移状态"""
  166. print("\n" + "=" * 70)
  167. print("数据库迁移状态")
  168. print("=" * 70)
  169. # 确保迁移表存在
  170. ensure_migration_table()
  171. db = SessionLocal()
  172. try:
  173. # 查询已执行的迁移
  174. result = db.execute(text("""
  175. SELECT version, name, description, executed_at, execution_time_ms, status
  176. FROM aigcspace.schema_migrations
  177. ORDER BY executed_at DESC
  178. """))
  179. executed_migrations = {row[0]: row for row in result.fetchall()}
  180. print(f"\n待执行的迁移:")
  181. print("-" * 70)
  182. pending_count = 0
  183. for migration in MIGRATIONS:
  184. version = migration['version']
  185. if version not in executed_migrations:
  186. pending_count += 1
  187. print(f" [{version}] {migration['name']}")
  188. print(f" 描述: {migration['description']}")
  189. print()
  190. if pending_count == 0:
  191. print(" 无待执行的迁移")
  192. print(f"\n已执行的迁移:")
  193. print("-" * 70)
  194. if executed_migrations:
  195. for migration in MIGRATIONS:
  196. version = migration['version']
  197. if version in executed_migrations:
  198. row = executed_migrations[version]
  199. status_icon = "✓" if row[5] == 'success' else "✗"
  200. print(f" {status_icon} [{version}] {row[1]}")
  201. print(f" 执行时间: {row[3]}")
  202. print(f" 耗时: {row[4]}ms")
  203. print(f" 状态: {row[5]}")
  204. print()
  205. else:
  206. print(" 无已执行的迁移")
  207. print("=" * 70)
  208. finally:
  209. db.close()
  210. def rollback_last_migration():
  211. """回滚最后一次迁移"""
  212. print("\n" + "=" * 70)
  213. print("警告:准备回滚最后一次迁移")
  214. print("=" * 70)
  215. db = SessionLocal()
  216. try:
  217. # 查询最后一次迁移
  218. result = db.execute(text("""
  219. SELECT version, name, description
  220. FROM aigcspace.schema_migrations
  221. ORDER BY executed_at DESC
  222. LIMIT 1
  223. """))
  224. last_migration = result.fetchone()
  225. if not last_migration:
  226. print("\n没有可回滚的迁移")
  227. return
  228. version, name, description = last_migration
  229. print(f"\n最后一次迁移:")
  230. print(f" 版本: {version}")
  231. print(f" 名称: {name}")
  232. print(f" 描述: {description}")
  233. confirm = input("\n确定要回滚此迁移吗?(yes/no): ")
  234. if confirm.lower() != 'yes':
  235. print("已取消回滚")
  236. return
  237. # 查找对应的迁移配置
  238. migration_info = None
  239. for m in MIGRATIONS:
  240. if m['version'] == version:
  241. migration_info = m
  242. break
  243. if not migration_info:
  244. print(f"\n✗ 未找到迁移配置: {version}")
  245. return
  246. # 动态导入迁移模块
  247. module_name = migration_info['module']
  248. module = __import__(module_name, fromlist=['rollback_migration'])
  249. if not hasattr(module, 'rollback_migration'):
  250. print(f"\n✗ 迁移 {name} 不支持回滚")
  251. return
  252. # 执行回滚
  253. module.rollback_migration()
  254. # 删除迁移记录
  255. db.execute(
  256. text("DELETE FROM aigcspace.schema_migrations WHERE version = :version"),
  257. {"version": version}
  258. )
  259. db.commit()
  260. print(f"\n✓ 迁移 {name} 已回滚")
  261. except Exception as e:
  262. print(f"\n✗ 回滚失败: {e}")
  263. import traceback
  264. traceback.print_exc()
  265. db.rollback()
  266. finally:
  267. db.close()
  268. if __name__ == "__main__":
  269. import argparse
  270. parser = argparse.ArgumentParser(description='数据库迁移管理工具')
  271. parser.add_argument(
  272. '--migration',
  273. type=str,
  274. help='执行指定的迁移'
  275. )
  276. parser.add_argument(
  277. '--status',
  278. action='store_true',
  279. help='显示迁移状态'
  280. )
  281. parser.add_argument(
  282. '--rollback',
  283. action='store_true',
  284. help='回滚最后一次迁移'
  285. )
  286. args = parser.parse_args()
  287. try:
  288. if args.status:
  289. show_migration_status()
  290. elif args.rollback:
  291. rollback_last_migration()
  292. elif args.migration:
  293. # 执行指定迁移
  294. migration_info = None
  295. for m in MIGRATIONS:
  296. if m['name'] == args.migration:
  297. migration_info = m
  298. break
  299. if migration_info:
  300. ensure_migration_table()
  301. success = run_migration(migration_info)
  302. sys.exit(0 if success else 1)
  303. else:
  304. print(f"✗ 未找到迁移: {args.migration}")
  305. sys.exit(1)
  306. else:
  307. # 执行所有待执行的迁移
  308. success = run_all_migrations()
  309. sys.exit(0 if success else 1)
  310. except KeyboardInterrupt:
  311. print("\n\n用户中断执行")
  312. sys.exit(1)
  313. except Exception as e:
  314. print(f"\n✗ 执行失败: {e}")
  315. import traceback
  316. traceback.print_exc()
  317. sys.exit(1)