@@ -82,6 +82,17 @@ def valid_tokens():
8282 )
8383
8484
85+ @pytest .fixture
86+ def expired_tokens ():
87+ return OAuthToken (
88+ access_token = "test_access_token" ,
89+ token_type = "Bearer" ,
90+ expires_in = 0 ,
91+ refresh_token = "test_refresh_token" ,
92+ scope = "read write" ,
93+ )
94+
95+
8596@pytest .fixture
8697def oauth_provider (client_metadata : OAuthClientMetadata , mock_storage : MockTokenStorage ):
8798 async def redirect_handler (url : str ) -> None :
@@ -259,6 +270,98 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
259270 assert context .token_expiry_time is None
260271
261272
273+ class TestTokenInitialization :
274+ """Test token loading from storage during initialization."""
275+
276+ @pytest .mark .anyio
277+ async def test_initialize_sets_token_expiry_from_stored_tokens (
278+ self , oauth_provider : OAuthClientProvider , valid_tokens : OAuthToken
279+ ):
280+ """Test _initialize() sets token_expiry_time when loading tokens from storage."""
281+ context = oauth_provider .context
282+ await context .storage .set_tokens (valid_tokens )
283+
284+ # Before initialization
285+ assert oauth_provider ._initialized is False
286+ assert context .current_tokens is None
287+ assert context .token_expiry_time is None
288+
289+ # Trigger initialization by starting auth flow
290+ test_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
291+ auth_flow = oauth_provider .async_auth_flow (test_request )
292+
293+ # First request calls _initialize()
294+ request = await auth_flow .__anext__ ()
295+
296+ # After first request, verify tokens were loaded
297+ assert oauth_provider ._initialized is True
298+ assert oauth_provider .context .current_tokens is not None
299+ assert oauth_provider .context .current_tokens .access_token == "test_access_token"
300+
301+ # token_expiry_time should be set by update_token_expiry()
302+ assert oauth_provider .context .token_expiry_time is not None
303+
304+ # Verify token is considered valid
305+ assert oauth_provider .context .is_token_valid () is True
306+
307+ # Request should have auth header added
308+ assert request .headers ["Authorization" ] == "Bearer test_access_token"
309+
310+ # Complete the flow
311+ response = httpx .Response (200 , request = request )
312+ try :
313+ await auth_flow .asend (response )
314+ except StopAsyncIteration :
315+ pass
316+
317+ @pytest .mark .anyio
318+ async def test_initialize_with_expired_tokens_detects_expiry (
319+ self , oauth_provider : OAuthClientProvider , expired_tokens : OAuthToken
320+ ):
321+ """Test that expired tokens loaded from storage are detected as invalid."""
322+ context = oauth_provider .context
323+ await context .storage .set_tokens (expired_tokens )
324+ await context .storage .set_client_info (OAuthClientInformationFull (
325+ client_id = "test_client_id" ,
326+ client_secret = "test_client_secret" ,
327+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
328+ ))
329+
330+ # First request
331+ test_request = httpx .Request ("GET" , "https://api.example.com/v1/mcp" )
332+ auth_flow = oauth_provider .async_auth_flow (test_request )
333+
334+ # This should trigger a refresh attempt, not the original request
335+ refresh_request = await auth_flow .__anext__ ()
336+
337+ # Verify tokens were loaded
338+ assert context .current_tokens is not None
339+
340+ # token_expiry_time should be set by update_token_expiry()
341+ assert context .token_expiry_time is not None
342+
343+ # Token should be detected as invalid (expired)
344+ assert context .is_token_valid () is False
345+
346+ # Should be able to refresh
347+ assert context .can_refresh_token () is True
348+
349+ # Complete the flow
350+ refresh_response = httpx .Response (
351+ 200 ,
352+ content = b'{"access_token": "new_token", "token_type": "Bearer", "expires_in": 3600}' ,
353+ request = refresh_request ,
354+ )
355+ try :
356+ original_request = await auth_flow .asend (refresh_response )
357+ # Should retry original request with new token
358+ assert original_request .headers ["Authorization" ] == "Bearer new_token"
359+ final_response = httpx .Response (200 , request = original_request )
360+ await auth_flow .asend (final_response )
361+ except StopAsyncIteration :
362+ pass
363+
364+
262365class TestOAuthFlow :
263366 """Test OAuth flow methods."""
264367
0 commit comments