diff --git a/README.md b/README.md index 4f9db7e..d4eed16 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,68 @@ Listed below are the environment variables supported by the application. | `ICAT_SERVER__CERTIFICATE_VALIDATION` | Whether to verify ICAT certificates using its internal trust store or disable certificate validation completely. | Yes | | | `ICAT_SERVER__REQUEST_TIMEOUT_SECONDS` | The maximum number of seconds that the request should wait for a response from ICAT before timing out. | Yes | | +### OIDC Configuration + +The following environment variables are only required when using OIDC authentication: + +| Environment Variable | Description | Mandatory | Default Value | +|-------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|-----------|---------------| +| `AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR` | The mnemonic of the ICAT authenticator. Usually `delegating`. | Yes | | +| `AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR_TOKEN` | The secret token to pass to the ICAT authenticator. | Yes | | +| `AUTHENTICATION__OIDC_REDIRECT_URI` | Redirect URI. Required if a `client_secret` is used. | No | | +| `AUTHENTICATION__OIDC_PROVIDERS` | A dictionary of OIDC provider configurations, indexed by `provider_id`. | Yes | | + +To support multiple OIDC providers simultaneously, provider-specific config is indexed by a `provider_id`, e.g. to set the value of `DISPLAY_NAME` you would set the environment variable `AUTHENTICATION__OIDC_PROVIDERS____DISPLAY_NAME`. The actual value used for `provider_id` is not important. + +Each individual OIDC provider has the following configuration: + +| Environment Variable | Description | Mandatory | Default Value | +|-------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|-----------|---------------| +| `DISPLAY_NAME` | The name of the OIDC provider to display in the frontend. | Yes | | +| `CONFIGURATION_URL` | The URL of the OIDC provider's configuration metadata. This will usually end with`/.well-known/openid-configuration`. | Yes | | +| `CLIENT_ID` | The `client_id` of the application registered with the OIDC provider. | Yes | | +| `CLIENT_SECRET` | The `client_secret`. If this is omitted then Authorization Code Flow with PKCE will be used (which is preferred). | No | | +| `VERIFY_CERT` | Whether to verify TLS certificates in calls to the OIDC provider. This should be `True`. | No | `True` | +| `MECHANISM` | The mechanism to prepend to the username in ICAT when using this OIDC provider. | No | | +| `SCOPE` | Which OAuth scopes to request. Must include `openid`. | No | `openid` | +| `USERNAME_CLAIM` | Which OAuth claim to use as the user's username. | No | `sub` | + +### Example OIDC Configurations + +This example uses Microsoft Single Sign-On. The format of the username will be determined by the tenant admin. +``` +AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR="delegating" +AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR_TOKEN="fe1be44a35eb00ab46f5" +AUTHENTICATION__OIDC_PROVIDERS__sso__DISPLAY_NAME="Microsoft SSO" +AUTHENTICATION__OIDC_PROVIDERS__sso__CONFIGURATION_URL="https://login.microsoftonline.com/73c7442c-5f40-4db0-8dd2-5b9bb94516a1/v2.0/.well-known/openid-configuration" +AUTHENTICATION__OIDC_PROVIDERS__sso__CLIENT_ID="700bfc86-e26e-4638-a1a6-f7027106857b" +``` + +This example uses ORCID. The username will be the user's ORCID Id prepended with `orcid/`, e.g. `orcid/0000-0002-1825-0097`. +Since `client_secret` is used, `AUTHENTICATION__OIDC_REDIRECT_URI` must also be set. +``` +AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR="delegating" +AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR_TOKEN="fe1be44a35eb00ab46f5" +AUTHENTICATION__OIDC_REDIRECT_URI="https://scigateway.example.com/login" +AUTHENTICATION__OIDC_PROVIDERS__orcid__DISPLAY_NAME="Orcid" +AUTHENTICATION__OIDC_PROVIDERS__orcid__CONFIGURATION_URL="https://orcid.org/.well-known/openid-configuration" +AUTHENTICATION__OIDC_PROVIDERS__orcid__CLIENT_ID="APP-QKUS1G0MLIOXDC57" +AUTHENTICATION__OIDC_PROVIDERS__orcid__CLIENT_SECRET="33182ac684744367edd8" +AUTHENTICATION__OIDC_PROVIDERS__orcid__MECHANISM="orcid" +``` + +This example uses Keycloak for testing with TLS certificate verification disabled. The username will be the user's email address (using custom settings for `SCOPE` and `USERNAME_CLAIM`). +``` +AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR="delegating" +AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR_TOKEN="fe1be44a35eb00ab46f5" +AUTHENTICATION__OIDC_PROVIDERS__keycloak__DISPLAY_NAME="Keycloak" +AUTHENTICATION__OIDC_PROVIDERS__keycloak__CONFIGURATION_URL="https://localhost:9000/realms/test-realm/.well-known/openid-configuration" +AUTHENTICATION__OIDC_PROVIDERS__keycloak__CLIENT_ID="test-client-id" +AUTHENTICATION__OIDC_PROVIDERS__keycloak__VERIFY_CERT="False" +AUTHENTICATION__OIDC_PROVIDERS__keycloak__SCOPE="openid email" +AUTHENTICATION__OIDC_PROVIDERS__keycloak__USERNAME_CLAIM="email" +``` + ### How to add or remove a JWT refresh token from the blacklist The `AUTHENTICATION__JWT_REFRESH_TOKEN_BLACKLIST` environment variable holds the list of blacklisted JWT refresh tokens diff --git a/poetry.lock b/poetry.lock index 776883e..71d0106 100644 --- a/poetry.lock +++ b/poetry.lock @@ -140,6 +140,18 @@ d = ["aiohttp (>=3.10)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "cachetools" +version = "6.2.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "cachetools-6.2.0-py3-none-any.whl", hash = "sha256:1c76a8960c0041fcc21097e357f882197c79da0dbff766e7317890a65d7d8ba6"}, + {file = "cachetools-6.2.0.tar.gz", hash = "sha256:38b328c0889450f05f5e120f56ab68c8abaf424e1275522b138ffc93253f7e32"}, +] + [[package]] name = "certifi" version = "2025.11.12" @@ -3183,4 +3195,4 @@ files = [ [metadata] lock-version = "2.1" python-versions = ">=3.11,<4.0" -content-hash = "79862b2eddbbb0e953745fba5fb6c9abff9ed541b48b068439bdfc345632d9b3" +content-hash = "b168c5f04b7683cf216b005bac2cdf6248820e8d2b26c56048e9a20a1c97265e" diff --git a/pyproject.toml b/pyproject.toml index 45827aa..918002e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "PyJWT (>=2.9,<3.0)", "cryptography (>=43.0)", "fastapi[all] (>=0.123)", + "cachetools (>=6.2)", ] [project.urls] diff --git a/scigateway_auth/common/config.py b/scigateway_auth/common/config.py index d30d7ce..43c8445 100644 --- a/scigateway_auth/common/config.py +++ b/scigateway_auth/common/config.py @@ -3,9 +3,9 @@ """ from pathlib import Path -from typing import List +from typing import List, Self -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -29,6 +29,21 @@ class MaintenanceConfig(BaseModel): scheduled_maintenance_path: str +class OidcProviderConfig(BaseModel): + """ + Configuration model for an OIDC provider + """ + + display_name: str + configuration_url: str + client_id: str + client_secret: str = None + verify_cert: bool = True + mechanism: str = None + scope: str = "openid" + username_claim: str = "sub" + + class AuthenticationConfig(BaseModel): """ Configuration model for the authentication. @@ -43,6 +58,20 @@ class AuthenticationConfig(BaseModel): # These are the ICAT usernames of the users normally in the / form admin_users: list[str] + oidc_providers: dict[str, OidcProviderConfig] = {} + oidc_redirect_uri: str = None + oidc_icat_authenticator: str = None + oidc_icat_authenticator_token: str = None + + @model_validator(mode="after") + def validate_oidc(self) -> Self: + if self.oidc_providers: + if not self.oidc_icat_authenticator: + raise ValueError("oidc_icat_authenticator is required if oidc_providers is set") + if not self.oidc_icat_authenticator_token: + raise ValueError("oidc_icat_authenticator_token is required if oidc_providers is set") + return self + class ICATServerConfig(BaseModel): """ diff --git a/scigateway_auth/common/exceptions.py b/scigateway_auth/common/exceptions.py index 8f9c42b..dcbdade 100644 --- a/scigateway_auth/common/exceptions.py +++ b/scigateway_auth/common/exceptions.py @@ -55,3 +55,9 @@ class UserNotAdminError(Exception): """ Exception raised when a non-admin user performs an action that requires the user to be an admin. """ + + +class OidcProviderNotFoundError(Exception): + """ + Exception raised when an OIDC provider is not found + """ diff --git a/scigateway_auth/routers/authentication.py b/scigateway_auth/routers/authentication.py index c35f15c..e4301e1 100644 --- a/scigateway_auth/routers/authentication.py +++ b/scigateway_auth/routers/authentication.py @@ -7,6 +7,7 @@ from fastapi import APIRouter, Body, Cookie, Depends, HTTPException, Response, status from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from scigateway_auth.common.config import config from scigateway_auth.common.exceptions import ( @@ -14,9 +15,11 @@ ICATAuthenticationError, InvalidJWTError, JWTRefreshError, + OidcProviderNotFoundError, UsernameMismatchError, ) from scigateway_auth.common.schemas import LoginDetailsPostRequestSchema +from scigateway_auth.src import oidc from scigateway_auth.src.authentication import ICATAuthenticator from scigateway_auth.src.jwt_handler import JWTHandler @@ -43,6 +46,27 @@ def get_authenticators(): ) from exc +@router.get( + path="/oidc_providers", + summary="Get a list of OIDC providers", + response_description="Returns a list of OIDC providers", +) +def get_oidc_providers() -> JSONResponse: + logger.info("Getting a list of OIDC providers") + + providers = {} + for provider_id, provider_config in config.authentication.oidc_providers.items(): + providers[provider_id] = { + "display_name": provider_config.display_name, + "configuration_url": provider_config.configuration_url, + "client_id": provider_config.client_id, + "pkce": False if provider_config.client_secret else True, + "scope": provider_config.scope, + } + + return JSONResponse(content=providers) + + @router.post( path="/login", summary="Login with ICAT mnemonic and credentials", @@ -56,28 +80,113 @@ def login( ], ) -> JSONResponse: logger.info("Authenticating a user") + + if login_details.credentials is None: + credentials = None + else: + credentials = { + "username": login_details.credentials.username.get_secret_value(), + "password": login_details.credentials.password.get_secret_value(), + } + try: - icat_session_id = ICATAuthenticator.authenticate(login_details.mnemonic, login_details.credentials) + icat_session_id = ICATAuthenticator.authenticate(login_details.mnemonic, credentials) icat_username = ICATAuthenticator.get_username(icat_session_id) + except ICATAuthenticationError as exc: + logger.exception(exc.args) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc + + access_token = jwt_handler.get_access_token(icat_session_id, icat_username) + refresh_token = jwt_handler.get_refresh_token(icat_username) + + response = JSONResponse(content=access_token) + response.set_cookie( + key="scigateway:refresh_token", + value=refresh_token, + max_age=config.authentication.refresh_token_validity_days * 24 * 60 * 60, + secure=True, + httponly=True, + samesite="lax", + path=f"{config.api.root_path}/refresh", + ) + return response + + +@router.post( + path="/oidc_token/{provider_id}", + summary="Get an OIDC id_token", + response_description="OIDC token endpoint response", +) +def oidc_token( + provider_id: Annotated[str, "OIDC provider id"], + code: Annotated[str, Body(description="OIDC authorization code")], +) -> JSONResponse: + logger.info("Obtaining an id_token") + + try: + token_response = oidc.get_token(provider_id, code) + except OidcProviderNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Unknown OIDC provider: " + provider_id, + ) from None - access_token = jwt_handler.get_access_token(icat_session_id, icat_username) - refresh_token = jwt_handler.get_refresh_token(icat_username) - - response = JSONResponse(content=access_token) - response.set_cookie( - key="scigateway:refresh_token", - value=refresh_token, - max_age=config.authentication.refresh_token_validity_days * 24 * 60 * 60, - secure=True, - httponly=True, - samesite="lax", - path=f"{config.api.root_path}/refresh", - ) - return response + return JSONResponse(content=token_response) + + +@router.post( + path="/oidc_login/{provider_id}", + summary="Login with an OIDC id token", + response_description="A JWT access token including a refresh token as an HTTP-only cookie", +) +def oidc_login( + jwt_handler: JWTHandlerDep, + provider_id: Annotated[str, "The OIDC provider id"], + bearer_token: Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer(description="OIDC id token"))], +) -> JSONResponse: + logger.info("Authenticating a user") + + id_token = bearer_token.credentials + + try: + mechanism, oidc_username = oidc.get_username(provider_id, id_token) + except OidcProviderNotFoundError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Unknown OIDC provider: " + provider_id, + ) from None + except InvalidJWTError as exc: + logger.exception(exc.args) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc + + credentials = { + "mechanism": mechanism, + "username": oidc_username, + "token": config.authentication.oidc_icat_authenticator_token, + } + + try: + icat_session_id = ICATAuthenticator.authenticate(config.authentication.oidc_icat_authenticator, credentials) + icat_username = ICATAuthenticator.get_username(icat_session_id) except ICATAuthenticationError as exc: logger.exception(exc.args) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)) from exc + access_token = jwt_handler.get_access_token(icat_session_id, icat_username) + refresh_token = jwt_handler.get_refresh_token(icat_username) + + response = JSONResponse(content=access_token) + response.set_cookie( + key="scigateway:refresh_token", + value=refresh_token, + max_age=config.authentication.refresh_token_validity_days * 24 * 60 * 60, + secure=True, + httponly=True, + samesite="lax", + path=f"{config.api.root_path}/refresh", + ) + return response + @router.post( path="/refresh", diff --git a/scigateway_auth/src/authentication.py b/scigateway_auth/src/authentication.py index bdbca0f..bc48221 100644 --- a/scigateway_auth/src/authentication.py +++ b/scigateway_auth/src/authentication.py @@ -10,7 +10,6 @@ from scigateway_auth.common.config import config from scigateway_auth.common.exceptions import ICATAuthenticationError -from scigateway_auth.common.schemas import UserCredentialsPostRequestSchema logger = logging.getLogger() @@ -21,7 +20,7 @@ class ICATAuthenticator: """ @staticmethod - def authenticate(mnemonic: str, credentials: UserCredentialsPostRequestSchema = None) -> str: + def authenticate(mnemonic: str, credentials: dict[str, str] | None = None) -> str: """ Sends an authentication request to the ICAT authenticator and returns a session ID. @@ -32,17 +31,15 @@ def authenticate(mnemonic: str, credentials: UserCredentialsPostRequestSchema = :return: The ICAT session ID. """ logger.info("Authenticating at %s with mnemonic: %s", config.icat_server.url, mnemonic) - json_payload = ( - {"plugin": "anon"} - if credentials is None - else { + + if credentials is None: + json_payload = {"plugin": "anon"} + else: + json_payload = { "plugin": mnemonic, - "credentials": [ - {"username": credentials.username.get_secret_value()}, - {"password": credentials.password.get_secret_value()}, - ], + "credentials": [{k: v} for k, v in credentials.items()], # ICAT requires this to be an array of objects } - ) + data = {"json": json.dumps(json_payload)} response = requests.post( diff --git a/scigateway_auth/src/oidc.py b/scigateway_auth/src/oidc.py new file mode 100644 index 0000000..0e3fc72 --- /dev/null +++ b/scigateway_auth/src/oidc.py @@ -0,0 +1,151 @@ +""" +OIDC module. +""" + +from cachetools.func import ttl_cache +import jwt +import requests + +from scigateway_auth.common.config import config, OidcProviderConfig +from scigateway_auth.common.exceptions import InvalidJWTError, OidcProviderNotFoundError + +# Amount of leeway (in seconds) when validating exp & iat +LEEWAY = 5 + +# Timeout for HTTP requests (in seconds) +TIMEOUT = 10 + + +def get_provider_config(provider_id: str) -> OidcProviderConfig: + """ + Get OIDC provider config with error handling. + + :param provider_id: The ID of the OIDC provider. + :raises OidcProviderNotFoundError: If there is no OIDC provider config for the given provider_id. + :return: The OIDC provider config. + """ + try: + return config.authentication.oidc_providers[provider_id] + except KeyError: + raise OidcProviderNotFoundError from None + + +@ttl_cache(ttl=(24 * 60 * 60)) +def get_well_known_config(provider_id: str) -> dict: + """ + Retreives the OIDC provider's configuration from its .well-known/openid-configuration endpoint. + Caches the response for 24 hours. + + :param provider_id: The ID of the OIDC provider. + :raises OidcProviderNotFoundError: If there is no OIDC provider config for the given provider_id. + :raises RequestException: If a HTTP request did not succeed or returned an error. + :return: The OIDC provider's configuration. + """ + provider_config = get_provider_config(provider_id) + r = requests.get(provider_config.configuration_url, verify=provider_config.verify_cert, timeout=TIMEOUT) + r.raise_for_status() + return r.json() + + +@ttl_cache(ttl=(2 * 60 * 60)) +def get_jwks(provider_id: str) -> jwt.PyJWKSet: + """ + Retreives an OIDC provider's JWK Set. + Caches the response for 2 hours. + + :param provider_id: The ID of the OIDC provider. + :raises OidcProviderNotFoundError: If there is no OIDC provider config for the given provider_id. + :raises RequestException: If a HTTP request did not succeed or returned an error. + :return: The OIDC provider's JWK Set. + """ + provider_config = get_provider_config(provider_id) + well_known_config = get_well_known_config(provider_id) + jwks_uri = well_known_config["jwks_uri"] + + r = requests.get(jwks_uri, verify=provider_config.verify_cert, timeout=TIMEOUT) + r.raise_for_status() + jwks_config = r.json() + + return jwt.PyJWKSet(jwks_config["keys"]) + + +def get_token(provider_id: str, code: str) -> dict: + """ + Call the OIDC provider's token endpoint. + + :param provider_id: The ID of the OIDC provider. + :param code: Authorization code obtained from the OIDC provider's authorization endpoint. + :raises OidcProviderNotFoundError: If there is no OIDC provider config for the given provider_id or the OIDC + provider config does not have a client_secret. + :raises RequestException: If a HTTP request did not succeed or returned an error. + :return: The reponse received from the token endpoint. + """ + provider_config = get_provider_config(provider_id) + token_endpoint = get_well_known_config(provider_id)["token_endpoint"] + + if provider_config.client_secret is None: + raise OidcProviderNotFoundError from None + + r = requests.post( + url=token_endpoint, + data={ + "grant_type": "authorization_code", + "client_id": provider_config.client_id, + "client_secret": provider_config.client_secret, + "code": code, + "redirect_uri": config.authentication.oidc_redirect_uri, + }, + verify=provider_config.verify_cert, + timeout=TIMEOUT, + ) + + r.raise_for_status() + return r.json() + + +def get_username(provider_id: str, id_token: str) -> tuple[str, str]: + """ + Verify an id_token and return the mechanism and username. + + :param provider_id: The ID of the OIDC provider. + :param id_token: An OIDC id_token. + :raises InvalidJWTError: If the id_token is invalid. + :raises OidcProviderNotFoundError: If there is no OIDC provider config for the given provider_id. + :raises RequestException: If a HTTP request did not succeed or returned an error. + :return: The mechanism for the provider, and the username obtained from the id_token. + """ + provider_config = get_provider_config(provider_id) + + try: + unverified_header = jwt.get_unverified_header(id_token) + + try: + kid = unverified_header["kid"] + key = get_jwks(provider_id)[kid] + except KeyError as exc: + raise InvalidJWTError("Invalid OIDC id_token") from exc + + # Ensure that this key can be used for signing + if key.public_key_use not in [None, "sig"]: + raise InvalidJWTError("Invalid OIDC id_token") + + payload = jwt.decode( + jwt=id_token, + key=key, + algorithms=[key.algorithm_name], + audience=provider_config.client_id, + issuer=get_well_known_config(provider_id)["issuer"], + verify=True, + options={"require": ["exp", "aud", "iss"], "verify_exp": True, "verify_aud": True, "verify_iss": True}, + leeway=LEEWAY, + ) + + try: + username = payload[provider_config.username_claim] + except KeyError: + raise InvalidJWTError("Invalid OIDC id_token") from None + + return (provider_config.mechanism, username) + + except jwt.exceptions.InvalidTokenError as exc: + raise InvalidJWTError("Invalid OIDC id_token") from exc diff --git a/test/mock_data.py b/test/mock_data.py index 9a6b793..86ab4ea 100644 --- a/test/mock_data.py +++ b/test/mock_data.py @@ -57,3 +57,38 @@ MAINTENANCE_CONFIG_PATH = "path/to/config/test_config.json" MAINTENANCE_STATE = {"test": "test"} + +JWK_PRIVATE_KEY = """-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAm3DT0wGFXPgW6tYh8IR2NaoUAP733QVOV/PPx6hPVY04QGT8 +n+QBdUBBXAZL/qxN5ib3ChI6cMnlGM4CQDdYje9ARJr48kzfPgdCFvmp70leNs7p +Mr0LLM2TIm3c+v1k9IwG0ahIR6s3gWUlYrmj9D9BegHESh5sTg8PbhRhQjfDQ9uz +Oci7XgkDaYoHbZY/NgwOgQMMHSgKVwlJUZh/FSkSfZJ5iML5fl57Sf2V+T3VuZ42 +Nowew37pzw0z83woVVVjr8anQ51NLMQyBRAZnGgqSvog5jnXnUBzKqQJQfln456f +U9v+cePD5YnSBNaT4bv2qcVwSpXMaIV0qQv0jwIDAQABAoIBADs6hrxEPBjsv27E +ITKQhqp0ICMxBClHuE5zLJ0bWP90TBbdZBVYv+V5km4KSDLGXPhxqHhB8drAU3dc +KCdn72pF2+tIUcuh1v6/rUMr5sCy6B9iQCGBPxzXSFU6H5XTsAAuyvpgcQo+B2xa +qiAwNHUBFWXz/mNvrD0iJ/L9+QFNvZ6srpB+8RRKU+g5NTipaj9DuHg07jQH5snV +B5s1GpWJbZNLbuCJh+jvWYjKyZY7QFD5dfSEpFH1Evorifntkt+wSxAJFmvpaRFU +CRQYOPecrwZv8PvOogmFBoeMqGKLHiuF/9sx6LB6pnPVgJ62VsV5Yppr4R0dMaEm +augjSqECgYEA2PhaCh+cCgEFdAgDJJ/+wup0yNdT2TkLk2MIYUZQp2WxKN7ys3aZ +nXOohLQmYTlwmvTpiOuzq0W2iNkSXjWLbmSnnDSSpqRkyGbqQ8SEgLZNsHYZTlq7 +Yk4Eg7YSPvsLoXyIbI05aZcpzp9LnAGqRdjh6ySzvMR/rzZkW70JIi0CgYEAt2b/ +dDvItEhaXhtQNdIoSeRb23VOMfAlNgXiTR8JyD14yGW9Ehii5YJqKLFvMM1dtVUx +Q48NgRQjHYOMKI8TPr7S/AGo95ck8uL+xTQwbiGgicq4JocRu40UprR3DhFr0eX1 +Emj1nqFdqirE/ij9v3V5jqijB/HeB+PEN+VtcysCgYEAy+V4GBNkfDJBga0V5xFE +RMA4R5WzgmuNaVCjy2Zc3TM/rXz275gA/Gp4b10sxClKnRSTcsyt58J3q1rzW/1N +rsyAhtcRCfFrlLjCZjUDoEGx+KbDWVMCzXsr2ur7cpxRbcyuF/UPgx8/dqFUWKNn +9IPAq02uazLuGyYuYdfgAXECgYBLNV6OULHWVFFShArZd0v0OUP989XUHrFzvOf3 +TkIZrjOooiftktLCIT/dXh3FuoGyCbSBCtmz1AkuYjKIs1tmAKTOmPOsTHvnanSl +c+hkUT/fIZVwnzUDXzBXyGuGBljbo1xjZ01J9sxNKurLew3LhKYLfVYVvPaa76kY +bun6LwKBgAa1i/JJkut5vNYw+qRvyperjlwjJSa+NUuF+ZEQcOuwPVrrHObCDDyU +HfaCyKgpd3GLdSVocBMyGG6oV9AhpR7i/linvonah4AN3WDB4Gds2XklrdEa/SQo +Z3UPoimOaG7/z7vi8JOl2aXjCro17ZJPFxF4B9HW4EcEeTMucKKq +-----END RSA PRIVATE KEY-----""" + +JWK_PUBLIC = { + "kid": "mock-kid", + "kty": "RSA", + "n": "m3DT0wGFXPgW6tYh8IR2NaoUAP733QVOV/PPx6hPVY04QGT8n+QBdUBBXAZL/qxN5ib3ChI6cMnlGM4CQDdYje9ARJr48kzfPgdCFvmp70leNs7pMr0LLM2TIm3c+v1k9IwG0ahIR6s3gWUlYrmj9D9BegHESh5sTg8PbhRhQjfDQ9uzOci7XgkDaYoHbZY/NgwOgQMMHSgKVwlJUZh/FSkSfZJ5iML5fl57Sf2V+T3VuZ42Nowew37pzw0z83woVVVjr8anQ51NLMQyBRAZnGgqSvog5jnXnUBzKqQJQfln456fU9v+cePD5YnSBNaT4bv2qcVwSpXMaIV0qQv0jw==", # noqa: B950 + "e": "AQAB", +} diff --git a/test/pytest.ini b/test/pytest.ini index 4e23918..b992aa7 100644 --- a/test/pytest.ini +++ b/test/pytest.ini @@ -18,3 +18,16 @@ env = ICAT_SERVER__URL=http://localhost/icat ICAT_SERVER__CERTIFICATE_VALIDATION=true ICAT_SERVER__REQUEST_TIMEOUT_SECONDS=5 + + AUTHENTICATION__OIDC_PROVIDERS__mock-pkce__DISPLAY_NAME = Mock using PKCE + AUTHENTICATION__OIDC_PROVIDERS__mock-pkce__CONFIGURATION_URL = https://mock-oidc-provider/.well-known/openid-configuration + AUTHENTICATION__OIDC_PROVIDERS__mock-pkce__CLIENT_ID = 5041ee0190fa + + AUTHENTICATION__OIDC_PROVIDERS__mock-client-secret__DISPLAY_NAME = Mock using client_secret + AUTHENTICATION__OIDC_PROVIDERS__mock-client-secret__CONFIGURATION_URL = https://mock-oidc-provider/.well-known/openid-configuration + AUTHENTICATION__OIDC_PROVIDERS__mock-client-secret__CLIENT_ID = b4454c229f3e + AUTHENTICATION__OIDC_PROVIDERS__mock-client-secret__CLIENT_SECRET = 8f2ba7d6d0ad + + AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR = delegating + AUTHENTICATION__OIDC_ICAT_AUTHENTICATOR_TOKEN = secret-token + AUTHENTICATION__OIDC_REDIRECT_URI = https://test-oidc-client/callback diff --git a/test/test_authentication.py b/test/test_authentication.py index d51ac85..9343e8c 100644 --- a/test/test_authentication.py +++ b/test/test_authentication.py @@ -8,7 +8,6 @@ from scigateway_auth.common.config import config from scigateway_auth.common.exceptions import ICATAuthenticationError -from scigateway_auth.common.schemas import UserCredentialsPostRequestSchema from scigateway_auth.src.authentication import ICATAuthenticator @@ -21,7 +20,7 @@ class TestICATAuthenticator: password = "test-password" # noqa: S105 mnemonic = "test-mnemonic" session_id = "test-session-id" - credentials = UserCredentialsPostRequestSchema(username=username, password=password) + credentials = {"username": username, "password": password} def create_mock_response(self, status_code: int, json_data: dict = None) -> Mock: """ diff --git a/test/test_oidc.py b/test/test_oidc.py new file mode 100644 index 0000000..d9f3e68 --- /dev/null +++ b/test/test_oidc.py @@ -0,0 +1,213 @@ +""" +Unit tests for the `oidc` module. +""" + +import time +from unittest.mock import Mock, patch + +from cryptography.hazmat.primitives import serialization +import jwt +import pytest +import requests + +from scigateway_auth.common.exceptions import InvalidJWTError, OidcProviderNotFoundError +from scigateway_auth.src import oidc +from test.mock_data import JWK_PRIVATE_KEY, JWK_PUBLIC + + +MOCK_CONFIGURATION = { + "issuer": "https://mock-oidc-provider/issuer", + "jwks_uri": "https://mock-oidc-provider/issuer/keys", + "token_endpoint": "https://mock-oidc-provider/issuer/token", +} + +MOCK_JWKS = {"keys": [JWK_PUBLIC]} + + +class MockRequestsResponse: + + def __init__(self, json): + self._json = json + + def raise_for_status(self): + pass + + def json(self): + return self._json + + +def mock_requests_get(url: str, **kwargs): + + if url == "https://mock-oidc-provider/.well-known/openid-configuration": + return MockRequestsResponse(MOCK_CONFIGURATION) + + if url == "https://mock-oidc-provider/issuer/keys": + return MockRequestsResponse(MOCK_JWKS) + + raise requests.ConnectionError + + +def mock_requests_post(url: str, **kwargs): + + if url == "https://mock-oidc-provider/issuer/token": + return MockRequestsResponse({"token_type": "Bearer", "id_token": create_token()}) + + raise requests.ConnectionError + + +def create_payload(): + return { + "sub": "test-username", + "iss": "https://mock-oidc-provider/issuer", + "aud": "5041ee0190fa", + "iat": int(time.time()), + "exp": int(time.time()) + 600, + } + + +def create_token(payload=None, headers=None): + if headers is None: + headers = {"kid": "mock-kid"} + + if payload is None: + payload = create_payload() + + private_key = serialization.load_pem_private_key(JWK_PRIVATE_KEY.encode(), None) + return jwt.encode(payload, private_key, algorithm="RS256", headers=headers) + + +class TestOidc: + """ + Test suite for the `oidc` module + """ + + @patch("requests.get") + def test_get_username(self, mock_get: Mock): + """ + Test that oidc.get_username() return the expected mechanism and username. + """ + mock_get.side_effect = mock_requests_get + + mechanism, username = oidc.get_username("mock-pkce", create_token()) + assert mechanism is None + assert username == "test-username" + + @patch("requests.get") + def test_get_username_expired(self, mock_get: Mock): + """ + Test that oidc.get_username() raises InvalidJWTError if exp is missing or expired. + """ + mock_get.side_effect = mock_requests_get + + payload = create_payload() + payload["exp"] = int(time.time()) - 10 + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + del payload["exp"] + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + @patch("requests.get") + def test_get_username_invalid_audience(self, mock_get: Mock): + """ + Test that oidc.get_username() raises InvalidJWTError if aud is missing or isn't the expected value. + """ + mock_get.side_effect = mock_requests_get + + payload = create_payload() + payload["aud"] = "invalid" + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + del payload["aud"] + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + @patch("requests.get") + def test_get_username_invalid_issuer(self, mock_get: Mock): + """ + Test that oidc.get_username() raises InvalidJWTError if iss is missing or isn't the expected value. + """ + mock_get.side_effect = mock_requests_get + + payload = create_payload() + payload["iss"] = "invalid" + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + del payload["iss"] + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + @patch("requests.get") + def test_get_username_missing_sub(self, mock_get: Mock): + """ + Test that oidc.get_username() raises InvalidJWTError if sub is missing. + """ + mock_get.side_effect = mock_requests_get + + payload = create_payload() + del payload["sub"] + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(payload)) + + @patch("requests.get") + def test_get_username_missing_kid(self, mock_get: Mock): + """ + Test that oidc.get_username() raises InvalidJWTError if kid is missing. + """ + mock_get.side_effect = mock_requests_get + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(headers={})) + + @patch("requests.get") + def test_get_username_unkonwn_key(self, mock_get: Mock): + """ + Test that oidc.get_username() raises InvalidJWTError if kid doesn't match a key from the JWKS. + """ + mock_get.side_effect = mock_requests_get + + with pytest.raises(InvalidJWTError): + mechanism, username = oidc.get_username("mock-pkce", create_token(headers={"kid": "unknown"})) + + @patch("requests.get") + def test_get_username_unkown_provider(self, mock_get: Mock): + """ + Test that oidc.get_username() raises OidcProviderNotFoundError if the provider_id is not found. + """ + mock_get.side_effect = mock_requests_get + + with pytest.raises(OidcProviderNotFoundError): + mechanism, username = oidc.get_username("unknown", create_token()) + + @patch("requests.post") + @patch("requests.get") + def test_get_token(self, mock_get: Mock, mock_post: Mock): + """ + Test that oidc.get_token() works. + """ + mock_get.side_effect = mock_requests_get + mock_post.side_effect = mock_requests_post + + oidc.get_token("mock-client-secret", "test-code") + + @patch("requests.post") + @patch("requests.get") + def test_get_token_no_client_secret(self, mock_get: Mock, mock_post: Mock): + """ + Test that oidc.get_token() raises OidcProviderNotFoundError if the OIDC provider has no client_secret. + """ + mock_get.side_effect = mock_requests_get + mock_post.side_effect = mock_requests_post + + with pytest.raises(OidcProviderNotFoundError): + oidc.get_token("mock-pkce", "test-code")