db.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. """Database-related utilities shared across GPUStack components."""
  2. import re
  3. from sqlalchemy import create_engine, text
  4. from sqlalchemy.dialects.postgresql import base as pg_base
  5. _pg_version_patched = False
  6. def test_db_connection(db_url: str, timeout: int = 5) -> bool:
  7. """Test if a database connection can be established."""
  8. # For async URLs, convert to sync for the test
  9. if db_url.startswith("sqlite"):
  10. # SQLite doesn't need a pre-connection test, the file will be created
  11. return True
  12. try:
  13. engine = create_engine(db_url, connect_args={"connect_timeout": timeout})
  14. with engine.connect() as conn:
  15. conn.execute(text("SELECT 1"))
  16. engine.dispose()
  17. return True
  18. except Exception:
  19. return False
  20. def patch_pg_version_info() -> None:
  21. """Teach SQLAlchemy's PGDialect to parse openGauss version strings.
  22. openGauss presents itself with the PostgreSQL dialect but reports
  23. ``(openGauss X.Y.Z build ...)`` instead of ``PostgreSQL X.Y.Z``,
  24. which SQLAlchemy's default regex rejects with ``AssertionError``.
  25. We delegate to the original parser first so future upstream fixes
  26. are preserved, and only fall back to an openGauss regex on failure.
  27. Idempotent: safe to call multiple times.
  28. """
  29. global _pg_version_patched
  30. if _pg_version_patched:
  31. return
  32. _pg_version_patched = True
  33. orig_get_server_version_info = pg_base.PGDialect._get_server_version_info
  34. def _patched(self, connection):
  35. try:
  36. return orig_get_server_version_info(self, connection)
  37. except AssertionError:
  38. v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
  39. m = re.search(r"openGauss (\d+)\.(\d+)(?:\.(\d+))?", v or "")
  40. if not m:
  41. raise
  42. return tuple(int(x) if x is not None else 0 for x in m.group(1, 2, 3))
  43. pg_base.PGDialect._get_server_version_info = _patched