Source code for social_core.backends.open_id_connect

from __future__ import annotations

import base64
import datetime
from calendar import timegm
from json import loads
from typing import TYPE_CHECKING, Any, Literal, cast

import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from jwt import (
    ExpiredSignatureError,
    InvalidAudienceError,
    InvalidTokenError,
    PyJWTError,
)
from jwt.utils import base64url_decode

from social_core.backends.oauth import BaseOAuth2PKCE
from social_core.exceptions import (
    AuthInvalidParameter,
    AuthMissingParameter,
    AuthTokenError,
)
from social_core.utils import cache

if TYPE_CHECKING:
    from collections.abc import Mapping

    from jwt.types import Options
    from requests.auth import AuthBase

    from social_core.strategy import BaseStrategy


[docs] class OpenIdConnectAssociation: """Use Association model to save the nonce by force.""" def __init__(self, handle, secret="", issued=0, lifetime=0, assoc_type="") -> None: self.handle = handle # as nonce self.secret = secret.encode() # not use self.issued = issued # not use self.lifetime = lifetime # not use self.assoc_type = assoc_type # as state
[docs] class OpenIdConnectAuth(BaseOAuth2PKCE): """ Base class for Open ID Connect backends. Currently only the code response type is supported. It can also be directly instantiated as a generic OIDC backend. To use it you will need to set at minimum: SOCIAL_AUTH_OIDC_OIDC_ENDPOINT = 'https://.....' # endpoint without /.well-known/openid-configuration SOCIAL_AUTH_OIDC_KEY = '<client_id>' SOCIAL_AUTH_OIDC_SECRET = '<client_secret>' SOCIAL_AUTH_OIDC_USE_PKCE = True # optional, enables PKCE for this backend """ name = "oidc" # Override OIDC_ENDPOINT in your subclass to enable autoconfig of OIDC OIDC_ENDPOINT: str | None = None ID_TOKEN_MAX_AGE = 600 DEFAULT_SCOPE = ["openid", "profile", "email"] EXTRA_DATA = ["id_token", "refresh_token", ("sub", "id")] REDIRECT_STATE = False REVOKE_TOKEN_METHOD: Literal["GET", "POST", "DELETE"] = "GET" ID_KEY = "sub" USERNAME_KEY = "preferred_username" EMAIL_KEY = "email" FIRST_NAME_KEY = "given_name" LAST_NAME_KEY = "family_name" FULLNAME_KEY = "name" JWT_ALGORITHMS = ["RS256"] JWT_DECODE_OPTIONS: Options = {} JWT_LEEWAY: float = 1.0 # seconds VALIDATE_AT_HASH: bool = True CUSTOM_AT_HASH_ALGO: str | None = None # When these options are unspecified, server will choose via openid autoconfiguration ID_TOKEN_ISSUER = "" ACCESS_TOKEN_URL = "" AUTHORIZATION_URL = "" REVOKE_TOKEN_URL = "" USERINFO_URL = "" JWKS_URI = "" TOKEN_ENDPOINT_AUTH_METHOD = "" # Optional parameters for Authentication Request DISPLAY: str | None = None PROMPT: str | None = None MAX_AGE: int | None = None UI_LOCALES: str | None = None ID_TOKEN_HINT: str | None = None LOGIN_HINT: str | None = None ACR_VALUES: str | None = None PKCE_DEFAULT_CODE_CHALLENGE_METHOD = "S256" DEFAULT_USE_PKCE = False def __init__( self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None ) -> None: super().__init__(strategy, redirect_uri=redirect_uri) self.id_token = None
[docs] def get_setting_config( self, setting_name: str, oidc_name: str, default: str ) -> str: value = self.setting(setting_name, default) if not value: value = self.oidc_config().get(oidc_name) if not isinstance(value, str): raise AuthMissingParameter(self, setting_name) return value
[docs] def authorization_url(self) -> str: return self.get_setting_config( "AUTHORIZATION_URL", "authorization_endpoint", self.AUTHORIZATION_URL )
[docs] def access_token_url(self) -> str: return self.get_setting_config( "ACCESS_TOKEN_URL", "token_endpoint", self.ACCESS_TOKEN_URL )
[docs] def revoke_token_url(self, token, uid) -> str: return self.get_setting_config( "REVOKE_TOKEN_URL", "revocation_endpoint", self.REVOKE_TOKEN_URL )
[docs] def id_token_issuer(self) -> str: return self.get_setting_config( "ID_TOKEN_ISSUER", "issuer", self.ID_TOKEN_ISSUER )
[docs] def userinfo_url(self) -> str: return self.get_setting_config( "USERINFO_URL", "userinfo_endpoint", self.USERINFO_URL )
[docs] def jwks_uri(self) -> str: return self.get_setting_config("JWKS_URI", "jwks_uri", self.JWKS_URI)
[docs] def use_basic_auth(self) -> bool: method = self.setting( "TOKEN_ENDPOINT_AUTH_METHOD", self.TOKEN_ENDPOINT_AUTH_METHOD ) if method: return method == "client_secret_basic" methods = self.oidc_config().get("token_endpoint_auth_methods_supported", []) return not methods or "client_secret_basic" in methods
[docs] def oidc_endpoint(self) -> str: return cast("str", self.setting("OIDC_ENDPOINT", self.OIDC_ENDPOINT))
@cache(ttl=86400) def oidc_config(self) -> dict[Any, Any]: return self.get_json(f"{self.oidc_endpoint()}/.well-known/openid-configuration") @cache(ttl=86400) def get_jwks_keys(self): return self.get_remote_jwks_keys() # Add client secret as oct key so it can be used for HMAC signatures # client_id, client_secret = self.get_key_and_secret() # keys.append({'key': client_secret, 'kty': 'oct'})
[docs] def get_remote_jwks_keys(self): response = self.request(self.jwks_uri()) return loads(response.text)["keys"]
[docs] def auth_params(self, state=None): # noqa: C901, PLR0912 """Return extra arguments needed on auth process.""" params = super().auth_params(state) params["nonce"] = self.get_and_store_nonce(self.authorization_url(), state) display = self.setting("DISPLAY", default=self.DISPLAY) if display is not None: if not display: raise AuthMissingParameter( self, "OpenID Connect display value cannot be empty string." ) if display not in ("page", "popup", "touch", "wap"): raise AuthMissingParameter( self, f"Invalid OpenID Connect display value: {display}" ) params["display"] = display prompt = self.setting("PROMPT", default=self.PROMPT) if prompt is not None: if not prompt: raise AuthInvalidParameter(self, "prompt") for prompt_token in prompt.split(): if prompt_token not in ("none", "login", "consent", "select_account"): raise AuthInvalidParameter(self, "prompt") params["prompt"] = prompt max_age = self.setting("MAX_AGE", default=self.MAX_AGE) if max_age is not None: if max_age < 0: raise AuthInvalidParameter(self, "max_age") params["max_age"] = max_age ui_locales = self.setting("UI_LOCALES", default=self.UI_LOCALES) if ui_locales is not None: if not ui_locales: raise AuthInvalidParameter(self, "ui_locales") params["ui_locales"] = ui_locales id_token_hint = self.setting("ID_TOKEN_HINT", default=self.ID_TOKEN_HINT) if id_token_hint is not None: if not id_token_hint: raise AuthInvalidParameter(self, "id_token_hint") params["id_token_hint"] = id_token_hint login_hint = self.setting("LOGIN_HINT", default=self.LOGIN_HINT) if login_hint is not None: if not login_hint: raise AuthInvalidParameter(self, "login_hint") params["login_hint"] = login_hint acr_values = self.setting("ACR_VALUES", default=self.ACR_VALUES) if acr_values is not None: if not acr_values: raise AuthInvalidParameter(self, "acr_values") params["acr_values"] = acr_values return params
[docs] def get_and_store_nonce(self, url, state): # Create a nonce nonce = self.strategy.random_string(64) # Store the nonce association = OpenIdConnectAssociation(nonce, assoc_type=state) self.strategy.storage.association.store(url, association) return nonce
[docs] def get_nonce(self, nonce): try: return self.strategy.storage.association.get( server_url=self.authorization_url(), handle=nonce )[0] except IndexError: return None
[docs] def remove_nonce(self, nonce_id) -> None: self.strategy.storage.association.remove([nonce_id])
[docs] def validate_claims(self, id_token) -> None: utc_timestamp = timegm(datetime.datetime.now(datetime.timezone.utc).timetuple()) if "nbf" in id_token and utc_timestamp < id_token["nbf"]: raise AuthTokenError(self, "Incorrect id_token: nbf") # Verify the token was issued in the last 10 minutes iat_leeway = self.setting("ID_TOKEN_MAX_AGE", self.ID_TOKEN_MAX_AGE) if utc_timestamp > id_token["iat"] + iat_leeway: raise AuthTokenError(self, "Incorrect id_token: iat") # Validate the nonce to ensure the request was not modified nonce = id_token.get("nonce") if not nonce: raise AuthTokenError(self, "Incorrect id_token: nonce") nonce_obj = self.get_nonce(nonce) if nonce_obj: self.remove_nonce(nonce_obj.id) else: raise AuthTokenError(self, "Incorrect id_token: nonce")
[docs] def find_valid_key(self, id_token): kid = jwt.get_unverified_header(id_token).get("kid") keys = self.get_jwks_keys() if kid is not None: for key in keys: if kid == key.get("kid"): break else: # In case the key id is not found in the cached keys, just # reload the JWKS keys. Ideally this should be done by # invalidating the cache. self.get_jwks_keys.invalidate() # pyright: ignore[reportAttributeAccessIssue] keys = self.get_jwks_keys() for key in keys: if kid is None or kid == key.get("kid"): if "alg" not in key: key["alg"] = cast( "list[str]", self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS) )[0] rsakey = jwt.PyJWK(key) message, encoded_sig = id_token.rsplit(".", 1) decoded_sig = base64url_decode(encoded_sig.encode("utf-8")) if rsakey.Algorithm.verify( message.encode("utf-8"), rsakey.key, decoded_sig ): return key return None
[docs] def validate_and_return_id_token(self, id_token, access_token): """ Validates the id_token according to the steps at http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation. """ client_id, _client_secret = self.get_key_and_secret() key = self.find_valid_key(id_token) if not key: raise AuthTokenError(self, "Signature verification failed") rsakey = jwt.PyJWK(key) try: claims = jwt.decode( id_token, rsakey.key, algorithms=self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS), audience=client_id, issuer=self.id_token_issuer(), options=cast( "Options", self.setting("JWT_DECODE_OPTIONS", self.JWT_DECODE_OPTIONS), ), leeway=cast("int", self.setting("JWT_LEEWAY", self.JWT_LEEWAY)), ) except ExpiredSignatureError as error: raise AuthTokenError(self, "Signature has expired") from error except InvalidAudienceError as error: # compatibility with jose error message raise AuthTokenError(self, "Token error: Invalid audience") from error except InvalidTokenError as error: raise AuthTokenError(self, str(error)) from error except PyJWTError as error: raise AuthTokenError(self, "Invalid signature") from error # pyjwt does not validate OIDC claims # see https://github.com/jpadilla/pyjwt/pull/296 if not self.validate_at_hash(claims, access_token, key): raise AuthTokenError(self, "Invalid access token") self.validate_claims(claims) return claims
[docs] def request_access_token( self, url: str, method: Literal["GET", "POST", "DELETE"] = "GET", headers: Mapping[str, str | bytes] | None = None, data: dict | None = None, json: dict | None = None, auth: tuple[str, str] | AuthBase | None = None, params: dict | None = None, ) -> dict[Any, Any]: """ Retrieve the access token. Also, validate the id_token and store it (temporarily). """ response = super().request_access_token( url, method=method, headers=headers, data=data, json=json, auth=auth, params=params, ) for parameter in ("id_token", "access_token"): if parameter not in response: raise AuthTokenError( self, f"Missing {parameter} in OpenID Connect token response", ) self.id_token = self.validate_and_return_id_token( response["id_token"], response["access_token"] ) return response
[docs] def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None: return self.validate_userinfo_sub( self.get_json( self.userinfo_url(), headers={"Authorization": f"Bearer {access_token}"}, ) )
[docs] def validate_userinfo_sub( self, userinfo: dict[str, Any] | None ) -> dict[str, Any] | None: """Validate that UserInfo belongs to the validated ID token subject.""" if userinfo is None or userinfo.get("sub") is None: return userinfo id_token_sub = self.id_token.get("sub") if self.id_token is not None else None if userinfo["sub"] != id_token_sub: raise AuthTokenError(self, "Invalid UserInfo sub") return userinfo
[docs] def get_user_id(self, details, response): if self.id_key() == "sub" and self.id_token is not None: return self.id_token.get("sub") return super().get_user_id(details, response)
def get_user_details(self, response): username_key = self.setting("USERNAME_KEY", self.USERNAME_KEY) email_key = self.setting("EMAIL_KEY", self.EMAIL_KEY) first_name_key = self.setting("FIRST_NAME_KEY", self.FIRST_NAME_KEY) last_name_key = self.setting("LAST_NAME_KEY", self.LAST_NAME_KEY) fullname_key = self.setting("FULLNAME_KEY", self.FULLNAME_KEY) def get_value(key): if key in response: return response.get(key) if self.id_token is not None: return self.id_token.get(key) return None return { "username": get_value(username_key), "email": get_value(email_key), "fullname": get_value(fullname_key), "first_name": get_value(first_name_key), "last_name": get_value(last_name_key), }
[docs] def validate_at_hash(self, claims, access_token, key): """ Validate the 'at_hash' claim according to OpenID Connect specs. See: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken """ if not self.VALIDATE_AT_HASH: return True if "at_hash" not in claims: return True expected_hash = claims["at_hash"] calculated_hash = self.calc_at_hash( access_token, key["alg"], self.CUSTOM_AT_HASH_ALGO ) return expected_hash == calculated_hash
[docs] @staticmethod def calc_at_hash(access_token, algorithm, custom_at_hash_algo: str | None = None): """ Calculates "at_hash" claim which is not done by pyjwt. Custom "at_hash" algorithm is used for non-standard token. See https://pyjwt.readthedocs.io/en/stable/usage.html#oidc-login-flow See https://github.com/python-social-auth/social-core/issues/1306 """ if not custom_at_hash_algo: alg_obj = jwt.get_algorithm_by_name(algorithm) digest = alg_obj.compute_hash_digest(access_token.encode("utf-8")) return ( base64.urlsafe_b64encode(digest[: (len(digest) // 2)]) .decode("utf-8") .rstrip("=") ) algo_class_name = custom_at_hash_algo.upper() algo_class = getattr(hashes, algo_class_name, None) if algo_class is None: raise NotImplementedError( f"Unsupported custom at hash algorithm: {custom_at_hash_algo}" ) hasher = hashes.Hash(algo_class(), backend=default_backend()) hasher.update(access_token.encode("utf-8")) digest = hasher.finalize() half = digest[: (len(digest) // 2)] return base64.urlsafe_b64encode(half).decode("utf-8").rstrip("=")