| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- """Database-related utilities shared across GPUStack components."""
- import re
- from sqlalchemy import create_engine, text
- from sqlalchemy.dialects.postgresql import base as pg_base
- _pg_version_patched = False
- def test_db_connection(db_url: str, timeout: int = 5) -> bool:
- """Test if a database connection can be established."""
- # For async URLs, convert to sync for the test
- if db_url.startswith("sqlite"):
- # SQLite doesn't need a pre-connection test, the file will be created
- return True
- try:
- engine = create_engine(db_url, connect_args={"connect_timeout": timeout})
- with engine.connect() as conn:
- conn.execute(text("SELECT 1"))
- engine.dispose()
- return True
- except Exception:
- return False
- def patch_pg_version_info() -> None:
- """Teach SQLAlchemy's PGDialect to parse openGauss version strings.
- openGauss presents itself with the PostgreSQL dialect but reports
- ``(openGauss X.Y.Z build ...)`` instead of ``PostgreSQL X.Y.Z``,
- which SQLAlchemy's default regex rejects with ``AssertionError``.
- We delegate to the original parser first so future upstream fixes
- are preserved, and only fall back to an openGauss regex on failure.
- Idempotent: safe to call multiple times.
- """
- global _pg_version_patched
- if _pg_version_patched:
- return
- _pg_version_patched = True
- orig_get_server_version_info = pg_base.PGDialect._get_server_version_info
- def _patched(self, connection):
- try:
- return orig_get_server_version_info(self, connection)
- except AssertionError:
- v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
- m = re.search(r"openGauss (\d+)\.(\d+)(?:\.(\d+))?", v or "")
- if not m:
- raise
- return tuple(int(x) if x is not None else 0 for x in m.group(1, 2, 3))
- pg_base.PGDialect._get_server_version_info = _patched
|