1919
2020import json
2121import logging
22+ import secrets
2223from typing import cast
2324from urllib .parse import quote , urljoin
2425
@@ -53,15 +54,23 @@ class AuthManagerRefreshTokenExpiredException(Exception): # type: ignore[no-red
5354login_router = AirflowRouter (tags = ["KeycloakAuthManagerLogin" ])
5455
5556COOKIE_NAME_ID_TOKEN = "_id_token"
57+ COOKIE_NAME_OAUTH_STATE = "_oauth_state"
5658
5759
5860@login_router .get ("/login" )
5961def login (request : Request ) -> RedirectResponse :
6062 """Initiate the authentication."""
6163 client = KeycloakAuthManager .get_keycloak_client ()
6264 redirect_uri = request .url_for ("login_callback" )
63- auth_url = client .auth_url (redirect_uri = str (redirect_uri ), scope = "openid" )
64- return RedirectResponse (auth_url )
65+ state = secrets .token_urlsafe (32 )
66+ auth_url = client .auth_url (redirect_uri = str (redirect_uri ), scope = "openid" , state = state )
67+ response = RedirectResponse (auth_url )
68+ secure = bool (conf .get ("api" , "ssl_cert" , fallback = "" ))
69+ cookie_path = get_cookie_path ()
70+ response .set_cookie (
71+ COOKIE_NAME_OAUTH_STATE , state , max_age = 300 , path = cookie_path , httponly = True , secure = secure
72+ )
73+ return response
6574
6675
6776@login_router .get ("/login_callback" )
@@ -70,6 +79,10 @@ def login_callback(request: Request):
7079 code = request .query_params .get ("code" )
7180 if not code :
7281 return HTMLResponse ("Missing code" , status_code = 400 )
82+ state_q = request .query_params .get ("state" , "" )
83+ state_c = request .cookies .get (COOKIE_NAME_OAUTH_STATE , "" )
84+ if not state_q or not state_c or not secrets .compare_digest (state_q , state_c ):
85+ return HTMLResponse ("Invalid OAuth state parameter" , status_code = 403 )
7386
7487 client = KeycloakAuthManager .get_keycloak_client ()
7588 redirect_uri = request .url_for ("login_callback" )
0 commit comments