from __future__ import annotations
import base64
import datetime
from calendar import timegm
from json import loads
from typing import TYPE_CHECKING, Any, Literal, cast
import jwt
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from jwt import (
ExpiredSignatureError,
InvalidAudienceError,
InvalidTokenError,
PyJWTError,
)
from jwt.utils import base64url_decode
from social_core.backends.oauth import BaseOAuth2PKCE
from social_core.exceptions import (
AuthInvalidParameter,
AuthMissingParameter,
AuthTokenError,
)
from social_core.utils import cache
if TYPE_CHECKING:
from collections.abc import Mapping
from jwt.types import Options
from requests.auth import AuthBase
from social_core.strategy import BaseStrategy
[docs]
class OpenIdConnectAssociation:
"""Use Association model to save the nonce by force."""
def __init__(self, handle, secret="", issued=0, lifetime=0, assoc_type="") -> None:
self.handle = handle # as nonce
self.secret = secret.encode() # not use
self.issued = issued # not use
self.lifetime = lifetime # not use
self.assoc_type = assoc_type # as state
[docs]
class OpenIdConnectAuth(BaseOAuth2PKCE):
"""
Base class for Open ID Connect backends.
Currently only the code response type is supported.
It can also be directly instantiated as a generic OIDC backend.
To use it you will need to set at minimum:
SOCIAL_AUTH_OIDC_OIDC_ENDPOINT = 'https://.....' # endpoint without /.well-known/openid-configuration
SOCIAL_AUTH_OIDC_KEY = '<client_id>'
SOCIAL_AUTH_OIDC_SECRET = '<client_secret>'
SOCIAL_AUTH_OIDC_USE_PKCE = True # optional, enables PKCE for this backend
"""
name = "oidc"
# Override OIDC_ENDPOINT in your subclass to enable autoconfig of OIDC
OIDC_ENDPOINT: str | None = None
ID_TOKEN_MAX_AGE = 600
DEFAULT_SCOPE = ["openid", "profile", "email"]
EXTRA_DATA = ["id_token", "refresh_token", ("sub", "id")]
REDIRECT_STATE = False
REVOKE_TOKEN_METHOD: Literal["GET", "POST", "DELETE"] = "GET"
ID_KEY = "sub"
USERNAME_KEY = "preferred_username"
EMAIL_KEY = "email"
FIRST_NAME_KEY = "given_name"
LAST_NAME_KEY = "family_name"
FULLNAME_KEY = "name"
JWT_ALGORITHMS = ["RS256"]
JWT_DECODE_OPTIONS: Options = {}
JWT_LEEWAY: float = 1.0 # seconds
VALIDATE_AT_HASH: bool = True
CUSTOM_AT_HASH_ALGO: str | None = None
# When these options are unspecified, server will choose via openid autoconfiguration
ID_TOKEN_ISSUER = ""
ACCESS_TOKEN_URL = ""
AUTHORIZATION_URL = ""
REVOKE_TOKEN_URL = ""
USERINFO_URL = ""
JWKS_URI = ""
TOKEN_ENDPOINT_AUTH_METHOD = ""
# Optional parameters for Authentication Request
DISPLAY: str | None = None
PROMPT: str | None = None
MAX_AGE: int | None = None
UI_LOCALES: str | None = None
ID_TOKEN_HINT: str | None = None
LOGIN_HINT: str | None = None
ACR_VALUES: str | None = None
PKCE_DEFAULT_CODE_CHALLENGE_METHOD = "S256"
DEFAULT_USE_PKCE = False
def __init__(
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
) -> None:
super().__init__(strategy, redirect_uri=redirect_uri)
self.id_token = None
[docs]
def get_setting_config(
self, setting_name: str, oidc_name: str, default: str
) -> str:
value = self.setting(setting_name, default)
if not value:
value = self.oidc_config().get(oidc_name)
if not isinstance(value, str):
raise AuthMissingParameter(self, setting_name)
return value
[docs]
def authorization_url(self) -> str:
return self.get_setting_config(
"AUTHORIZATION_URL", "authorization_endpoint", self.AUTHORIZATION_URL
)
[docs]
def access_token_url(self) -> str:
return self.get_setting_config(
"ACCESS_TOKEN_URL", "token_endpoint", self.ACCESS_TOKEN_URL
)
[docs]
def revoke_token_url(self, token, uid) -> str:
return self.get_setting_config(
"REVOKE_TOKEN_URL", "revocation_endpoint", self.REVOKE_TOKEN_URL
)
[docs]
def id_token_issuer(self) -> str:
return self.get_setting_config(
"ID_TOKEN_ISSUER", "issuer", self.ID_TOKEN_ISSUER
)
[docs]
def userinfo_url(self) -> str:
return self.get_setting_config(
"USERINFO_URL", "userinfo_endpoint", self.USERINFO_URL
)
[docs]
def jwks_uri(self) -> str:
return self.get_setting_config("JWKS_URI", "jwks_uri", self.JWKS_URI)
[docs]
def use_basic_auth(self) -> bool:
method = self.setting(
"TOKEN_ENDPOINT_AUTH_METHOD", self.TOKEN_ENDPOINT_AUTH_METHOD
)
if method:
return method == "client_secret_basic"
methods = self.oidc_config().get("token_endpoint_auth_methods_supported", [])
return not methods or "client_secret_basic" in methods
[docs]
def oidc_endpoint(self) -> str:
return cast("str", self.setting("OIDC_ENDPOINT", self.OIDC_ENDPOINT))
@cache(ttl=86400)
def oidc_config(self) -> dict[Any, Any]:
return self.get_json(f"{self.oidc_endpoint()}/.well-known/openid-configuration")
@cache(ttl=86400)
def get_jwks_keys(self):
return self.get_remote_jwks_keys()
# Add client secret as oct key so it can be used for HMAC signatures
# client_id, client_secret = self.get_key_and_secret()
# keys.append({'key': client_secret, 'kty': 'oct'})
[docs]
def get_remote_jwks_keys(self):
response = self.request(self.jwks_uri())
return loads(response.text)["keys"]
[docs]
def auth_params(self, state=None): # noqa: C901, PLR0912
"""Return extra arguments needed on auth process."""
params = super().auth_params(state)
params["nonce"] = self.get_and_store_nonce(self.authorization_url(), state)
display = self.setting("DISPLAY", default=self.DISPLAY)
if display is not None:
if not display:
raise AuthMissingParameter(
self, "OpenID Connect display value cannot be empty string."
)
if display not in ("page", "popup", "touch", "wap"):
raise AuthMissingParameter(
self, f"Invalid OpenID Connect display value: {display}"
)
params["display"] = display
prompt = self.setting("PROMPT", default=self.PROMPT)
if prompt is not None:
if not prompt:
raise AuthInvalidParameter(self, "prompt")
for prompt_token in prompt.split():
if prompt_token not in ("none", "login", "consent", "select_account"):
raise AuthInvalidParameter(self, "prompt")
params["prompt"] = prompt
max_age = self.setting("MAX_AGE", default=self.MAX_AGE)
if max_age is not None:
if max_age < 0:
raise AuthInvalidParameter(self, "max_age")
params["max_age"] = max_age
ui_locales = self.setting("UI_LOCALES", default=self.UI_LOCALES)
if ui_locales is not None:
if not ui_locales:
raise AuthInvalidParameter(self, "ui_locales")
params["ui_locales"] = ui_locales
id_token_hint = self.setting("ID_TOKEN_HINT", default=self.ID_TOKEN_HINT)
if id_token_hint is not None:
if not id_token_hint:
raise AuthInvalidParameter(self, "id_token_hint")
params["id_token_hint"] = id_token_hint
login_hint = self.setting("LOGIN_HINT", default=self.LOGIN_HINT)
if login_hint is not None:
if not login_hint:
raise AuthInvalidParameter(self, "login_hint")
params["login_hint"] = login_hint
acr_values = self.setting("ACR_VALUES", default=self.ACR_VALUES)
if acr_values is not None:
if not acr_values:
raise AuthInvalidParameter(self, "acr_values")
params["acr_values"] = acr_values
return params
[docs]
def get_and_store_nonce(self, url, state):
# Create a nonce
nonce = self.strategy.random_string(64)
# Store the nonce
association = OpenIdConnectAssociation(nonce, assoc_type=state)
self.strategy.storage.association.store(url, association)
return nonce
[docs]
def get_nonce(self, nonce):
try:
return self.strategy.storage.association.get(
server_url=self.authorization_url(), handle=nonce
)[0]
except IndexError:
return None
[docs]
def remove_nonce(self, nonce_id) -> None:
self.strategy.storage.association.remove([nonce_id])
[docs]
def validate_claims(self, id_token) -> None:
utc_timestamp = timegm(datetime.datetime.now(datetime.timezone.utc).timetuple())
if "nbf" in id_token and utc_timestamp < id_token["nbf"]:
raise AuthTokenError(self, "Incorrect id_token: nbf")
# Verify the token was issued in the last 10 minutes
iat_leeway = self.setting("ID_TOKEN_MAX_AGE", self.ID_TOKEN_MAX_AGE)
if utc_timestamp > id_token["iat"] + iat_leeway:
raise AuthTokenError(self, "Incorrect id_token: iat")
# Validate the nonce to ensure the request was not modified
nonce = id_token.get("nonce")
if not nonce:
raise AuthTokenError(self, "Incorrect id_token: nonce")
nonce_obj = self.get_nonce(nonce)
if nonce_obj:
self.remove_nonce(nonce_obj.id)
else:
raise AuthTokenError(self, "Incorrect id_token: nonce")
[docs]
def find_valid_key(self, id_token):
kid = jwt.get_unverified_header(id_token).get("kid")
keys = self.get_jwks_keys()
if kid is not None:
for key in keys:
if kid == key.get("kid"):
break
else:
# In case the key id is not found in the cached keys, just
# reload the JWKS keys. Ideally this should be done by
# invalidating the cache.
self.get_jwks_keys.invalidate() # pyright: ignore[reportAttributeAccessIssue]
keys = self.get_jwks_keys()
for key in keys:
if kid is None or kid == key.get("kid"):
if "alg" not in key:
key["alg"] = cast(
"list[str]", self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS)
)[0]
rsakey = jwt.PyJWK(key)
message, encoded_sig = id_token.rsplit(".", 1)
decoded_sig = base64url_decode(encoded_sig.encode("utf-8"))
if rsakey.Algorithm.verify(
message.encode("utf-8"), rsakey.key, decoded_sig
):
return key
return None
[docs]
def validate_and_return_id_token(self, id_token, access_token):
"""
Validates the id_token according to the steps at
http://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation.
"""
client_id, _client_secret = self.get_key_and_secret()
key = self.find_valid_key(id_token)
if not key:
raise AuthTokenError(self, "Signature verification failed")
rsakey = jwt.PyJWK(key)
try:
claims = jwt.decode(
id_token,
rsakey.key,
algorithms=self.setting("JWT_ALGORITHMS", self.JWT_ALGORITHMS),
audience=client_id,
issuer=self.id_token_issuer(),
options=cast(
"Options",
self.setting("JWT_DECODE_OPTIONS", self.JWT_DECODE_OPTIONS),
),
leeway=cast("int", self.setting("JWT_LEEWAY", self.JWT_LEEWAY)),
)
except ExpiredSignatureError as error:
raise AuthTokenError(self, "Signature has expired") from error
except InvalidAudienceError as error:
# compatibility with jose error message
raise AuthTokenError(self, "Token error: Invalid audience") from error
except InvalidTokenError as error:
raise AuthTokenError(self, str(error)) from error
except PyJWTError as error:
raise AuthTokenError(self, "Invalid signature") from error
# pyjwt does not validate OIDC claims
# see https://github.com/jpadilla/pyjwt/pull/296
if not self.validate_at_hash(claims, access_token, key):
raise AuthTokenError(self, "Invalid access token")
self.validate_claims(claims)
return claims
[docs]
def request_access_token(
self,
url: str,
method: Literal["GET", "POST", "DELETE"] = "GET",
headers: Mapping[str, str | bytes] | None = None,
data: dict | None = None,
json: dict | None = None,
auth: tuple[str, str] | AuthBase | None = None,
params: dict | None = None,
) -> dict[Any, Any]:
"""
Retrieve the access token. Also, validate the id_token and
store it (temporarily).
"""
response = super().request_access_token(
url,
method=method,
headers=headers,
data=data,
json=json,
auth=auth,
params=params,
)
for parameter in ("id_token", "access_token"):
if parameter not in response:
raise AuthTokenError(
self,
f"Missing {parameter} in OpenID Connect token response",
)
self.id_token = self.validate_and_return_id_token(
response["id_token"], response["access_token"]
)
return response
[docs]
def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None:
return self.validate_userinfo_sub(
self.get_json(
self.userinfo_url(),
headers={"Authorization": f"Bearer {access_token}"},
)
)
[docs]
def validate_userinfo_sub(
self, userinfo: dict[str, Any] | None
) -> dict[str, Any] | None:
"""Validate that UserInfo belongs to the validated ID token subject."""
if userinfo is None or userinfo.get("sub") is None:
return userinfo
id_token_sub = self.id_token.get("sub") if self.id_token is not None else None
if userinfo["sub"] != id_token_sub:
raise AuthTokenError(self, "Invalid UserInfo sub")
return userinfo
[docs]
def get_user_id(self, details, response):
if self.id_key() == "sub" and self.id_token is not None:
return self.id_token.get("sub")
return super().get_user_id(details, response)
def get_user_details(self, response):
username_key = self.setting("USERNAME_KEY", self.USERNAME_KEY)
email_key = self.setting("EMAIL_KEY", self.EMAIL_KEY)
first_name_key = self.setting("FIRST_NAME_KEY", self.FIRST_NAME_KEY)
last_name_key = self.setting("LAST_NAME_KEY", self.LAST_NAME_KEY)
fullname_key = self.setting("FULLNAME_KEY", self.FULLNAME_KEY)
def get_value(key):
if key in response:
return response.get(key)
if self.id_token is not None:
return self.id_token.get(key)
return None
return {
"username": get_value(username_key),
"email": get_value(email_key),
"fullname": get_value(fullname_key),
"first_name": get_value(first_name_key),
"last_name": get_value(last_name_key),
}
[docs]
def validate_at_hash(self, claims, access_token, key):
"""
Validate the 'at_hash' claim according to OpenID Connect specs.
See: https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
"""
if not self.VALIDATE_AT_HASH:
return True
if "at_hash" not in claims:
return True
expected_hash = claims["at_hash"]
calculated_hash = self.calc_at_hash(
access_token, key["alg"], self.CUSTOM_AT_HASH_ALGO
)
return expected_hash == calculated_hash
[docs]
@staticmethod
def calc_at_hash(access_token, algorithm, custom_at_hash_algo: str | None = None):
"""
Calculates "at_hash" claim which is not done by pyjwt.
Custom "at_hash" algorithm is used for non-standard token.
See https://pyjwt.readthedocs.io/en/stable/usage.html#oidc-login-flow
See https://github.com/python-social-auth/social-core/issues/1306
"""
if not custom_at_hash_algo:
alg_obj = jwt.get_algorithm_by_name(algorithm)
digest = alg_obj.compute_hash_digest(access_token.encode("utf-8"))
return (
base64.urlsafe_b64encode(digest[: (len(digest) // 2)])
.decode("utf-8")
.rstrip("=")
)
algo_class_name = custom_at_hash_algo.upper()
algo_class = getattr(hashes, algo_class_name, None)
if algo_class is None:
raise NotImplementedError(
f"Unsupported custom at hash algorithm: {custom_at_hash_algo}"
)
hasher = hashes.Hash(algo_class(), backend=default_backend())
hasher.update(access_token.encode("utf-8"))
digest = hasher.finalize()
half = digest[: (len(digest) // 2)]
return base64.urlsafe_b64encode(half).decode("utf-8").rstrip("=")