Skip to content

Commit 71ba1e1

Browse files
committed
more changes
1 parent 55b09a8 commit 71ba1e1

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

databricks/sdk/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ def client_type(self) -> ClientType:
406406
return ClientType.WORKSPACE
407407

408408
if host_type == HostType.UNIFIED:
409+
if self.workspace_id:
410+
return ClientType.WORKSPACE
409411
if self.account_id:
410412
return ClientType.ACCOUNT
411413
# Legacy workspace hosts don't have a workspace_id until AFTER the auth is resolved.

databricks/sdk/credentials_provider.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,9 @@ def _oidc_credentials_provider(
431431

432432
# Determine the audience for token exchange
433433
audience = cfg.token_audience
434-
if audience is None and cfg.client_type == ClientType.ACCOUNT:
434+
if audience is None and cfg.account_id:
435435
audience = cfg.account_id
436-
if audience is None and cfg.client_type != ClientType.ACCOUNT:
436+
if audience is None and not cfg.account_id:
437437
audience = cfg.oidc_endpoints.token_endpoint
438438

439439
# Try to get an OIDC token. If no supplier returns a token, we cannot use this authentication mode.
@@ -590,7 +590,7 @@ def token() -> oauth.Token:
590590
def refreshed_headers() -> Dict[str, str]:
591591
credentials.refresh(request)
592592
headers = {"Authorization": f"Bearer {credentials.token}"}
593-
if cfg.client_type == ClientType.ACCOUNT:
593+
if cfg.account_id:
594594
gcp_credentials.refresh(request)
595595
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token
596596
return headers
@@ -631,7 +631,7 @@ def token() -> oauth.Token:
631631
def refreshed_headers() -> Dict[str, str]:
632632
id_creds.refresh(request)
633633
headers = {"Authorization": f"Bearer {id_creds.token}"}
634-
if cfg.client_type == ClientType.ACCOUNT:
634+
if cfg.account_id:
635635
gcp_impersonated_credentials.refresh(request)
636636
headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token
637637
return headers

tests/test_config.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,26 @@ def test_is_account_client_backward_compatibility():
361361
assert config_account.is_account_client
362362

363363

364-
def test_is_account_client_raises_on_unified_host():
365-
"""Test that is_account_client raises ValueError when used with unified hosts."""
364+
def test_is_account_client_on_unified_host():
365+
"""Test that is_account_client returns truthiness of account_id for unified hosts."""
366366
config = Config(
367367
host="https://unified.databricks.com",
368368
experimental_is_unified_host=True,
369369
workspace_id="test-workspace",
370370
token="test-token",
371371
)
372-
with pytest.raises(ValueError, match="is_account_client cannot be used with unified hosts"):
373-
_ = config.is_account_client
372+
# Should be falsy since account_id is not set
373+
assert not config.is_account_client
374+
375+
# With account_id set, should be truthy
376+
config_with_account = Config(
377+
host="https://unified.databricks.com",
378+
experimental_is_unified_host=True,
379+
workspace_id="test-workspace",
380+
account_id="test-account",
381+
token="test-token",
382+
)
383+
assert config_with_account.is_account_client
374384

375385

376386
def test_oidc_endpoints_unified_workspace(mocker, requests_mock):

0 commit comments

Comments
 (0)