Skip to content

Commit 31c1cd6

Browse files
committed
Add tests that complement test_token_validity_check
1 parent c9cae41 commit 31c1cd6

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

tests/client/test_auth.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ def valid_tokens():
8282
)
8383

8484

85+
@pytest.fixture
86+
def expired_tokens():
87+
return OAuthToken(
88+
access_token="test_access_token",
89+
token_type="Bearer",
90+
expires_in=0,
91+
refresh_token="test_refresh_token",
92+
scope="read write",
93+
)
94+
95+
8596
@pytest.fixture
8697
def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage):
8798
async def redirect_handler(url: str) -> None:
@@ -259,6 +270,98 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
259270
assert context.token_expiry_time is None
260271

261272

273+
class TestTokenInitialization:
274+
"""Test token loading from storage during initialization."""
275+
276+
@pytest.mark.anyio
277+
async def test_initialize_sets_token_expiry_from_stored_tokens(
278+
self, oauth_provider: OAuthClientProvider, valid_tokens: OAuthToken
279+
):
280+
"""Test _initialize() sets token_expiry_time when loading tokens from storage."""
281+
context = oauth_provider.context
282+
await context.storage.set_tokens(valid_tokens)
283+
284+
# Before initialization
285+
assert oauth_provider._initialized is False
286+
assert context.current_tokens is None
287+
assert context.token_expiry_time is None
288+
289+
# Trigger initialization by starting auth flow
290+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
291+
auth_flow = oauth_provider.async_auth_flow(test_request)
292+
293+
# First request calls _initialize()
294+
request = await auth_flow.__anext__()
295+
296+
# After first request, verify tokens were loaded
297+
assert oauth_provider._initialized is True
298+
assert oauth_provider.context.current_tokens is not None
299+
assert oauth_provider.context.current_tokens.access_token == "test_access_token"
300+
301+
# token_expiry_time should be set by update_token_expiry()
302+
assert oauth_provider.context.token_expiry_time is not None
303+
304+
# Verify token is considered valid
305+
assert oauth_provider.context.is_token_valid() is True
306+
307+
# Request should have auth header added
308+
assert request.headers["Authorization"] == "Bearer test_access_token"
309+
310+
# Complete the flow
311+
response = httpx.Response(200, request=request)
312+
try:
313+
await auth_flow.asend(response)
314+
except StopAsyncIteration:
315+
pass
316+
317+
@pytest.mark.anyio
318+
async def test_initialize_with_expired_tokens_detects_expiry(
319+
self, oauth_provider: OAuthClientProvider, expired_tokens: OAuthToken
320+
):
321+
"""Test that expired tokens loaded from storage are detected as invalid."""
322+
context = oauth_provider.context
323+
await context.storage.set_tokens(expired_tokens)
324+
await context.storage.set_client_info(OAuthClientInformationFull(
325+
client_id="test_client_id",
326+
client_secret="test_client_secret",
327+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
328+
))
329+
330+
# First request
331+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
332+
auth_flow = oauth_provider.async_auth_flow(test_request)
333+
334+
# This should trigger a refresh attempt, not the original request
335+
refresh_request = await auth_flow.__anext__()
336+
337+
# Verify tokens were loaded
338+
assert context.current_tokens is not None
339+
340+
# token_expiry_time should be set by update_token_expiry()
341+
assert context.token_expiry_time is not None
342+
343+
# Token should be detected as invalid (expired)
344+
assert context.is_token_valid() is False
345+
346+
# Should be able to refresh
347+
assert context.can_refresh_token() is True
348+
349+
# Complete the flow
350+
refresh_response = httpx.Response(
351+
200,
352+
content=b'{"access_token": "new_token", "token_type": "Bearer", "expires_in": 3600}',
353+
request=refresh_request,
354+
)
355+
try:
356+
original_request = await auth_flow.asend(refresh_response)
357+
# Should retry original request with new token
358+
assert original_request.headers["Authorization"] == "Bearer new_token"
359+
final_response = httpx.Response(200, request=original_request)
360+
await auth_flow.asend(final_response)
361+
except StopAsyncIteration:
362+
pass
363+
364+
262365
class TestOAuthFlow:
263366
"""Test OAuth flow methods."""
264367

0 commit comments

Comments
 (0)