diff --git a/shapeshifter_uftp/client/base_client.py b/shapeshifter_uftp/client/base_client.py
index a09c367..31a7542 100644
--- a/shapeshifter_uftp/client/base_client.py
+++ b/shapeshifter_uftp/client/base_client.py
@@ -11,6 +11,7 @@
from ..exceptions import ClientTransportException
from ..logging import logger
from ..uftp import PayloadMessage, PayloadMessageResponse, SignedMessage
+from shapeshifter_uftp.token_manager import AuthTokenManager
class ShapeshifterClient:
@@ -34,6 +35,7 @@ def __init__(
recipient_domain: str,
recipient_endpoint: str = None,
recipient_signing_key: str = None,
+ oauth_token_manager: AuthTokenManager = None,
):
"""
Shapeshifter client class that allows you to initiate messages to a different party.
@@ -55,6 +57,7 @@ def __init__(
self.recipient_domain = recipient_domain
self.recipient_endpoint = recipient_endpoint
self.recipient_signing_key = recipient_signing_key
+ self.oauth_token_manager = oauth_token_manager
# The outgoing queue and scheduler are used when queueing
# messages for delivery later. This allows the Shapeshifter
@@ -114,11 +117,23 @@ def _send_message(self, message: PayloadMessage) -> PayloadMessageResponse:
logger.debug(f"Sending message to {self.recipient_endpoint}:")
logger.debug(serialized_message)
+ # Find the right headers to use for the request. If we have
+ # an OAuth2 token manager, we will use that to get the
+ # request headers. If not, we will use the basic Content-Type
+ try:
+ if self.oauth_token_manager:
+ header = self.oauth_token_manager.get_request_headers()
+ else:
+ header = {"Content-Type": "text/xml; charset=utf-8"}
+ except Exception as e:
+ logger.warning(f"Failed to get OAuth2 headers, falling back to basic headers: {e}")
+ header = {"Content-Type": "text/xml; charset=utf-8"}
+
# Send the request to the relevant endpoint
response = requests.post(
self.recipient_endpoint,
data=serialized_message,
- headers={"Content-Type": "text/xml; charset=utf-8"},
+ headers=header,
timeout=self.request_timeout,
)
if response.status_code != 200:
diff --git a/shapeshifter_uftp/service/base_service.py b/shapeshifter_uftp/service/base_service.py
index 30c3d01..d6e7634 100644
--- a/shapeshifter_uftp/service/base_service.py
+++ b/shapeshifter_uftp/service/base_service.py
@@ -25,6 +25,7 @@
PayloadMessageResponse,
SignedMessage,
)
+from ..token_manager import AuthTokenManager
class ShapeshifterService():
@@ -46,11 +47,15 @@ def __init__(
self,
sender_domain,
signing_key,
+ oauth_token_endpoint: str = None,
+ oauth_client_id: str = None,
+ oauth_client_secret: str = None,
+ token_refresh_buffer: int = 30,
key_lookup_function=None,
endpoint_lookup_function=None,
host: str = "0.0.0.0",
port: int = 8080,
- path="/shapeshifter/api/v3/message",
+ path="/shapeshifter/api/v3/message"
):
"""
:param sender_domain: our sender domain (FQDN) that the recipient uses to look us up.
@@ -64,6 +69,9 @@ def __init__(
:param host: the host to bind the server to (usually 127.0.0.1 or 0.0.0.0)
:param port: the port to bind the server to (default: 8080)
:param path: the URL path that the server listens on (default: /shapeshifter/api/v3/message)
+ :param oauth_token_endpoint: the OAuth2 token endpoint to use for obtaining access tokens
+ :param oauth_client_id: the OAuth2 client ID to use for obtaining access tokens
+ :param oauth_client_secret: the OAuth2 client secret to use for obtaining access tokens
"""
# Set the sender domain, which is used
@@ -87,6 +95,18 @@ def __init__(
# The FastAPI web app takes care of routing messages to the
# (one) endpoint, and by virtue of FastAPI-XML convert the
# python-friendly objects into XML and vice versa.
+
+ # Create Auth Manager for OAuth2 Client Credentials flow (if configured)
+ if oauth_token_endpoint and oauth_client_id and oauth_client_secret:
+ self.auth_token_manager = AuthTokenManager(
+ oauth_token_endpoint=oauth_token_endpoint,
+ oauth_client_id=oauth_client_id,
+ oauth_client_secret=oauth_client_secret,
+ token_refresh_buffer=token_refresh_buffer
+ )
+ else:
+ self.auth_token_manager = None
+
self.app = FastAPI(default_response_class=XmlAppResponse)
self.app.router.route_class = XmlRoute
self.app.router.add_api_route(
@@ -249,7 +269,8 @@ def _get_client(self, recipient_domain, recipient_role):
signing_key = self.signing_key,
recipient_domain = recipient_domain,
recipient_endpoint = recipient_endpoint,
- recipient_signing_key = recipient_signing_key
+ recipient_signing_key = recipient_signing_key,
+ oauth_token_manager = self.auth_token_manager
)
def __enter__(self):
diff --git a/shapeshifter_uftp/token_manager.py b/shapeshifter_uftp/token_manager.py
new file mode 100644
index 0000000..04e4427
--- /dev/null
+++ b/shapeshifter_uftp/token_manager.py
@@ -0,0 +1,122 @@
+from datetime import datetime, timezone, timedelta
+
+import requests
+
+from .logging import logger
+
+from typing import Optional
+from threading import Lock
+
+class AuthTokenManager:
+ """
+ A token manager that can be used to manage tokens for the Shapeshifter client.
+ It handles OAuth2 Client Credentials flow to obtain and refresh tokens.
+ This class is thread-safe and ensures that tokens are refreshed only when necessary.
+ It provides a method to get request headers with the Bearer token included.
+ """
+ request_timeout: int = 30
+
+ def __init__(self,
+ oauth_token_endpoint: str,
+ oauth_client_id: str,
+ oauth_client_secret: str,
+ token_refresh_buffer: int = 30):
+ self.oauth_token_endpoint = oauth_token_endpoint
+ self.oauth_client_id = oauth_client_id
+ self.oauth_client_secret = oauth_client_secret
+ self.token_refresh_buffer = token_refresh_buffer
+ self._access_token: Optional[str] = None
+ self._token_expires_at: Optional[datetime] = None
+ self._token_lock = Lock()
+
+ def _is_oauth_configured(self) -> bool:
+ """Check if OAuth2 is properly configured."""
+ return all([
+ self.oauth_token_endpoint,
+ self.oauth_client_id,
+ self.oauth_client_secret
+ ])
+
+ def _is_token_valid(self) -> bool:
+ """Check if the current token is valid and not close to expiring."""
+ if not self._access_token or not self._token_expires_at:
+ return False
+
+ buffer_time = datetime.now(timezone.utc) + timedelta(seconds=self.token_refresh_buffer)
+ return self._token_expires_at > buffer_time
+
+ def _obtain_bearer_token(self) -> str:
+ """
+ Obtain a Bearer token using OAuth2 Client Credentials flow.
+
+ :return: Access token string
+ :raises: Exception if token acquisition fails
+ """
+ if not self._is_oauth_configured():
+ raise ValueError("OAuth2 not configured. Please provide oauth_token_endpoint, oauth_client_id, and oauth_client_secret.")
+
+ token_data = {
+ 'grant_type': 'client_credentials',
+ 'client_id': self.oauth_client_id,
+ 'client_secret': self.oauth_client_secret
+ }
+
+ headers = {
+ 'Content-Type': 'application/x-www-form-urlencoded'
+ }
+
+ try:
+ response = requests.post(
+ self.oauth_token_endpoint,
+ data=token_data,
+ headers=headers,
+ timeout=self.request_timeout
+ )
+ response.raise_for_status()
+
+ token_response = response.json()
+ access_token = token_response['access_token']
+ expires_in = token_response.get('expires_in', 300)
+
+ self._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
+
+ logger.info(f"Successfully obtained OAuth2 token, expires at {self._token_expires_at}")
+ return access_token
+
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Failed to obtain OAuth2 token: {e}")
+ raise
+ except KeyError as e:
+ logger.error(f"Invalid token response format, missing key: {e}")
+ raise
+
+ def _get_valid_token(self) -> Optional[str]:
+ """
+ Get a valid Bearer token, refreshing if necessary.
+ Thread-safe implementation.
+
+ :return: Valid access token or None if OAuth2 not configured
+ """
+ if not self._is_oauth_configured():
+ return None
+
+ with self._token_lock:
+ if not self._is_token_valid():
+ logger.debug("Token invalid or expired, obtaining new token")
+ self._access_token = self._obtain_bearer_token()
+
+ return self._access_token
+
+ def get_request_headers(self) -> dict:
+ """
+ Get headers for HTTP requests, including Bearer token if configured.
+
+ :return: Dictionary of headers
+ """
+ headers = {"Content-Type": "text/xml; charset=utf-8"}
+
+ token = self._get_valid_token()
+ if token:
+ headers["Authorization"] = f"Bearer {token}"
+
+ return headers
\ No newline at end of file
diff --git a/test/helpers/services.py b/test/helpers/services.py
index dea0598..72a7929 100644
--- a/test/helpers/services.py
+++ b/test/helpers/services.py
@@ -36,7 +36,7 @@ def key_lookup_function(domain, role):
return CRO_PUBLIC_KEY
elif domain == "dso.dev":
return DSO_PUBLIC_KEY
-
+
class DummyAgrService(ShapeshifterAgrService):
diff --git a/test/test_oauth_token_unit.py b/test/test_oauth_token_unit.py
new file mode 100644
index 0000000..41f6d43
--- /dev/null
+++ b/test/test_oauth_token_unit.py
@@ -0,0 +1,472 @@
+import pytest
+import requests
+from unittest.mock import Mock, patch, MagicMock
+from datetime import datetime, timezone, timedelta
+import json
+import time
+
+from shapeshifter_uftp.token_manager import AuthTokenManager
+from shapeshifter_uftp.service.base_service import ShapeshifterService
+from shapeshifter_uftp.client.base_client import ShapeshifterClient
+from shapeshifter_uftp.uftp import PayloadMessage
+
+
+class TestAuthTokenManager:
+ """Test suite for AuthTokenManager class"""
+
+ @pytest.fixture
+ def token_manager(self):
+ """Fixture for AuthTokenManager instance"""
+ return AuthTokenManager(
+ oauth_token_endpoint="https://test.example.com/oauth2/token",
+ oauth_client_id="test_client_id",
+ oauth_client_secret="test_client_secret",
+ token_refresh_buffer=30
+ )
+
+ def test_init(self, token_manager):
+ """Test AuthTokenManager initialization"""
+ assert token_manager.oauth_token_endpoint == "https://test.example.com/oauth2/token"
+ assert token_manager.oauth_client_id == "test_client_id"
+ assert token_manager.oauth_client_secret == "test_client_secret"
+ assert token_manager.token_refresh_buffer == 30
+ assert token_manager._access_token is None
+ assert token_manager._token_expires_at is None
+
+ def test_is_oauth_configured(self, token_manager):
+ """Test OAuth2 configuration check"""
+ assert token_manager._is_oauth_configured() is True
+
+ # Test with missing configuration
+ incomplete_manager = AuthTokenManager("", "", "")
+ assert incomplete_manager._is_oauth_configured() is False
+
+ def test_is_token_valid_no_token(self, token_manager):
+ """Test token validity when no token exists"""
+ assert token_manager._is_token_valid() is False
+
+ def test_is_token_valid_expired(self, token_manager):
+ """Test token validity when token is expired"""
+ token_manager._access_token = "test_token"
+ token_manager._token_expires_at = datetime.now(timezone.utc) - timedelta(seconds=60)
+ assert token_manager._is_token_valid() is False
+
+ def test_is_token_valid_near_expiry(self, token_manager):
+ """Test token validity when token is near expiry (within buffer)"""
+ token_manager._access_token = "test_token"
+ token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=20)
+ assert token_manager._is_token_valid() is False
+
+ def test_is_token_valid_good_token(self, token_manager):
+ """Test token validity when token is valid"""
+ token_manager._access_token = "test_token"
+ token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=300) # 5 minutes
+ assert token_manager._is_token_valid() is True
+
+ @patch('requests.post')
+ def test_obtain_bearer_token_success(self, mock_post, token_manager):
+ """Test successful token acquisition"""
+ mock_response = Mock()
+ mock_response.raise_for_status.return_value = None
+ mock_response.json.return_value = {
+ 'access_token': 'test_access_token_123',
+ 'expires_in': 300
+ }
+ mock_post.return_value = mock_response
+
+ token = token_manager._obtain_bearer_token()
+
+ assert token == 'test_access_token_123'
+ assert token_manager._token_expires_at is not None
+
+ mock_post.assert_called_once_with(
+ "https://test.example.com/oauth2/token",
+ data={
+ 'grant_type': 'client_credentials',
+ 'client_id': 'test_client_id',
+ 'client_secret': 'test_client_secret'
+ },
+ headers={'Content-Type': 'application/x-www-form-urlencoded'},
+ timeout=30
+ )
+
+ @patch('requests.post')
+ def test_obtain_bearer_token_http_error(self, mock_post, token_manager):
+ """Test token acquisition with HTTP error"""
+ mock_response = Mock()
+ mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("401 Unauthorized")
+ mock_post.return_value = mock_response
+
+ with pytest.raises(requests.exceptions.HTTPError):
+ token_manager._obtain_bearer_token()
+
+ @patch('requests.post')
+ def test_obtain_bearer_token_invalid_response(self, mock_post, token_manager):
+ """Test token acquisition with invalid JSON response"""
+ mock_response = Mock()
+ mock_response.raise_for_status.return_value = None
+ mock_response.json.return_value = {'invalid': 'response'} # Missing access_token
+ mock_post.return_value = mock_response
+
+ with pytest.raises(KeyError):
+ token_manager._obtain_bearer_token()
+
+ @patch.object(AuthTokenManager, '_obtain_bearer_token')
+ def test_get_valid_token_refresh_needed(self, mock_obtain, token_manager):
+ """Test getting valid token when refresh is needed"""
+ mock_obtain.return_value = 'new_token_123'
+
+ # Set up expired token
+ token_manager._access_token = 'old_token'
+ token_manager._token_expires_at = datetime.now(timezone.utc) - timedelta(seconds=60)
+
+ token = token_manager._get_valid_token()
+
+ assert token == 'new_token_123'
+ mock_obtain.assert_called_once()
+
+ def test_get_valid_token_no_refresh_needed(self, token_manager):
+ """Test getting valid token when no refresh is needed"""
+ # Set up valid token
+ token_manager._access_token = 'valid_token'
+ token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=300)
+
+ with patch.object(token_manager, '_obtain_bearer_token') as mock_obtain:
+ token = token_manager._get_valid_token()
+
+ assert token == 'valid_token'
+ mock_obtain.assert_not_called()
+
+ @patch.object(AuthTokenManager, '_get_valid_token')
+ def test_get_request_headers_with_token(self, mock_get_token, token_manager):
+ """Test getting request headers with Bearer token"""
+ mock_get_token.return_value = 'test_bearer_token'
+
+ headers = token_manager.get_request_headers()
+
+ expected_headers = {
+ "Content-Type": "text/xml; charset=utf-8",
+ "Authorization": "Bearer test_bearer_token"
+ }
+ assert headers == expected_headers
+
+ @patch.object(AuthTokenManager, '_get_valid_token')
+ def test_get_request_headers_no_token(self, mock_get_token, token_manager):
+ """Test getting request headers without Bearer token"""
+ mock_get_token.return_value = None
+
+ headers = token_manager.get_request_headers()
+
+ expected_headers = {"Content-Type": "text/xml; charset=utf-8"}
+ assert headers == expected_headers
+ assert "Authorization" not in headers
+
+
+class TestShapeshifterServiceOAuth:
+ """Test suite for ShapeshifterService OAuth2 integration"""
+
+ def test_service_with_oauth_config(self):
+ """Test service initialization with OAuth2 configuration"""
+ service = ShapeshifterService(
+ sender_domain="test.example.com",
+ signing_key="test_signing_key",
+ oauth_token_endpoint="https://oauth.example.com/token",
+ oauth_client_id="test_client",
+ oauth_client_secret="test_secret",
+ token_refresh_buffer=60
+ )
+
+ assert service.auth_token_manager is not None
+ assert service.auth_token_manager.oauth_token_endpoint == "https://oauth.example.com/token"
+ assert service.auth_token_manager.oauth_client_id == "test_client"
+ assert service.auth_token_manager.oauth_client_secret == "test_secret"
+ assert service.auth_token_manager.token_refresh_buffer == 60
+
+ def test_service_without_oauth_config(self):
+ """Test service initialization without OAuth2 configuration"""
+ service = ShapeshifterService(
+ sender_domain="test.example.com",
+ signing_key="test_signing_key",
+ oauth_token_endpoint=None,
+ oauth_client_id=None,
+ oauth_client_secret=None
+ )
+
+ assert service.auth_token_manager is None
+
+ @patch('shapeshifter_uftp.service.base_service.client_map')
+ @patch('shapeshifter_uftp.service.base_service.transport')
+ def test_get_client_with_oauth(self, mock_transport, mock_client_map):
+ """Test _get_client method passes OAuth token manager"""
+ # Setup mocks
+ mock_client_class = Mock()
+ mock_client_map.__getitem__.return_value = mock_client_class
+ mock_transport.get_endpoint.return_value = "https://recipient.example.com/api"
+ mock_transport.get_key.return_value = "recipient_public_key"
+
+ # Create service with OAuth2
+ service = ShapeshifterService(
+ sender_domain="test.example.com",
+ signing_key="test_signing_key",
+ oauth_token_endpoint="https://oauth.example.com/token",
+ oauth_client_id="test_client",
+ oauth_client_secret="test_secret"
+ )
+ service.sender_role = "AGR" # Set sender role for client_map lookup
+
+ # Call _get_client
+ client = service._get_client("recipient.example.com", "DSO")
+
+ # Verify client was created with oauth_token_manager
+ mock_client_class.assert_called_once()
+ call_kwargs = mock_client_class.call_args[1]
+ assert call_kwargs['oauth_token_manager'] == service.auth_token_manager
+
+
+class TestShapeshifterClientOAuth:
+ """Test suite for ShapeshifterClient OAuth2 integration"""
+
+ @pytest.fixture
+ def mock_token_manager(self):
+ """Fixture for mock token manager"""
+ return Mock(spec=AuthTokenManager)
+
+ @pytest.fixture
+ def client_with_oauth(self, mock_token_manager):
+ """Fixture for client with OAuth2"""
+ client = ShapeshifterClient(
+ sender_domain="sender.example.com",
+ recipient_domain="recipient.example.com",
+ recipient_endpoint="https://recipient.example.com/api",
+ signing_key="test_signing_key",
+ recipient_signing_key="recipient_public_key",
+ oauth_token_manager=mock_token_manager
+ )
+ client.sender_role = "test_sender"
+ client.recipient_role = "test_recipient"
+ return client
+
+ @pytest.fixture
+ def client_without_oauth(self):
+ """Fixture for client without OAuth2"""
+ client = ShapeshifterClient(
+ sender_domain="sender.example.com",
+ recipient_domain="recipient.example.com",
+ recipient_endpoint="https://recipient.example.com/api",
+ signing_key="test_signing_key",
+ recipient_signing_key="recipient_public_key"
+ )
+ client.sender_role = "test_sender"
+ client.recipient_role = "test_recipient"
+ return client
+
+ @pytest.fixture
+ def mock_payload_message(self):
+ """Fixture for mock PayloadMessage"""
+ mock_message = Mock(spec=PayloadMessage)
+ # Set required attributes that _send_message expects
+ mock_message.__class__.__name__ = "TestMessage"
+ mock_message.version = None
+ mock_message.sender_domain = None
+ mock_message.recipient_domain = None
+ mock_message.time_stamp = None
+ mock_message.message_id = None
+ mock_message.conversation_id = None
+ return mock_message
+
+ @patch('requests.post')
+ @patch('shapeshifter_uftp.client.base_client.transport')
+ def test_send_message_with_oauth_success(self, mock_transport, mock_post, client_with_oauth, mock_token_manager, mock_payload_message):
+ """Test _send_message with successful OAuth2 token"""
+ # Setup mocks
+ mock_token_manager.get_request_headers.return_value = {
+ "Content-Type": "text/xml; charset=utf-8",
+ "Authorization": "Bearer test_token_123"
+ }
+
+ mock_transport.seal_message.return_value = "sealed_message"
+ mock_transport.to_xml.return_value = "message"
+ mock_transport.parser.from_bytes.return_value = Mock(body="sealed_response")
+ mock_transport.unseal_message.return_value = "unsealed_response"
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.content = b"response"
+ mock_post.return_value = mock_response
+
+ # Create a mock message
+ # mock_message = Mock(spec=PayloadMessage)
+ # mock_message.__class__.__name__ = "TestMessage"
+
+ # Call _send_message
+ result = client_with_oauth._send_message(mock_payload_message)
+
+ # Verify OAuth2 headers were used
+ mock_token_manager.get_request_headers.assert_called_once()
+ mock_post.assert_called_once()
+
+ # Check that the Authorization header was included
+ call_kwargs = mock_post.call_args[1]
+ assert call_kwargs['headers']['Authorization'] == "Bearer test_token_123"
+
+ @patch('requests.post')
+ @patch('shapeshifter_uftp.client.base_client.transport')
+ def test_send_message_oauth_failure_fallback(self, mock_transport, mock_post, client_with_oauth, mock_token_manager, mock_payload_message):
+ """Test _send_message falls back to basic headers when OAuth2 fails"""
+ # Setup OAuth2 to fail
+ mock_token_manager.get_request_headers.side_effect = Exception("OAuth2 failed")
+
+ mock_transport.seal_message.return_value = "sealed_message"
+ mock_transport.to_xml.return_value = "message"
+ mock_transport.parser.from_bytes.return_value = Mock(body="sealed_response")
+ mock_transport.unseal_message.return_value = "unsealed_response"
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.content = b"response"
+ mock_post.return_value = mock_response
+
+ # # Create a mock message
+ # mock_message = Mock()
+ # mock_message.__class__.__name__ = "TestMessage"
+
+ # Call _send_message
+ result = client_with_oauth._send_message(mock_payload_message)
+
+ # Verify fallback headers were used
+ mock_post.assert_called_once()
+ call_kwargs = mock_post.call_args[1]
+ assert call_kwargs['headers'] == {"Content-Type": "text/xml; charset=utf-8"}
+ assert "Authorization" not in call_kwargs['headers']
+
+ @patch('requests.post')
+ @patch('shapeshifter_uftp.client.base_client.transport')
+ def test_send_message_without_oauth(self, mock_transport, mock_post, client_without_oauth, mock_payload_message):
+ """Test _send_message without OAuth2 token manager"""
+ mock_transport.seal_message.return_value = "sealed_message"
+ mock_transport.to_xml.return_value = "message"
+ mock_transport.parser.from_bytes.return_value = Mock(body="sealed_response")
+ mock_transport.unseal_message.return_value = "unsealed_response"
+
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.content = b"response"
+ mock_post.return_value = mock_response
+
+ # Call _send_message
+ result = client_without_oauth._send_message(mock_payload_message)
+
+ # Verify basic headers were used
+ mock_post.assert_called_once()
+ call_kwargs = mock_post.call_args[1]
+ assert call_kwargs['headers'] == {"Content-Type": "text/xml; charset=utf-8"}
+ assert "Authorization" not in call_kwargs['headers']
+
+
+class TestTokenRefreshLogic:
+ """Test token refresh timing and logic"""
+
+ @pytest.fixture
+ def token_manager(self):
+ return AuthTokenManager(
+ oauth_token_endpoint="https://test.example.com/oauth2/token",
+ oauth_client_id="test_client_id",
+ oauth_client_secret="test_client_secret",
+ token_refresh_buffer=30
+ )
+
+ @pytest.mark.parametrize("expires_in_seconds,refresh_buffer,should_refresh", [
+ (300, 30, False), # 5 minutes left, 30s buffer - no refresh needed
+ (25, 30, True), # 25 seconds left, 30s buffer - refresh needed
+ (30, 30, True), # Exactly at buffer - refresh needed
+ (31, 30, False), # Just over buffer - no refresh needed
+ (0, 30, True), # Expired - refresh needed
+ ])
+ def test_token_refresh_timing(self, token_manager, expires_in_seconds, refresh_buffer, should_refresh):
+ """Test token refresh timing with different scenarios"""
+ token_manager.token_refresh_buffer = refresh_buffer
+ token_manager._access_token = "test_token"
+ token_manager._token_expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds)
+
+ result = token_manager._is_token_valid()
+ assert result != should_refresh # should_refresh means token is NOT valid
+
+
+class TestErrorHandling:
+ """Test error handling scenarios"""
+
+ @pytest.fixture
+ def token_manager(self):
+ return AuthTokenManager(
+ oauth_token_endpoint="https://test.example.com/oauth2/token",
+ oauth_client_id="test_client_id",
+ oauth_client_secret="test_client_secret",
+ token_refresh_buffer=30
+ )
+
+ @patch('requests.post')
+ def test_network_timeout(self, mock_post, token_manager):
+ """Test handling of network timeouts"""
+ mock_post.side_effect = requests.exceptions.Timeout("Request timed out")
+
+ with pytest.raises(requests.exceptions.Timeout):
+ token_manager._obtain_bearer_token()
+
+ @patch('requests.post')
+ def test_connection_error(self, mock_post, token_manager):
+ """Test handling of connection errors"""
+ mock_post.side_effect = requests.exceptions.ConnectionError("Connection failed")
+
+ with pytest.raises(requests.exceptions.ConnectionError):
+ token_manager._obtain_bearer_token()
+
+ @patch('requests.post')
+ def test_invalid_json_response(self, mock_post, token_manager):
+ """Test handling of invalid JSON response"""
+ mock_response = Mock()
+ mock_response.raise_for_status.return_value = None
+ mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
+ mock_post.return_value = mock_response
+
+ with pytest.raises(json.JSONDecodeError):
+ token_manager._obtain_bearer_token()
+
+
+# Pytest configuration
+@pytest.fixture(scope="session")
+def setup_logging():
+ """Setup logging for tests"""
+ import logging
+ logging.basicConfig(
+ level=logging.DEBUG,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ )
+
+
+# Parametrized tests for different scenarios
+@pytest.mark.parametrize("endpoint,client_id,secret,expected_configured", [
+ ("https://oauth.example.com/token", "client123", "secret456", True),
+ ("", "client123", "secret456", False),
+ ("https://oauth.example.com/token", "", "secret456", False),
+ ("https://oauth.example.com/token", "client123", "", False),
+ (None, None, None, False),
+])
+def test_oauth_configuration_scenarios(endpoint, client_id, secret, expected_configured):
+ """Test different OAuth2 configuration scenarios"""
+ if endpoint is None:
+ # Handle case where service is created without OAuth2 params
+ service = ShapeshifterService(
+ sender_domain="test.example.com",
+ signing_key="test_key",
+ oauth_token_endpoint=None,
+ oauth_client_id=None,
+ oauth_client_secret=None
+ )
+ assert (service.auth_token_manager is not None) == expected_configured
+ else:
+ token_manager = AuthTokenManager(
+ oauth_token_endpoint=endpoint,
+ oauth_client_id=client_id,
+ oauth_client_secret=secret
+ )
+ assert token_manager._is_oauth_configured() == expected_configured
\ No newline at end of file