diff --git a/changelog/8148-engine-creator-pattern.yaml b/changelog/8148-engine-creator-pattern.yaml new file mode 100644 index 00000000000..f2aa0517eab --- /dev/null +++ b/changelog/8148-engine-creator-pattern.yaml @@ -0,0 +1,4 @@ +type: Changed +description: Refactored database engines to use SQLAlchemy creator pattern for per-connection credential resolution +pr: 8148 +labels: [] diff --git a/design-docs/dynamic-database-credentials.md b/design-docs/dynamic-database-credentials.md index 737162bd9ab..f8b4623c9e0 100644 --- a/design-docs/dynamic-database-credentials.md +++ b/design-docs/dynamic-database-credentials.md @@ -107,7 +107,7 @@ This depends on a SQLAlchemy internal (`greenlet_spawn`), which is acceptable be - The planned SQLAlchemy 2.0 upgrade will replace this with the public `async_creator` API. - The code should include a clear TODO and comments explaining this constraint. -The module-level engines in `ctl_session.py` need to be refactored into lazy factories (similar to how `session_management.py` already works) so the `creator` can be injected at construction time. +The module-level engines in `ctl_session.py` remain as module-level singletons. The `creator` closure captures a provider reference, not credentials themselves — credentials are resolved inside the closure body on every call. This means the engine can be constructed at any time (including module import) and credential rotation still works correctly. ### 4. Automatic Retry on Auth Failure diff --git a/src/fides/api/db/ctl_session.py b/src/fides/api/db/ctl_session.py index 0737c06f3be..75112b5e9ce 100644 --- a/src/fides/api/db/ctl_session.py +++ b/src/fides/api/db/ctl_session.py @@ -1,4 +1,3 @@ -import ssl from asyncio import Lock, gather from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from typing import Any, AsyncGenerator, Callable, Dict @@ -11,24 +10,17 @@ from fides.api.db.session import ExtendedSession from fides.api.db.util import custom_json_deserializer, custom_json_serializer +from fides.common.engine_creators import make_async_creator, make_sync_creator from fides.config import CONFIG # asyncio lock and flag for warming up the async pool ASYNC_READONLY_POOL_LOCK = Lock() ASYNC_READONLY_POOL_WARMED = False -# Associated with a workaround in fides.core.config.database_settings -# ref: https://github.com/sqlalchemy/sqlalchemy/discussions/5975 -connect_args: Dict[str, Any] = {} -if CONFIG.database.params.get("sslrootcert"): - ssl_ctx = ssl.create_default_context(cafile=CONFIG.database.params["sslrootcert"]) - ssl_ctx.verify_mode = ssl.CERT_REQUIRED - connect_args["ssl"] = ssl_ctx - -# Parameters are hidden for security +# Primary async engine — credentials resolved per-connection via creator async_engine = create_async_engine( - CONFIG.database.async_database_uri, - connect_args=connect_args, + "postgresql+asyncpg://", + creator=make_async_creator(), echo=False, hide_parameters=not CONFIG.dev_mode, logging_name="AsyncEngine", @@ -49,21 +41,12 @@ if CONFIG.database.async_readonly_database_uri: logger.info("Creating read-only async engine and session factory") - # Build connect_args for readonly (similar to primary) - readonly_connect_args: Dict[str, Any] = {} - readonly_params = CONFIG.database.readonly_params or {} - - if readonly_params.get("sslrootcert"): - ssl_ctx = ssl.create_default_context(cafile=readonly_params["sslrootcert"]) - ssl_ctx.verify_mode = ssl.CERT_REQUIRED - readonly_connect_args["ssl"] = ssl_ctx - logger.info( f"Read-only async settings: max-overflow: {CONFIG.database.api_async_engine_max_overflow}, pool-size: {CONFIG.database.async_readonly_database_pool_size}, pre-warm = {CONFIG.database.async_readonly_database_prewarm}, autocommit = {CONFIG.database.async_readonly_database_autocommit}, skip rollback = {CONFIG.database.async_readonly_database_pool_skip_rollback}" ) readonly_async_engine = create_async_engine( - CONFIG.database.async_readonly_database_uri, - connect_args=readonly_connect_args, + "postgresql+asyncpg://", + creator=make_async_creator(readonly=True), echo=False, hide_parameters=not CONFIG.dev_mode, logging_name="ReadOnlyAsyncEngine", @@ -92,7 +75,8 @@ # and they do not respect engine settings like pool_size, max_overflow, etc. # these should be removed, and we should standardize on what's provided in `session.py` sync_engine = create_engine( - CONFIG.database.sync_database_uri, + "postgresql+psycopg2://", + creator=make_sync_creator(), echo=False, hide_parameters=not CONFIG.dev_mode, logging_name="SyncEngine", diff --git a/src/fides/api/db/session.py b/src/fides/api/db/session.py index f16eb8684af..560ce09f830 100644 --- a/src/fides/api/db/session.py +++ b/src/fides/api/db/session.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any, Callable, Dict from loguru import logger from sqlalchemy import create_engine @@ -18,6 +18,7 @@ def get_db_engine( *, config: FidesConfig | None = None, database_uri: str | URL | None = None, + creator: Callable[[], Any] | None = None, pool_size: int = 50, max_overflow: int = 50, keepalives_idle: int | None = None, @@ -28,36 +29,59 @@ def get_db_engine( ) -> Engine: """Return a database engine. + When *creator* is provided, it is called by the pool to open each new + connection — credentials and connect_args are handled inside the creator. + A dialect-only URL is used for engine construction. + + When *database_uri* or *config* is provided, the engine uses a fixed + connection URI (existing behavior). + If the TESTING environment var is set the database engine returned will be connected to the test DB. """ - if not config and not database_uri: - raise ValueError("Either a config or database_uri is required") - - if not database_uri and config: - # Don't override any database_uri explicitly passed in - if config.test_mode: - database_uri = config.database.sqlalchemy_test_database_uri - else: - database_uri = config.database.sqlalchemy_database_uri - engine_args: Dict[str, Any] = { "json_serializer": custom_json_serializer, "json_deserializer": custom_json_deserializer, } - # keepalives settings - connect_args = {} - if keepalives_idle: - connect_args["keepalives_idle"] = keepalives_idle - if keepalives_interval: - connect_args["keepalives_interval"] = keepalives_interval - if keepalives_count: - connect_args["keepalives_count"] = keepalives_count - - if connect_args: - connect_args["keepalives"] = 1 - engine_args["connect_args"] = connect_args + if creator: + # Creator handles credentials and connect_args internally, + # so creator needs to set keepalives settings. + if database_uri or config: + raise ValueError( + "database_uri/config cannot be used with creator — " + "the creator handles connection construction" + ) + if keepalives_idle or keepalives_interval or keepalives_count: + raise ValueError( + "keepalives_idle/interval/count cannot be used with creator — " + "pass them as connect_args to the creator instead" + ) + engine_args["creator"] = creator + database_uri = "postgresql+psycopg2://" + else: + # URI-based path. + if not config and not database_uri: + raise ValueError("Either a config, database_uri, or creator is required") + + if not database_uri and config: + if config.test_mode: + database_uri = config.database.sqlalchemy_test_database_uri + else: + database_uri = config.database.sqlalchemy_database_uri + + # keepalives settings (only for URI path; creator handles its own) + connect_args = {} + if keepalives_idle: + connect_args["keepalives_idle"] = keepalives_idle + if keepalives_interval: + connect_args["keepalives_interval"] = keepalives_interval + if keepalives_count: + connect_args["keepalives_count"] = keepalives_count + + if connect_args: + connect_args["keepalives"] = 1 + engine_args["connect_args"] = connect_args if disable_pooling: engine_args["poolclass"] = NullPool diff --git a/src/fides/api/tasks/__init__.py b/src/fides/api/tasks/__init__.py index cf113747a02..ef077c441e7 100644 --- a/src/fides/api/tasks/__init__.py +++ b/src/fides/api/tasks/__init__.py @@ -19,6 +19,7 @@ from fides.api.request_context import get_request_id, set_request_id from fides.api.tasks import celery_healthcheck from fides.api.util.logger import setup as setup_logging +from fides.common.engine_creators import make_sync_creator from fides.config import CONFIG, FidesConfig MESSAGING_QUEUE_NAME = "fidesops.messaging" @@ -77,12 +78,16 @@ def get_new_session(self) -> ContextManager[Session]: # once per celery process. if self._task_engine is None: self._task_engine = get_db_engine( - config=CONFIG, + creator=make_sync_creator( + connect_args={ + "keepalives": 1, + "keepalives_idle": CONFIG.database.task_engine_keepalives_idle, + "keepalives_interval": CONFIG.database.task_engine_keepalives_interval, + "keepalives_count": CONFIG.database.task_engine_keepalives_count, + }, + ), pool_size=CONFIG.database.task_engine_pool_size, max_overflow=CONFIG.database.task_engine_max_overflow, - keepalives_idle=CONFIG.database.task_engine_keepalives_idle, - keepalives_interval=CONFIG.database.task_engine_keepalives_interval, - keepalives_count=CONFIG.database.task_engine_keepalives_count, pool_pre_ping=CONFIG.database.task_engine_pool_pre_ping, ) diff --git a/src/fides/common/engine_creators.py b/src/fides/common/engine_creators.py new file mode 100644 index 00000000000..a1fa5d40b77 --- /dev/null +++ b/src/fides/common/engine_creators.py @@ -0,0 +1,190 @@ +""" +SQLAlchemy engine ``creator`` callables for dynamic credential resolution. + +The ``creator`` pattern passes a callable to ``create_engine`` / +``create_async_engine`` instead of a connection URI. SQLAlchemy calls the +creator every time the pool needs a new connection, so credentials are +resolved at **connection time** rather than engine construction time. + +Today the credential helpers read from static config (``CONFIG.database``). +A future secret-provider integration will swap them to call +``provider.get_secret()`` — the rest of the engine code stays the same. + +Because creators run on every new pool connection, they must stay +lightweight — avoid expensive I/O, network calls, or heavy computation. +Credential lookups should return cached values in the common case. +""" + +from __future__ import annotations + +import ssl +from copy import deepcopy +from typing import Any, Callable, Dict, Optional + +import asyncpg # type: ignore[import-untyped] +import psycopg2 # type: ignore[import-untyped] +from sqlalchemy.dialects.postgresql.asyncpg import ( + AsyncAdapt_asyncpg_connection, + AsyncAdapt_asyncpg_dbapi, +) +from sqlalchemy.util.concurrency import await_only # type: ignore[import-untyped] + +from fides.config import CONFIG + +# Shared dbapi instance for async creators — reused across connections. +_asyncpg_dbapi = AsyncAdapt_asyncpg_dbapi(asyncpg) + + +# --------------------------------------------------------------------------- +# Credential helpers +# --------------------------------------------------------------------------- + + +def get_db_credentials() -> Dict[str, Any]: + """Return DB credentials from static config.""" + db_settings = CONFIG.database + dbname = db_settings.test_db if CONFIG.test_mode else db_settings.db + return { + "host": db_settings.server, + "port": int(db_settings.port), + "user": db_settings.user, + "password": db_settings.raw_password, + "dbname": dbname, + } + + +def get_readonly_db_credentials() -> Optional[Dict[str, Any]]: + """Return readonly DB credentials, or ``None`` if not configured. + + Falls back to primary fields where readonly-specific values are absent. + """ + db_settings = CONFIG.database + if not db_settings.readonly_server: + return None + return { + "host": db_settings.readonly_server, + "port": int(db_settings.readonly_port or db_settings.port), + "user": db_settings.readonly_user or db_settings.user, + "password": db_settings.raw_readonly_password or db_settings.raw_password, + "dbname": db_settings.readonly_db or db_settings.db, + } + + +# --------------------------------------------------------------------------- +# Sync creators (psycopg2) +# --------------------------------------------------------------------------- + + +def make_sync_creator( + connect_args: Optional[Dict[str, Any]] = None, + readonly: bool = False, +) -> Callable[[], Any]: + """Return a creator callable for psycopg2 engines. + + The factory captures per-engine config (keepalives, SSL) in the closure. + Credentials are resolved from CONFIG on every call — the seam for future + dynamic credential rotation. + + When using ``creator``, SQLAlchemy ignores ``connect_args`` passed to + ``create_engine``, so all connection parameters must be baked in here. + """ + + def creator() -> Any: + if readonly: + kw = get_readonly_db_credentials() or get_db_credentials() + else: + kw = get_db_credentials() + if connect_args: + kw.update(connect_args) + return psycopg2.connect(**kw) + + return creator + + +# --------------------------------------------------------------------------- +# Async creators (asyncpg) +# --------------------------------------------------------------------------- + + +def make_async_creator( + readonly: bool = False, +) -> Callable[[], Any]: + """Return a creator callable for asyncpg engines (SA 1.4.27). + + The factory builds the SSL context and asyncpg-compatible params from + CONFIG, capturing them in the closure. Credentials are resolved from + CONFIG on every call. + + The creator replaces ``dialect.connect()`` in SQLAlchemy's pool. For + async engines the pool runs inside a greenlet bridge, so ``await_only`` + is valid. Must return ``AsyncAdapt_asyncpg_connection`` (SA's sync + wrapper) since the pool operates in sync mode through greenlets. + + TODO: Replace with ``async_creator`` API after SQLAlchemy 2.0 upgrade. + """ + db_params = ( + (CONFIG.database.readonly_params or CONFIG.database.params) + if readonly + else CONFIG.database.params + ) + ssl_context = _build_ssl_context(db_params) + async_params = _convert_asyncpg_params(db_params) + + # When we have a full SSLContext (from sslrootcert), it takes priority + # over the raw ssl string (from sslmode). Otherwise kw.update(async_params) + # would overwrite the SSLContext with e.g. "require", losing cert verification. + if ssl_context: + async_params.pop("ssl", None) + + def creator() -> Any: + if readonly: + creds = get_readonly_db_credentials() or get_db_credentials() + else: + creds = get_db_credentials() + kw: Dict[str, Any] = { + "host": creds["host"], + "port": creds["port"], + "user": creds["user"], + "password": creds["password"], + "database": creds["dbname"], + } + if ssl_context: + kw["ssl"] = ssl_context + if async_params: + kw.update(async_params) + raw_conn = await_only(asyncpg.connect(**kw)) + return AsyncAdapt_asyncpg_connection(_asyncpg_dbapi, raw_conn) + + return creator + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _build_ssl_context(params: Dict[str, Any]) -> Optional[ssl.SSLContext]: + """Build an ``ssl.SSLContext`` from DB params if ``sslrootcert`` is set.""" + sslrootcert = params.get("sslrootcert") + if not sslrootcert: + return None + ctx = ssl.create_default_context(cafile=sslrootcert) + ctx.verify_mode = ssl.CERT_REQUIRED + return ctx + + +def _convert_asyncpg_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Convert DB params dict for asyncpg compatibility. + + asyncpg uses ``ssl`` instead of ``sslmode`` and does not accept + ``sslrootcert`` as a connection parameter (it's handled via + ``ssl.SSLContext`` passed separately). + + See: https://github.com/MagicStack/asyncpg/issues/737 + ref: https://github.com/sqlalchemy/sqlalchemy/discussions/5975 + """ + converted = deepcopy(params) + if "sslmode" in converted: + converted["ssl"] = converted.pop("sslmode") + converted.pop("sslrootcert", None) + return converted diff --git a/src/fides/common/session_management.py b/src/fides/common/session_management.py index ac0844195bc..5024074b284 100644 --- a/src/fides/common/session_management.py +++ b/src/fides/common/session_management.py @@ -16,6 +16,7 @@ from fides.api.db.ctl_session import async_session from fides.api.db.session import get_db_engine, get_db_session +from fides.common.engine_creators import make_sync_creator from fides.config import CONFIG T = TypeVar("T") @@ -28,12 +29,16 @@ def get_api_session() -> Session: global _engine # pylint: disable=W0603 if not _engine: _engine = get_db_engine( - config=CONFIG, + creator=make_sync_creator( + connect_args={ + "keepalives": 1, + "keepalives_idle": CONFIG.database.api_engine_keepalives_idle, + "keepalives_interval": CONFIG.database.api_engine_keepalives_interval, + "keepalives_count": CONFIG.database.api_engine_keepalives_count, + }, + ), pool_size=CONFIG.database.api_engine_pool_size, max_overflow=CONFIG.database.api_engine_max_overflow, - keepalives_idle=CONFIG.database.api_engine_keepalives_idle, - keepalives_interval=CONFIG.database.api_engine_keepalives_interval, - keepalives_count=CONFIG.database.api_engine_keepalives_count, pool_pre_ping=CONFIG.database.api_engine_pool_pre_ping, ) SessionLocal = get_db_session(CONFIG, engine=_engine) @@ -140,12 +145,17 @@ def get_readonly_api_session() -> Session: global _readonly_engine # pylint: disable=W0603 if not _readonly_engine: _readonly_engine = get_db_engine( - database_uri=CONFIG.database.sqlalchemy_readonly_database_uri, + creator=make_sync_creator( + connect_args={ + "keepalives": 1, + "keepalives_idle": CONFIG.database.api_engine_keepalives_idle, + "keepalives_interval": CONFIG.database.api_engine_keepalives_interval, + "keepalives_count": CONFIG.database.api_engine_keepalives_count, + }, + readonly=True, + ), pool_size=CONFIG.database.api_engine_pool_size, max_overflow=CONFIG.database.api_engine_max_overflow, - keepalives_idle=CONFIG.database.api_engine_keepalives_idle, - keepalives_interval=CONFIG.database.api_engine_keepalives_interval, - keepalives_count=CONFIG.database.api_engine_keepalives_count, pool_pre_ping=CONFIG.database.api_engine_pool_pre_ping, ) SessionLocal = get_db_session(CONFIG, engine=_readonly_engine) diff --git a/src/fides/config/database_settings.py b/src/fides/config/database_settings.py index 53144552453..f279a2f4d41 100644 --- a/src/fides/config/database_settings.py +++ b/src/fides/config/database_settings.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Dict, Optional, cast -from urllib.parse import quote, quote_plus, urlencode +from urllib.parse import quote, quote_plus, unquote_plus, urlencode from pydantic import ( Field, @@ -275,6 +275,18 @@ def escape_password(cls, value: Optional[str]) -> Optional[str]: return quote_plus(value) return value + @property + def raw_password(self) -> str: + """Return password unescaped for direct driver use (psycopg2/asyncpg).""" + return unquote_plus(self.password) + + @property + def raw_readonly_password(self) -> Optional[str]: + """Return readonly password unescaped for direct driver use.""" + if self.readonly_password: + return unquote_plus(self.readonly_password) + return None + @field_validator("sync_database_uri", mode="before") @classmethod def assemble_sync_database_uri( diff --git a/tests/lib/test_engine_creators.py b/tests/lib/test_engine_creators.py new file mode 100644 index 00000000000..8a3c22cfdbd --- /dev/null +++ b/tests/lib/test_engine_creators.py @@ -0,0 +1,262 @@ +"""Tests for engine creator factories and credential helpers.""" + +import ssl +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.ext.asyncio import create_async_engine + +from fides.common.engine_creators import ( + _build_ssl_context, + _convert_asyncpg_params, + get_db_credentials, + get_readonly_db_credentials, + make_async_creator, + make_sync_creator, +) +from fides.config import CONFIG +from fides.config.database_settings import DatabaseSettings + + +class TestGetDbCredentials: + def test_returns_expected_fields(self) -> None: + creds = get_db_credentials() + assert set(creds.keys()) == {"host", "port", "user", "password", "dbname"} + + def test_port_is_int(self) -> None: + creds = get_db_credentials() + assert isinstance(creds["port"], int) + + def test_password_is_unescaped(self) -> None: + """raw_password should reverse the quote_plus escaping.""" + creds = get_db_credentials() + # The raw password should not contain URL-encoded characters + # unless the original password literally contains them + assert creds["password"] == CONFIG.database.raw_password + + def test_uses_test_db_in_test_mode(self) -> None: + creds = get_db_credentials() + if CONFIG.test_mode: + assert creds["dbname"] == CONFIG.database.test_db + else: + assert creds["dbname"] == CONFIG.database.db + + +class TestGetReadonlyDbCredentials: + def test_returns_none_when_not_configured(self) -> None: + if not CONFIG.database.readonly_server: + assert get_readonly_db_credentials() is None + + +class TestRawPassword: + """Verify raw_password round-trips passwords with special characters.""" + + @pytest.mark.parametrize( + "password", + [ + "simple", + "p@ssw0rd", + "pass#word", + "pass%word", + "pass/word", + "p@ss#w%rd/123", + "has spaces", + "has+plus", + ], + ) + def test_raw_password_round_trip(self, password: str) -> None: + """Constructing DatabaseSettings with special-char passwords + should produce a raw_password that matches the original.""" + settings = DatabaseSettings(password=password) + assert settings.raw_password == password + + @pytest.mark.parametrize("password", ["p@ssw0rd", "pass%word"]) + def test_raw_readonly_password_round_trip(self, password: str) -> None: + """readonly_password should also round-trip through quote_plus.""" + settings = DatabaseSettings( + readonly_server="replica", readonly_password=password + ) + assert settings.raw_readonly_password == password + + def test_raw_readonly_password_none_when_not_set(self) -> None: + settings = DatabaseSettings() + assert settings.raw_readonly_password is None + + def test_pre_encoded_password_treated_as_literal(self) -> None: + """Passwords are always treated as raw values, never as pre-encoded. + + If a user sets their password to "foo%40bar" (literally containing + the characters %, 4, 0), raw_password returns "foo%40bar" — NOT + "foo@bar". This matches the old URI-based path where escape_password + would double-encode %40 to %2540 in the URI, and psycopg2 would + decode it back to %40. + """ + settings = DatabaseSettings(password="foo%40bar") + assert settings.raw_password == "foo%40bar" + + +class TestConvertAsyncpgParams: + def test_converts_sslmode_to_ssl(self) -> None: + params = {"sslmode": "require", "other": "value"} + result = _convert_asyncpg_params(params) + assert "sslmode" not in result + assert result["ssl"] == "require" + assert result["other"] == "value" + + def test_drops_sslrootcert(self) -> None: + params = {"sslrootcert": "/path/to/cert.pem", "other": "value"} + result = _convert_asyncpg_params(params) + assert "sslrootcert" not in result + assert result["other"] == "value" + + def test_does_not_mutate_input(self) -> None: + params = {"sslmode": "require", "sslrootcert": "/path"} + _convert_asyncpg_params(params) + assert "sslmode" in params + assert "sslrootcert" in params + + def test_empty_params(self) -> None: + assert _convert_asyncpg_params({}) == {} + + +class TestBuildSslContext: + def test_returns_none_without_sslrootcert(self) -> None: + assert _build_ssl_context({}) is None + assert _build_ssl_context({"sslmode": "require"}) is None + + def test_returns_context_with_valid_sslrootcert(self, tmp_path) -> None: + """Success path: a valid CA cert produces a usable SSLContext.""" + # Generate a self-signed cert for testing + import datetime + + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "test-ca"), + ] + ) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after( + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(days=1) + ) + .sign(key, hashes.SHA256()) + ) + cert_file = tmp_path / "ca.pem" + cert_file.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + + ctx = _build_ssl_context({"sslrootcert": str(cert_file)}) + assert isinstance(ctx, ssl.SSLContext) + assert ctx.verify_mode == ssl.CERT_REQUIRED + + def test_returns_none_with_invalid_cert(self, tmp_path) -> None: + cert_file = tmp_path / "bad.pem" + cert_file.write_text("not a cert") + with pytest.raises(ssl.SSLError): + _build_ssl_context({"sslrootcert": str(cert_file)}) + + +class TestMakeSyncCreator: + def test_returns_callable(self) -> None: + creator = make_sync_creator() + assert callable(creator) + + def test_creator_opens_working_connection(self) -> None: + """The sync creator should produce a real psycopg2 connection.""" + creator = make_sync_creator() + conn = creator() + try: + cur = conn.cursor() + cur.execute("SELECT 1") + assert cur.fetchone() == (1,) + finally: + conn.close() + + def test_creator_with_connect_args(self) -> None: + """connect_args like keepalives should be forwarded.""" + creator = make_sync_creator( + connect_args={"keepalives": 1, "keepalives_idle": 30} + ) + conn = creator() + try: + cur = conn.cursor() + cur.execute("SELECT 1") + assert cur.fetchone() == (1,) + finally: + conn.close() + + def test_engine_with_sync_creator(self) -> None: + """A full engine using the sync creator can execute queries.""" + creator = make_sync_creator() + engine = create_engine("postgresql+psycopg2://", creator=creator, pool_size=1) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() + + +class TestMakeAsyncCreator: + def test_returns_callable(self) -> None: + creator = make_async_creator() + assert callable(creator) + + async def test_engine_with_async_creator(self) -> None: + """A full async engine using the async creator can execute queries.""" + creator = make_async_creator() + engine = create_async_engine( + "postgresql+asyncpg://", creator=creator, pool_size=1 + ) + try: + async with engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + await engine.dispose() + + @patch("fides.common.engine_creators.asyncpg") + @patch("fides.common.engine_creators.AsyncAdapt_asyncpg_connection") + def test_ssl_context_not_overwritten_by_async_params( + self, mock_adapt_conn, mock_asyncpg + ) -> None: + """When both sslrootcert and sslmode are configured, the SSLContext + must not be overwritten by the raw ssl string from async_params.""" + mock_ssl_context = MagicMock(spec=ssl.SSLContext) + + with ( + patch( + "fides.common.engine_creators._build_ssl_context", + return_value=mock_ssl_context, + ), + patch( + "fides.common.engine_creators._convert_asyncpg_params", + return_value={"ssl": "require", "other": "value"}, + ), + patch( + "fides.common.engine_creators.await_only", + side_effect=lambda coro: coro, + ), + ): + creator = make_async_creator() + creator() + + # asyncpg.connect was called — check the ssl kwarg + connect_kwargs = mock_asyncpg.connect.call_args[1] + assert connect_kwargs["ssl"] is mock_ssl_context, ( + f"Expected SSLContext but got {connect_kwargs['ssl']!r} — " + "async_params overwrote the ssl_context" + ) + assert connect_kwargs["other"] == "value" diff --git a/tests/lib/test_session.py b/tests/lib/test_session.py index 8d74f94cb59..9c8865b933d 100644 --- a/tests/lib/test_session.py +++ b/tests/lib/test_session.py @@ -1,6 +1,8 @@ import pytest +from sqlalchemy import text from fides.api.db import session +from fides.common.engine_creators import make_sync_creator from fides.config import get_config @@ -20,3 +22,75 @@ def test_get_session_test_modes(self, test_mode: bool) -> None: db_engine = session.get_db_engine(config=config) config.test_mode = original_value assert db_engine + + def test_get_engine_with_creator(self) -> None: + """Engine created via creator= can execute queries.""" + creator = make_sync_creator() + engine = session.get_db_engine(creator=creator, pool_size=1) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() + + def test_creator_with_keepalives_raises(self) -> None: + """Passing both creator and keepalives params is an error.""" + creator = make_sync_creator() + with pytest.raises( + ValueError, + match="keepalives_idle/interval/count cannot be used with creator", + ): + session.get_db_engine(creator=creator, keepalives_idle=30) + + def test_creator_with_database_uri_raises(self) -> None: + """Passing both creator and database_uri is an error.""" + creator = make_sync_creator() + with pytest.raises( + ValueError, + match="database_uri/config cannot be used with creator", + ): + session.get_db_engine( + creator=creator, database_uri="postgresql://localhost/db" + ) + + def test_creator_with_config_raises(self) -> None: + """Passing both creator and config is an error.""" + creator = make_sync_creator() + config = get_config() + with pytest.raises( + ValueError, + match="database_uri/config cannot be used with creator", + ): + session.get_db_engine(creator=creator, config=config) + + def test_config_with_keepalives(self) -> None: + """URI path with keepalives produces a working engine.""" + config = get_config() + engine = session.get_db_engine( + config=config, + pool_size=1, + keepalives_idle=30, + keepalives_interval=10, + keepalives_count=5, + ) + try: + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose() + + def test_disable_pooling(self) -> None: + """disable_pooling uses NullPool — no connections are kept.""" + from sqlalchemy.pool import NullPool + + config = get_config() + engine = session.get_db_engine(config=config, disable_pooling=True) + try: + assert isinstance(engine.pool, NullPool) + with engine.connect() as conn: + result = conn.execute(text("SELECT 1")) + assert result.scalar() == 1 + finally: + engine.dispose()