init_db.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import asyncio
  2. import logging
  3. import threading
  4. import time
  5. import re
  6. from urllib.parse import urlparse, parse_qs, urlunparse
  7. from sqlalchemy.ext.asyncio import (
  8. AsyncEngine,
  9. create_async_engine,
  10. )
  11. from sqlmodel import SQLModel
  12. from sqlalchemy import DDL, event, text
  13. from gpustack import envs
  14. from gpustack.server import db
  15. from gpustack.schemas.api_keys import ApiKey
  16. from gpustack.schemas.inference_backend import InferenceBackend
  17. from gpustack.schemas.model_usage import ModelUsage
  18. from gpustack.schemas.models import Model, ModelInstance
  19. from gpustack.schemas.system_load import SystemLoad
  20. from gpustack.schemas.users import User
  21. from gpustack.schemas.workers import Worker
  22. from gpustack.schemas.clusters import (
  23. Cluster,
  24. CloudCredential,
  25. WorkerPool,
  26. Credential,
  27. )
  28. from gpustack.schemas.stmt import (
  29. worker_after_create_view_stmt_sqlite,
  30. worker_after_drop_view_stmt_sqlite,
  31. worker_after_create_view_stmt_postgres,
  32. worker_after_drop_view_stmt_postgres,
  33. worker_after_create_view_stmt_opengauss,
  34. worker_after_create_view_stmt_mysql,
  35. worker_after_drop_view_stmt_mysql,
  36. model_user_after_drop_view_stmt,
  37. model_user_after_create_view_stmt,
  38. principal_users_after_drop_view_stmt,
  39. principal_users_after_create_view_stmt,
  40. )
  41. logger = logging.getLogger(__name__)
  42. SLOW_QUERY_THRESHOLD_SECOND = 0.5
  43. # Query counter for performance monitoring
  44. _query_counter = 0
  45. _query_counter_lock = threading.Lock()
  46. def increment_query_count_sync():
  47. """Increment the global query counter (synchronous version)."""
  48. global _query_counter
  49. with _query_counter_lock:
  50. _query_counter += 1
  51. def get_query_count() -> int:
  52. """Get the current query count."""
  53. global _query_counter
  54. with _query_counter_lock:
  55. return _query_counter
  56. async def init_db(db_url: str):
  57. if db.engine is None:
  58. db.engine = await init_db_engine(db_url)
  59. listen_events(db.engine)
  60. await create_db_and_tables(db.engine)
  61. async def init_db_engine(db_url: str):
  62. connect_args = {}
  63. if db_url.startswith("postgresql://"):
  64. db_url = re.sub(r'^postgresql://', 'postgresql+asyncpg://', db_url)
  65. parsed = urlparse(db_url)
  66. # rewrite the parameters to use asyncpg with custom database schema
  67. query_params = parse_qs(parsed.query)
  68. qoptions = query_params.pop('options', None)
  69. schema_name = None
  70. if qoptions is not None and len(qoptions) > 0:
  71. option = qoptions[0]
  72. if option.startswith('-csearch_path='):
  73. schema_name = option[len('-csearch_path=') :]
  74. if schema_name:
  75. connect_args['server_settings'] = {'search_path': schema_name}
  76. new_parsed = parsed._replace(query={})
  77. db_url = urlunparse(new_parsed)
  78. elif db_url.startswith("mysql://"):
  79. db_url = re.sub(r'^mysql://', 'mysql+asyncmy://', db_url)
  80. elif db_url.startswith("sqlite"):
  81. # Convert sqlite:// to sqlite+aiosqlite://
  82. db_url = re.sub(r'^sqlite(\+aiosqlite)?://', 'sqlite+aiosqlite://', db_url)
  83. else:
  84. raise Exception(f"Unsupported database URL: {db_url}")
  85. engine = create_async_engine(
  86. db_url,
  87. echo=envs.DB_ECHO,
  88. pool_size=envs.DB_POOL_SIZE,
  89. max_overflow=envs.DB_MAX_OVERFLOW,
  90. pool_timeout=envs.DB_POOL_TIMEOUT,
  91. pool_pre_ping=True,
  92. connect_args=connect_args,
  93. )
  94. return engine
  95. async def create_db_and_tables(engine: AsyncEngine):
  96. async with engine.begin() as conn:
  97. await conn.run_sync(
  98. SQLModel.metadata.create_all,
  99. tables=[
  100. ApiKey.__table__,
  101. InferenceBackend.__table__,
  102. ModelUsage.__table__,
  103. Model.__table__,
  104. ModelInstance.__table__,
  105. SystemLoad.__table__,
  106. User.__table__,
  107. Worker.__table__,
  108. Cluster.__table__,
  109. CloudCredential.__table__,
  110. WorkerPool.__table__,
  111. Credential.__table__,
  112. ],
  113. )
  114. def listen_events(engine: AsyncEngine):
  115. dialect_name = engine.dialect.name
  116. def _manage_worker_view(target, connection, **kw):
  117. d = connection.dialect.name
  118. if d == "postgresql":
  119. ver = connection.execute(text("SELECT version()")).scalar()
  120. create_stmt = (
  121. worker_after_create_view_stmt_opengauss
  122. if 'openGauss' in (ver or '')
  123. else worker_after_create_view_stmt_postgres
  124. )
  125. connection.execute(text(worker_after_drop_view_stmt_postgres))
  126. connection.execute(text(create_stmt))
  127. elif d == "mysql":
  128. connection.execute(text(worker_after_drop_view_stmt_mysql))
  129. connection.execute(text(worker_after_create_view_stmt_mysql))
  130. else:
  131. connection.execute(text(worker_after_drop_view_stmt_sqlite))
  132. connection.execute(text(worker_after_create_view_stmt_sqlite))
  133. event.listen(Worker.metadata, "after_create", _manage_worker_view)
  134. # ``non_admin_user_models`` references ``principal_users``; drop the
  135. # dependent view first and create the helper before the dependent.
  136. event.listen(
  137. SQLModel.metadata, "after_create", DDL(model_user_after_drop_view_stmt)
  138. )
  139. event.listen(
  140. SQLModel.metadata, "after_create", DDL(principal_users_after_drop_view_stmt)
  141. )
  142. event.listen(
  143. SQLModel.metadata,
  144. "after_create",
  145. DDL(principal_users_after_create_view_stmt()),
  146. )
  147. event.listen(
  148. SQLModel.metadata,
  149. "after_create",
  150. DDL(model_user_after_create_view_stmt(dialect_name)),
  151. )
  152. if engine.dialect.name == "sqlite":
  153. event.listen(engine.sync_engine, "connect", setup_sqlite_pragmas)
  154. event.listen(engine.sync_engine, "close", ignore_cancel_on_close)
  155. if logger.isEnabledFor(logging.DEBUG):
  156. # Log slow queries on debugging
  157. event.listen(
  158. engine.sync_engine, "before_cursor_execute", before_cursor_execute
  159. )
  160. event.listen(
  161. engine.sync_engine, "after_cursor_execute", after_cursor_execute
  162. )
  163. # Always count queries for performance monitoring
  164. event.listen(engine.sync_engine, "after_cursor_execute", count_query)
  165. def count_query(conn, cursor, statement, parameters, context, executemany):
  166. """Increment the global query counter for each query executed."""
  167. increment_query_count_sync()
  168. def setup_sqlite_pragmas(conn, record):
  169. # Enable foreign keys for SQLite, since it's disabled by default
  170. conn.execute("PRAGMA foreign_keys=ON")
  171. # Performance tuning
  172. conn.execute("PRAGMA journal_mode=WAL")
  173. conn.execute("PRAGMA synchronous=normal")
  174. conn.execute("PRAGMA temp_store=memory")
  175. conn.execute("PRAGMA mmap_size=30000000000")
  176. def ignore_cancel_on_close(dbapi_connection, connection_record):
  177. try:
  178. dbapi_connection.close()
  179. except asyncio.CancelledError:
  180. pass
  181. def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
  182. context._query_start_time = time.time()
  183. def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
  184. total = time.time() - context._query_start_time
  185. if total > SLOW_QUERY_THRESHOLD_SECOND:
  186. logger.debug(f"[SLOW SQL] {total:.3f}s\nSQL: {statement}\nParams: {parameters}")