Skip to content

Commit f8dad10

Browse files
committed
changes done in auth util as well as sea http
Signed-off-by: Nikhil Suri <[email protected]>
1 parent 23af8e2 commit f8dad10

File tree

7 files changed

+83
-92
lines changed

7 files changed

+83
-92
lines changed

src/databricks/sql/auth/auth_utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,6 @@
77
logger = logging.getLogger(__name__)
88

99

10-
def parse_hostname(hostname: str) -> str:
11-
"""
12-
Normalize the hostname to include scheme and trailing slash.
13-
14-
Args:
15-
hostname: The hostname to normalize
16-
17-
Returns:
18-
Normalized hostname with scheme and trailing slash
19-
"""
20-
if not hostname.startswith("http://") and not hostname.startswith("https://"):
21-
hostname = f"https://{hostname}"
22-
if not hostname.endswith("/"):
23-
hostname = f"{hostname}/"
24-
return hostname
25-
26-
2710
def decode_token(access_token: str) -> Optional[Dict]:
2811
"""
2912
Decode a JWT token without verification to extract claims.

src/databricks/sql/auth/token_federation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from databricks.sql.auth.authenticators import AuthProvider
88
from databricks.sql.auth.auth_utils import (
9-
parse_hostname,
109
decode_token,
1110
is_same_host,
1211
)
12+
from databricks.sql.common.url_utils import normalize_host_with_protocol
1313
from databricks.sql.common.http import HttpMethod
1414

1515
logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ def __init__(
9999
if not http_client:
100100
raise ValueError("http_client is required for TokenFederationProvider")
101101

102-
self.hostname = parse_hostname(hostname)
102+
self.hostname = normalize_host_with_protocol(hostname)
103103
self.external_provider = external_provider
104104
self.http_client = http_client
105105
self.identity_federation_client_id = identity_federation_client_id
@@ -164,7 +164,7 @@ def _should_exchange_token(self, access_token: str) -> bool:
164164

165165
def _exchange_token(self, access_token: str) -> Token:
166166
"""Exchange the external token for a Databricks token."""
167-
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
167+
token_url = f"{self.hostname}{self.TOKEN_EXCHANGE_ENDPOINT}"
168168

169169
data = {
170170
"grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE,

src/databricks/sql/backend/sea/utils/http_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from databricks.sql.common.http_utils import (
1919
detect_and_parse_proxy,
2020
)
21+
from databricks.sql.common.url_utils import normalize_host_with_protocol
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -66,8 +67,9 @@ def __init__(
6667
self.auth_provider = auth_provider
6768
self.ssl_options = ssl_options
6869

69-
# Build base URL
70-
self.base_url = f"https://{server_hostname}:{self.port}"
70+
# Build base URL using url_utils for consistent normalization
71+
normalized_host = normalize_host_with_protocol(server_hostname)
72+
self.base_url = f"{normalized_host}:{self.port}"
7173

7274
# Parse URL for proxy handling
7375
parsed_url = urllib.parse.urlparse(self.base_url)

src/databricks/sql/common/url_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def normalize_host_with_protocol(host: str) -> str:
77
"""
8-
Normalize a connection hostname by ensuring it has a protocol and removing trailing slashes.
8+
Normalize a connection hostname by ensuring it has a protocol.
99
1010
This is useful for handling cases where users may provide hostnames with or without protocols
1111
(common with dbt-databricks users copying URLs from their browser).
@@ -15,12 +15,13 @@ def normalize_host_with_protocol(host: str) -> str:
1515
and may or may not have a trailing slash
1616
1717
Returns:
18-
Normalized hostname with protocol prefix and no trailing slash
18+
Normalized hostname with protocol prefix and no trailing slashes
1919
2020
Examples:
2121
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
2222
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
23-
normalize_host_with_protocol("HTTPS://myserver.com") -> "https://myserver.com"
23+
normalize_host_with_protocol("HTTPS://myserver.com/") -> "https://myserver.com"
24+
normalize_host_with_protocol("http://localhost:8080/") -> "http://localhost:8080"
2425
2526
Raises:
2627
ValueError: If host is None or empty string
@@ -29,7 +30,7 @@ def normalize_host_with_protocol(host: str) -> str:
2930
if not host or not host.strip():
3031
raise ValueError("Host cannot be None or empty")
3132

32-
# Remove trailing slash
33+
# Remove trailing slashes
3334
host = host.rstrip("/")
3435

3536
# Add protocol if not present (case-insensitive check)

tests/unit/test_sea_http_client.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,40 @@ def sea_http_client(self, mock_auth_provider, ssl_options):
4444
client._pool = Mock()
4545
return client
4646

47+
@pytest.mark.parametrize(
48+
"server_hostname,port,expected_base_url",
49+
[
50+
# Basic hostname without protocol
51+
("myserver.com", 443, "https://myserver.com:443"),
52+
# Hostname with trailing slash
53+
("myserver.com/", 443, "https://myserver.com:443"),
54+
# Hostname with https:// protocol
55+
("https://myserver.com", 443, "https://myserver.com:443"),
56+
# Hostname with http:// protocol (preserved as-is)
57+
("http://myserver.com", 443, "http://myserver.com:443"),
58+
# Hostname with protocol and trailing slash
59+
("https://myserver.com/", 443, "https://myserver.com:443"),
60+
# Custom port
61+
("myserver.com", 8080, "https://myserver.com:8080"),
62+
# Protocol with custom port
63+
("https://myserver.com", 8080, "https://myserver.com:8080"),
64+
],
65+
)
66+
def test_base_url_construction(
67+
self, server_hostname, port, expected_base_url, mock_auth_provider, ssl_options
68+
):
69+
"""Test that base_url is constructed correctly from various hostname inputs."""
70+
with patch("databricks.sql.backend.sea.utils.http_client.HTTPSConnectionPool"):
71+
client = SeaHttpClient(
72+
server_hostname=server_hostname,
73+
port=port,
74+
http_path="/sql/1.0/warehouses/test",
75+
http_headers=[],
76+
auth_provider=mock_auth_provider,
77+
ssl_options=ssl_options,
78+
)
79+
assert client.base_url == expected_base_url
80+
4781
def test_get_command_type_from_path(self, sea_http_client):
4882
"""Test the _get_command_type_from_path method with various paths and methods."""
4983
# Test statement execution

tests/unit/test_token_federation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from databricks.sql.auth.token_federation import TokenFederationProvider, Token
88
from databricks.sql.auth.auth_utils import (
9-
parse_hostname,
109
decode_token,
1110
is_same_host,
1211
)
12+
from databricks.sql.common.url_utils import normalize_host_with_protocol
1313
from databricks.sql.common.http import HttpMethod
1414

1515

@@ -78,10 +78,10 @@ def test_init_requires_http_client(self, mock_external_provider):
7878
@pytest.mark.parametrize(
7979
"input_hostname,expected",
8080
[
81-
("test.databricks.com", "https://test.databricks.com/"),
82-
("https://test.databricks.com", "https://test.databricks.com/"),
83-
("https://test.databricks.com/", "https://test.databricks.com/"),
84-
("test.databricks.com/", "https://test.databricks.com/"),
81+
("test.databricks.com", "https://test.databricks.com"),
82+
("https://test.databricks.com", "https://test.databricks.com"),
83+
("https://test.databricks.com/", "https://test.databricks.com"),
84+
("test.databricks.com/", "https://test.databricks.com"),
8585
],
8686
)
8787
def test_hostname_normalization(
@@ -305,15 +305,15 @@ class TestUtilityFunctions:
305305
@pytest.mark.parametrize(
306306
"input_hostname,expected",
307307
[
308-
("test.databricks.com", "https://test.databricks.com/"),
309-
("https://test.databricks.com", "https://test.databricks.com/"),
310-
("https://test.databricks.com/", "https://test.databricks.com/"),
311-
("test.databricks.com/", "https://test.databricks.com/"),
308+
("test.databricks.com", "https://test.databricks.com"),
309+
("https://test.databricks.com", "https://test.databricks.com"),
310+
("https://test.databricks.com/", "https://test.databricks.com"),
311+
("test.databricks.com/", "https://test.databricks.com"),
312312
],
313313
)
314-
def test_parse_hostname(self, input_hostname, expected):
315-
"""Test hostname parsing."""
316-
assert parse_hostname(input_hostname) == expected
314+
def test_normalize_hostname(self, input_hostname, expected):
315+
"""Test hostname normalization."""
316+
assert normalize_host_with_protocol(input_hostname) == expected
317317

318318
@pytest.mark.parametrize(
319319
"url1,url2,expected",

tests/unit/test_url_utils.py

Lines changed: 25 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,65 +6,36 @@
66
class TestNormalizeHostWithProtocol:
77
"""Tests for normalize_host_with_protocol function."""
88

9-
@pytest.mark.parametrize("input_host,expected_output", [
10-
# Hostname without protocol - should add https://
11-
("myserver.com", "https://myserver.com"),
12-
("workspace.databricks.com", "https://workspace.databricks.com"),
13-
14-
# Hostname with https:// - should not duplicate
15-
("https://myserver.com", "https://myserver.com"),
16-
("https://workspace.databricks.com", "https://workspace.databricks.com"),
17-
18-
# Hostname with http:// - should preserve
19-
("http://localhost", "http://localhost"),
20-
("http://myserver.com:8080", "http://myserver.com:8080"),
21-
22-
# Hostname with port numbers
23-
("myserver.com:443", "https://myserver.com:443"),
24-
("https://myserver.com:443", "https://myserver.com:443"),
25-
("http://localhost:8080", "http://localhost:8080"),
26-
27-
# Trailing slash - should be removed
28-
("myserver.com/", "https://myserver.com"),
29-
("https://myserver.com/", "https://myserver.com"),
30-
("http://localhost/", "http://localhost"),
31-
32-
# Case-insensitive protocol handling - should normalize to lowercase
33-
("HTTPS://myserver.com", "https://myserver.com"),
34-
("HTTP://myserver.com", "http://myserver.com"),
35-
("HttPs://workspace.databricks.com", "https://workspace.databricks.com"),
36-
("HtTp://localhost:8080", "http://localhost:8080"),
37-
("HTTPS://MYSERVER.COM", "https://MYSERVER.COM"), # Only protocol lowercased
38-
39-
# Case-insensitive with trailing slashes
40-
("HTTPS://myserver.com/", "https://myserver.com"),
41-
("HTTP://localhost:8080/", "http://localhost:8080"),
42-
("HttPs://workspace.databricks.com//", "https://workspace.databricks.com"),
43-
44-
# Mixed case protocols with ports
45-
("HTTPS://myserver.com:443", "https://myserver.com:443"),
46-
("HtTp://myserver.com:8080", "http://myserver.com:8080"),
47-
48-
# Case preservation - only protocol lowercased, hostname case preserved
49-
("HTTPS://MyServer.DataBricks.COM", "https://MyServer.DataBricks.COM"),
50-
("HttPs://CamelCase.Server.com", "https://CamelCase.Server.com"),
51-
("HTTP://UPPERCASE.COM:8080", "http://UPPERCASE.COM:8080"),
52-
])
53-
def test_normalize_host_with_protocol(self, input_host, expected_output):
9+
@pytest.mark.parametrize(
10+
"input_url,expected_output",
11+
[
12+
("myserver.com", "https://myserver.com"), # Add https://
13+
("https://myserver.com", "https://myserver.com"), # No duplicate
14+
("http://localhost:8080", "http://localhost:8080"), # Preserve http://
15+
("myserver.com:443", "https://myserver.com:443"), # With port
16+
("myserver.com/", "https://myserver.com"), # Remove trailing slash
17+
("https://myserver.com///", "https://myserver.com"), # Multiple slashes
18+
("HTTPS://MyServer.COM", "https://MyServer.COM"), # Case handling
19+
],
20+
)
21+
def test_normalize_host_with_protocol(self, input_url, expected_output):
5422
"""Test host normalization with various input formats."""
55-
result = normalize_host_with_protocol(input_host)
23+
result = normalize_host_with_protocol(input_url)
5624
assert result == expected_output
57-
58-
# Additional assertion: verify protocol is always lowercase
25+
26+
# Additional assertions
5927
assert result.startswith("https://") or result.startswith("http://")
28+
assert not result.endswith("/")
6029

61-
@pytest.mark.parametrize("invalid_host", [
62-
None,
63-
"",
64-
" ", # Whitespace only
65-
])
30+
@pytest.mark.parametrize(
31+
"invalid_host",
32+
[
33+
None,
34+
"",
35+
" ", # Whitespace only
36+
],
37+
)
6638
def test_normalize_host_with_protocol_raises_on_invalid_input(self, invalid_host):
6739
"""Test that function raises ValueError for None or empty host."""
6840
with pytest.raises(ValueError, match="Host cannot be None or empty"):
6941
normalize_host_with_protocol(invalid_host)
70-

0 commit comments

Comments
 (0)