Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Release v0.86.0

### New Features and Improvements
* Added `custom_headers` parameter to `WorkspaceClient` and `AccountClient` to support custom HTTP headers in all API requests ([#1245](https://github.com/databricks/databricks-sdk-py/pull/1245)).

### Security

Expand Down
16 changes: 16 additions & 0 deletions databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,27 @@ def __init__(
product=None,
product_version=None,
clock: Optional[Clock] = None,
custom_headers: Optional[Dict[str, str]] = None,
**kwargs,
):
"""Initialize a Config object.

Args:
credentials_provider: (Deprecated) Use credentials_strategy instead.
credentials_strategy: Custom credentials strategy for authentication.
product: Product name for User-Agent header.
product_version: Product version for User-Agent header.
clock: Clock instance for time-related operations.
custom_headers: Optional dictionary of custom HTTP headers to include in all API requests.
These headers will be automatically added to every request made by the client.
Request-specific headers passed to individual API calls will override these custom headers
if there is a conflict. Example: {"X-Request-ID": "123", "X-Custom-Header": "value"}
**kwargs: Additional configuration parameters.
"""
self._header_factory = None
self._inner = {}
self._user_agent_other_info = []
self._custom_headers = custom_headers or {}
if credentials_strategy and credentials_provider:
raise ValueError("When providing `credentials_strategy` field, `credential_provider` cannot be specified.")
if credentials_provider:
Expand Down
7 changes: 6 additions & 1 deletion databricks/sdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,16 @@ def do(
# Once we've fixed the OpenAPI spec, we can remove this
path = re.sub("^/api/2.0/fs/files//", "/api/2.0/fs/files/", path)
url = f"{self._cfg.host}{path}"

# Merge custom headers with request-specific headers
# Request-specific headers take precedence
merged_headers = {**self._cfg._custom_headers, **(headers or {})}

return self._api_client.do(
method=method,
url=url,
query=query,
headers=headers,
headers=merged_headers,
body=body,
raw=raw,
files=files,
Expand Down
111 changes: 111 additions & 0 deletions tests/test_custom_headers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Test custom headers functionality"""

from unittest.mock import MagicMock, patch

import pytest

from databricks.sdk import AccountClient, WorkspaceClient
from databricks.sdk.core import Config


def test_workspace_client_custom_headers():
"""Test that WorkspaceClient passes custom headers to all requests"""
with patch("databricks.sdk.core.ApiClient") as mock_api_client:
# Create a mock for the _api_client.do method
mock_do = MagicMock(return_value={})
mock_api_client_instance = MagicMock()
mock_api_client_instance.do = mock_do
mock_api_client.return_value = mock_api_client_instance

# Create WorkspaceClient with custom headers
w = WorkspaceClient(
host="https://test.databricks.com", token="test-token", custom_headers={"X-Custom-Header": "test-value"}
)

# Verify custom headers are stored in config
assert w._config._custom_headers == {"X-Custom-Header": "test-value"}


def test_account_client_custom_headers():
"""Test that AccountClient passes custom headers to all requests"""
with patch("databricks.sdk.core.ApiClient") as mock_api_client:
mock_do = MagicMock(return_value={})
mock_api_client_instance = MagicMock()
mock_api_client_instance.do = mock_do
mock_api_client.return_value = mock_api_client_instance

# Create AccountClient with custom headers
a = AccountClient(
host="https://accounts.cloud.databricks.com",
account_id="test-account-id",
token="test-token",
custom_headers={"X-Custom-Header": "test-value"},
)

# Verify custom headers are stored in config
assert a._config._custom_headers == {"X-Custom-Header": "test-value"}


def test_config_custom_headers():
"""Test that Config stores custom headers"""
config = Config(
host="https://test.databricks.com",
token="test-token",
custom_headers={"X-Custom-Header": "test-value", "X-Another": "another-value"},
)

assert config._custom_headers == {"X-Custom-Header": "test-value", "X-Another": "another-value"}


def test_api_client_merges_custom_headers(requests_mock):
"""Test that ApiClient.do() merges custom headers with request headers"""
from databricks.sdk.core import ApiClient

# Create config with custom headers
config = Config(
host="https://test.databricks.com", token="test-token", custom_headers={"X-Custom-Header": "custom-value"}
)

# Create ApiClient
api_client = ApiClient(config)

# Mock the request
requests_mock.get(
"https://test.databricks.com/api/2.0/clusters/list",
json={"clusters": []},
)

# Make a request
api_client.do("GET", "/api/2.0/clusters/list")

# Verify the custom header was included in the request
assert requests_mock.last_request.headers["X-Custom-Header"] == "custom-value"


def test_request_headers_override_custom_headers(requests_mock):
"""Test that request-specific headers override custom headers"""
from databricks.sdk.core import ApiClient

# Create config with custom headers
config = Config(
host="https://test.databricks.com", token="test-token", custom_headers={"X-Custom-Header": "custom-value"}
)

# Create ApiClient
api_client = ApiClient(config)

# Mock the request
requests_mock.get(
"https://test.databricks.com/api/2.0/clusters/list",
json={"clusters": []},
)

# Make a request with header override
response = api_client.do("GET", "/api/2.0/clusters/list", headers={"X-Custom-Header": "overridden-value"})

# Verify the request header overrode the custom header
assert requests_mock.last_request.headers["X-Custom-Header"] == "overridden-value"


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading