Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import inspect
import logging
import os
from collections.abc import Callable
from functools import wraps
Expand All @@ -27,7 +28,10 @@
EKSWorkloadIdentity,
WebIdentity,
)
from keycardai.mcp.server.auth.client_factory import ClientFactory, DefaultClientFactory
from keycardai.mcp.server.auth.client_factory import (
ClientFactory,
DefaultClientFactory,
)
from keycardai.mcp.server.exceptions import (
AuthProviderConfigurationError,
AuthProviderInternalError,
Expand All @@ -38,6 +42,104 @@
from keycardai.oauth import AsyncClient, Client
from keycardai.oauth.http.auth import NoneAuth
from keycardai.oauth.types.models import TokenExchangeRequest, TokenResponse
from keycardai.oauth.utils.jwt import extract_scopes, get_claims

logger = logging.getLogger(__name__)

# Define custom INTROSPECT log level for detailed token debugging
# INTROSPECT (5) is more detailed than DEBUG (10) - use for sensitive token introspection
INTROSPECT = 5
logging.addLevelName(INTROSPECT, "INTROSPECT")

def introspect(self, message, *args, **kwargs):
"""Log at INTROSPECT level - most detailed debugging including token info."""
if self.isEnabledFor(INTROSPECT):
self._log(INTROSPECT, message, args, **kwargs)

# Add introspect method to Logger class
logging.Logger.introspect = introspect

# Configure logger to respect KEYCARD_LOG_LEVEL environment variable
_log_level = os.getenv("KEYCARD_LOG_LEVEL", "").upper()
if _log_level in ("INTROSPECT", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"):
if _log_level == "INTROSPECT":
logger.setLevel(INTROSPECT)
else:
logger.setLevel(getattr(logging, _log_level))

if not logger.handlers:
_handler = logging.StreamHandler()
# Match FastMCP's logging format for consistency
# Format: [MM/DD/YY HH:MM:SS] LEVEL Message filename:line
_handler.setFormatter(
logging.Formatter(
"[%(asctime)s] %(levelname)-8s %(message)s %(filename)s:%(lineno)d",
datefmt="%m/%d/%y %H:%M:%S"
)
)
logger.addHandler(_handler)


def get_token_debug_info(access_token: str) -> dict[str, Any]:
"""Extract non-sensitive debugging information from a JWT access token.

This function safely extracts only non-sensitive claims from a JWT token
for debugging and logging purposes. It does NOT verify the token signature
and only returns issuer, audience, and scope information.

**Important:** This is a debug function that NEVER raises exceptions. If token
parsing fails, it returns an error indicator in the result dict instead.

**Security Note:** This function is designed to be safe for logging/debugging
in production environments. It explicitly excludes sensitive information like:
- The actual token string
- Subject (user identifier)
- Custom claims that might contain PII

Args:
access_token: JWT access token string (without Bearer prefix)

Returns:
Dictionary with non-sensitive token information:
- issuer (str): Token issuer (if present)
- audience (str | list[str]): Token audience (if present)
- expires_at (int): Token expiration time as Unix timestamp (if present)
- issued_at (int): Token issuance time as Unix timestamp (if present)
- scopes (list[str]): List of scopes from the token (if present)
- error (str): Error message if token parsing failed

Example:
>>> token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9..."
>>> debug_info = get_token_debug_info(token)
>>> logger.info(f"Token info: {debug_info}")
# Success: {"issuer": "https://auth.example.com", "audience": "api.example.com", "expires_at": 1700000000, "issued_at": 1699996400, "scopes": ["read", "write"]}
# Failure: {"error": "Failed to parse token"}
"""
try:
claims = get_claims(access_token)

debug_info: dict[str, Any] = {}

if "iss" in claims:
debug_info["issuer"] = claims["iss"]

if "aud" in claims:
debug_info["audience"] = claims["aud"]

if "exp" in claims:
debug_info["expires_at"] = claims["exp"]

if "iat" in claims:
debug_info["issued_at"] = claims["iat"]

scopes = extract_scopes(claims)
if scopes:
debug_info["scopes"] = scopes

return debug_info

except Exception:
return {"error": "Failed to parse token"}


class AccessContext:
Expand Down Expand Up @@ -577,25 +679,33 @@ async def wrapper(*args, **kwargs) -> Any:
runtime_context=True
)

_resource_list = [resources] if isinstance(resources, str) else resources
logger.debug(f"Starting token exchange for resources: {_resource_list}")

_access_context = AccessContext()
try:
_user_token = get_access_token()
if not _user_token:
logger.warning(f"No authentication token available for {func.__name__}")
_set_error({
"error": "No authentication token available. Please ensure you're properly authenticated.",
}, None, _access_context, _ctx)
return await _call_func(is_async_func, func, *args, **kwargs)
logger.introspect(f"User token retrieved: {get_token_debug_info(_user_token.token)}")
except Exception as e:
logger.error("Failed to get access token")
_set_error({
"error": "Failed to get access token from context. Ensure the Context parameter is properly annotated.",
"raw_error": str(e),
}, None, _access_context, _ctx)
return await _call_func(is_async_func, func, *args, **kwargs)
_resource_list = [resources] if isinstance(resources, str) else resources

_access_tokens = {}
for resource in _resource_list:
logger.debug(f"Exchanging token for resource: {resource}")
try:
if self.application_credential:
logger.debug(f"Using application credential: {type(self.application_credential).__name__}")
# auth_info context is used by application credential implementation
# to prepare correct assertions in the token exchange request
_auth_info = {
Expand All @@ -615,17 +725,24 @@ async def wrapper(*args, **kwargs) -> Any:
resource=resource,
subject_token_type="urn:ietf:params:oauth:token-type:access_token",
)

_token_response = await self.client.exchange_token(_token_exchange_request)

_access_tokens[resource] = _token_response
logger.debug(f"Token exchange successful for {resource}")
logger.introspect(f"Token details for {resource}: {get_token_debug_info(_token_response.access_token)}")
except Exception as e:
logger.error(f"Token exchange failed for {resource}")
_set_error({
"error": f"Token exchange failed for {resource}: {e}",
"raw_error": str(e),
}, resource, _access_context, _ctx)
return await _call_func(is_async_func, func, *args, **kwargs)

logger.debug(f"All token exchanges completed. Setting access context with {len(_access_tokens)} token(s)")
_access_context.set_bulk_tokens(_access_tokens)
_ctx.set_state("keycardai", _access_context)
logger.debug(f"Executing decorated function: {func.__name__}")
return await _call_func(is_async_func, func, *args, **kwargs)
return wrapper
return decorator
Loading