db_migration.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import argparse
  2. import logging
  3. import os
  4. import re
  5. import shutil
  6. from glob import glob
  7. import sys
  8. from typing import Optional
  9. from urllib.parse import parse_qs, urlparse, urlunparse
  10. from sqlalchemy import MetaData, Table
  11. from sqlmodel.ext.asyncio.session import AsyncSession
  12. from sqlmodel import select, text, func
  13. from sqlalchemy.ext.asyncio import (
  14. AsyncEngine,
  15. create_async_engine,
  16. )
  17. from gpustack.cmd.start import get_gpustack_env
  18. from gpustack import envs
  19. logger = logging.getLogger(__name__)
  20. logger.setLevel(logging.INFO)
  21. logger.addHandler(logging.StreamHandler(sys.stdout))
  22. revision = "d19176de3b74" # 0.7.1
  23. TARGET_TABLES = [
  24. "system_loads",
  25. "workers",
  26. "users",
  27. "api_keys",
  28. "model_files",
  29. "models",
  30. "model_instances",
  31. "modelinstancemodelfilelink",
  32. "model_usages",
  33. ]
  34. migration_temp_file_prefix = "gpustack_migration_temp_"
  35. def setup_migrate_cmd(subparsers: argparse._SubParsersAction):
  36. parser: argparse.ArgumentParser = subparsers.add_parser("migrate")
  37. parser.add_argument(
  38. "--migration-data-dir",
  39. type=str,
  40. help="Data directory to include the original sqlite file database.db to migrate.",
  41. default=get_gpustack_env("MIGRATION_DATA_DIR"),
  42. required=True,
  43. )
  44. parser.add_argument(
  45. "--database-url",
  46. type=str,
  47. help="Target database URL, e.g. postgresql://user:password@host:port/db_name.",
  48. default=get_gpustack_env("DATABASE_URL"),
  49. required=True,
  50. )
  51. parser.set_defaults(func=run)
  52. def run(args: argparse.Namespace):
  53. import asyncio
  54. asyncio.run(_run(args))
  55. async def _run(args):
  56. try:
  57. logger.info("Starting database migration...")
  58. sqlite_db_url, postgres_db_url, old_engine, new_engine = await prepare_env(args)
  59. await upgrade_schema(sqlite_db_url, postgres_db_url, old_engine)
  60. await migrate_all_data(old_engine, new_engine)
  61. await old_engine.dispose()
  62. await new_engine.dispose()
  63. clean_env(args)
  64. logger.info("Migration completed successfully.")
  65. except Exception as e:
  66. logger.fatal(f"Failed to migrate: {e}")
  67. sys.exit(1)
  68. async def prepare_env(args):
  69. logger.info("=" * 30 + " Preparing " + "=" * 30)
  70. data_dir = args.migration_data_dir
  71. database_url = args.database_url
  72. migration_temp_sqlite_path = _copy_sqlite_file(data_dir)
  73. sqlite_db_url = f"sqlite:///{migration_temp_sqlite_path}"
  74. postgres_db_url = database_url
  75. old_engine = await init_db_engine(sqlite_db_url)
  76. new_engine = await init_db_engine(postgres_db_url)
  77. return sqlite_db_url, postgres_db_url, old_engine, new_engine
  78. async def upgrade_schema(sqlite_db_url, postgres_db_url, old_engine):
  79. logger.info("=" * 30 + " Drop views " + "=" * 30)
  80. await _drop_view(old_engine)
  81. logger.info("=" * 30 + " SQLite Upgrade " + "=" * 30)
  82. _run_schema_upgrade(sqlite_db_url, revision)
  83. logger.info("=" * 30 + " Postgres Upgrade " + "=" * 30)
  84. _run_schema_upgrade(postgres_db_url, revision)
  85. def clean_env(args):
  86. logger.info("=" * 30 + " Cleaning Up " + "=" * 30)
  87. data_dir = args.migration_data_dir
  88. for f in glob(os.path.join(data_dir, f"{migration_temp_file_prefix}*")):
  89. try:
  90. os.remove(f)
  91. logger.info(f"Cleaning up temporary files {f}")
  92. except Exception as e:
  93. logger.error(f"Failed to remove file {f}: {e}")
  94. def _copy_sqlite_file(data_dir: str):
  95. sqlite_path = ""
  96. required_files = ["database.db"]
  97. optional_files = ["database.db-wal"]
  98. for f in required_files + optional_files:
  99. file_path = os.path.join(data_dir, f)
  100. if os.path.exists(file_path) is False:
  101. if f in required_files:
  102. raise FileNotFoundError(f"Required sqlite file {file_path} not found.")
  103. else:
  104. continue
  105. copied_file_path = os.path.join(data_dir, f"{migration_temp_file_prefix}{f}")
  106. try:
  107. shutil.copyfile(file_path, copied_file_path)
  108. logger.info(f"Copied sqlite file to {copied_file_path}")
  109. if f == "database.db":
  110. sqlite_path = copied_file_path
  111. except Exception as e:
  112. raise RuntimeError(f"Failed to copy sqlite file: {e}") from e
  113. return sqlite_path
  114. def _run_schema_upgrade(db_url: str, revision: str = "head"):
  115. logger.info(f"Running schema upgrade for {db_url}.")
  116. from alembic import command
  117. from alembic.config import Config as AlembicConfig
  118. import importlib.util
  119. spec = importlib.util.find_spec("gpustack")
  120. if spec is None:
  121. raise ImportError("The 'gpustack' package is not found.")
  122. pkg_path = spec.submodule_search_locations[0]
  123. alembic_cfg = AlembicConfig()
  124. alembic_cfg.set_main_option("script_location", os.path.join(pkg_path, "migrations"))
  125. alembic_cfg.set_main_option("called_by_db_migration", "true")
  126. db_url_escaped = db_url.replace("%", "%%")
  127. alembic_cfg.set_main_option("sqlalchemy.url", db_url_escaped)
  128. try:
  129. command.upgrade(alembic_cfg, revision)
  130. except Exception as e:
  131. raise RuntimeError(f"Database upgrade failed: {e}") from e
  132. logger.info(f"Database schema upgrade for {db_url} completed.")
  133. async def _drop_view(engine: AsyncEngine):
  134. logger.info("Dropping views in the old database if any.")
  135. async with engine.begin() as conn:
  136. await conn.execute(text("DROP VIEW IF EXISTS gpu_devices_view"))
  137. async def migrate_all_data(old_engine: AsyncEngine, new_engine: AsyncEngine):
  138. logger.info("=" * 30 + " Migrate Data " + "=" * 30)
  139. old_meta = MetaData()
  140. new_meta = MetaData()
  141. async with old_engine.begin() as conn:
  142. await conn.run_sync(old_meta.reflect, only=TARGET_TABLES)
  143. async with new_engine.begin() as conn:
  144. await conn.run_sync(new_meta.reflect, only=TARGET_TABLES)
  145. for table_name in TARGET_TABLES:
  146. await _migrate_table(table_name, old_meta, new_meta, old_engine, new_engine)
  147. await _sync_table_sequence(new_meta, new_engine)
  148. async def _migrate_table(
  149. table_name: str,
  150. old_meta: MetaData,
  151. new_meta: MetaData,
  152. old_engine: AsyncEngine,
  153. new_engine: AsyncEngine,
  154. ):
  155. old_table: Optional[Table] = old_meta.tables.get(table_name)
  156. new_table: Optional[Table] = new_meta.tables.get(table_name)
  157. if old_table is None:
  158. logger.info(f"Old database lack of {table_name}, skip.")
  159. return
  160. if new_table is None:
  161. logger.info(f"New database lack of {table_name}, skip.")
  162. return
  163. common_cols = [c for c in old_table.columns.keys() if c in new_table.columns.keys()]
  164. if not common_cols:
  165. logger.info(f"Table {table_name} has no common columns, skipping.")
  166. return
  167. async with (
  168. AsyncSession(old_engine) as old_sess,
  169. AsyncSession(new_engine) as new_sess,
  170. ):
  171. stmt = select(*[old_table.c[col] for col in common_cols])
  172. result = await old_sess.execute(stmt)
  173. rows = result.fetchall()
  174. if not rows:
  175. logger.info(f"Old table {table_name} has no data, skipping.")
  176. return
  177. # Convert to dictionary
  178. data = [dict(zip(common_cols, row)) for row in rows]
  179. # Insert into new database
  180. await new_sess.execute(new_table.insert(), data)
  181. await new_sess.commit()
  182. logger.info(f"Table {table_name} has migrated {len(data)} records.")
  183. async def _sync_table_sequence(new_meta: MetaData, new_engine: AsyncEngine):
  184. synced = 0
  185. async with AsyncSession(new_engine) as session:
  186. for table_name, table in new_meta.tables.items():
  187. if table_name not in TARGET_TABLES:
  188. continue
  189. id_col = table.columns.get("id")
  190. if id_col is None:
  191. continue
  192. stmt = select(func.max(id_col))
  193. result = await session.execute(stmt)
  194. max_id = result.scalar()
  195. if max_id is None:
  196. continue
  197. seq_name = f"{table_name}_id_seq"
  198. setval_stmt = text('SELECT setval(:seq_name, :max_id)')
  199. await session.execute(setval_stmt, {"seq_name": seq_name, "max_id": max_id})
  200. synced += 1
  201. await session.commit()
  202. logger.info(f"Synced {synced} sequences.")
  203. async def init_db_engine(db_url: str):
  204. connect_args = {}
  205. if db_url.startswith("sqlite://"):
  206. connect_args = {"check_same_thread": False}
  207. # use async driver
  208. db_url = re.sub(r'^sqlite://', 'sqlite+aiosqlite://', db_url)
  209. elif db_url.startswith("postgresql://"):
  210. db_url = re.sub(r'^postgresql://', 'postgresql+asyncpg://', db_url)
  211. parsed = urlparse(db_url)
  212. # rewrite the parameters to use asyncpg with custom database schema
  213. query_params = parse_qs(parsed.query)
  214. qoptions = query_params.pop('options', None)
  215. schema_name = None
  216. if qoptions is not None and len(qoptions) > 0:
  217. option = qoptions[0]
  218. if option.startswith('-csearch_path='):
  219. schema_name = option[len('-csearch_path=') :]
  220. if schema_name:
  221. connect_args['server_settings'] = {'search_path': schema_name}
  222. new_parsed = parsed._replace(query={})
  223. db_url = urlunparse(new_parsed)
  224. elif db_url.startswith("mysql://"):
  225. db_url = re.sub(r'^mysql://', 'mysql+asyncmy://', db_url)
  226. else:
  227. raise Exception(f"Unsupported database URL: {db_url}")
  228. engine = create_async_engine(
  229. db_url,
  230. echo=envs.DB_ECHO,
  231. pool_size=envs.DB_POOL_SIZE,
  232. max_overflow=envs.DB_MAX_OVERFLOW,
  233. pool_timeout=envs.DB_POOL_TIMEOUT,
  234. connect_args=connect_args,
  235. )
  236. return engine