Source code for social_core.backends.oauth

from __future__ import annotations

import base64
import hashlib
from typing import TYPE_CHECKING, Any, Literal, cast

from oauthlib.oauth1 import SIGNATURE_TYPE_AUTH_HEADER
from requests_oauthlib import OAuth1

from social_core.exceptions import (
    AuthCanceled,
    AuthException,
    AuthFailed,
    AuthMissingParameter,
    AuthStateForbidden,
    AuthStateMissing,
    AuthTokenError,
    AuthUnknownError,
)
from social_core.utils import (
    constant_time_compare,
    handle_http_errors,
    parse_qs,
    url_add_parameters,
    wrap_access_token_error,
)

from .base import BaseAuth

if TYPE_CHECKING:
    from collections.abc import Mapping

    from requests import Response
    from requests.auth import AuthBase


[docs] class OAuthAuth(BaseAuth): """OAuth authentication backend base class. Settings will be inspected to get more values names that should be stored on extra_data field. The setting name is created following the pattern SOCIAL_AUTH_<uppercase current backend name>_EXTRA_DATA. access_token is always stored. URLs settings: AUTHORIZATION_URL Authorization service url ACCESS_TOKEN_URL Access token URL """ AUTHORIZATION_URL = "" ACCESS_TOKEN_URL = "" ACCESS_TOKEN_METHOD: Literal["GET", "POST"] = "POST" ACCESS_TOKEN_PAYLOAD: Literal["form", "json"] = "form" REVOKE_TOKEN_URL: str = "" REVOKE_TOKEN_METHOD: Literal["GET", "POST", "DELETE"] = "POST" ID_KEY = "id" SCOPE_PARAMETER_NAME = "scope" DEFAULT_SCOPE: list[str] | None = None SCOPE_SEPARATOR = " " REDIRECT_STATE = False STATE_PARAMETER = False
[docs] def extra_data( self, user, uid: str, response: dict[str, Any], details: dict[str, Any], pipeline_kwargs: dict[str, Any], ) -> dict[str, Any]: """Return access_token and extra defined names to store in extra_data field""" data = super().extra_data(user, uid, response, details, pipeline_kwargs) data["access_token"] = response.get("access_token") or pipeline_kwargs.get( "access_token" ) return data
[docs] def state_token(self): """Generate csrf token to include as state parameter.""" return self.strategy.random_string(32)
[docs] def get_or_create_state(self) -> str | None: if self.STATE_PARAMETER or self.REDIRECT_STATE: # Store state in session for further request validation. The state # value is passed as state parameter (as specified in OAuth2 spec), # but also added to redirect, that way we can still verify the # request if the provider doesn't implement the state parameter. # Reuse token if any. name = f"{self.name}_state" state = self.strategy.session_get(name) if state is None: state = self.state_token() self.strategy.session_set(name, state) else: state = None return state
[docs] def get_session_state(self): return self.strategy.session_get(f"{self.name}_state")
[docs] def get_request_state(self): request_state = self.data.get("state") or self.data.get("redirect_state") if request_state and isinstance(request_state, list): request_state = request_state[0] return request_state
[docs] def validate_state(self): """Validate state value. Raises exception on error, returns state value if valid.""" if not self.STATE_PARAMETER and not self.REDIRECT_STATE: return None state = self.get_session_state() request_state = self.get_request_state() if not request_state: raise AuthMissingParameter(self, "state") if not state: raise AuthStateMissing(self, "state") if not constant_time_compare(request_state, state): raise AuthStateForbidden(self) return state
[docs] def get_redirect_uri(self, state: str | None = None) -> str: """Build redirect with redirect_state parameter.""" uri = cast("str", self.redirect_uri) if self.REDIRECT_STATE and state: uri = url_add_parameters(uri, {"redirect_state": state}) return uri
[docs] def get_scope(self) -> list[str]: """Return list with needed access scope""" scope = cast("list[str]", self.setting("SCOPE", [])) if not self.setting("IGNORE_DEFAULT_SCOPE", False): scope = scope + (self.DEFAULT_SCOPE or []) return scope
[docs] def get_scope_argument(self): param = {} scope = self.get_scope() if scope: param[self.SCOPE_PARAMETER_NAME] = self.SCOPE_SEPARATOR.join(scope) return param
[docs] def user_data(self, access_token, *args, **kwargs) -> dict[str, Any] | None: """Loads user data from service. Implement in subclass""" raise NotImplementedError
[docs] def authorization_url(self) -> str: url = cast("str", self.setting("AUTHORIZATION_URL", self.AUTHORIZATION_URL)) if format_params := self.get_authorization_url_format(): return url.format(**format_params) return url
[docs] def get_authorization_url_format(self) -> dict[str, str]: return {}
[docs] def access_token_url(self) -> str: url = cast("str", self.setting("ACCESS_TOKEN_URL", self.ACCESS_TOKEN_URL)) if format_params := self.get_access_token_url_format(): return url.format(**format_params) return url
[docs] def get_access_token_url_format(self) -> dict[str, str]: return {}
[docs] def revoke_token_url(self, token, uid) -> str: return cast("str", self.setting("REVOKE_TOKEN_URL", self.REVOKE_TOKEN_URL))
[docs] def revoke_token_params(self, token, uid) -> dict[str, Any]: return {}
[docs] def revoke_token_headers(self, token, uid) -> dict[str, Any]: return {}
[docs] def process_revoke_token_response(self, response): return response.status_code == 200
[docs] def revoke_token(self, token, uid): if revoke_token_url := self.revoke_token_url(token, uid): params = self.revoke_token_params(token, uid) headers = self.revoke_token_headers(token, uid) data = params if self.REVOKE_TOKEN_METHOD != "GET" else None response = self.request( revoke_token_url, params=params, headers=headers, data=data, method=self.REVOKE_TOKEN_METHOD, ) return self.process_revoke_token_response(response) return None
[docs] class BaseOAuth1(OAuthAuth): """Consumer based mechanism OAuth authentication, fill the needed parameters to communicate properly with authentication service. URLs settings: REQUEST_TOKEN_URL Request token URL """ REQUEST_TOKEN_URL = "" REQUEST_TOKEN_METHOD: Literal["GET", "POST"] = "GET" OAUTH_TOKEN_PARAMETER_NAME = "oauth_token" REDIRECT_URI_PARAMETER_NAME = "redirect_uri" UNATHORIZED_TOKEN_SUFIX = "unauthorized_token_name"
[docs] def auth_url(self) -> str: """Return redirect url""" token = self.set_unauthorized_token() return self.oauth_authorization_request(token)
[docs] def process_error(self, data) -> None: if "oauth_problem" in data: if data["oauth_problem"] == "user_refused": raise AuthCanceled(self, "User refused the access") raise AuthUnknownError(self, f"Error was {data['oauth_problem']}")
[docs] @handle_http_errors def auth_complete(self, *args, **kwargs): """Return user, might be logged in""" # Multiple unauthorized tokens are supported (see #521) self.process_error(self.data) self.validate_state() token = self.get_unauthorized_token() access_token = self.access_token(token) return self.do_auth(access_token, *args, **kwargs)
[docs] @handle_http_errors def do_auth(self, access_token, *args, **kwargs): """Finish the auth process once the access_token was retrieved""" if not isinstance(access_token, dict): access_token = parse_qs(access_token) data = self.user_data(access_token) if data is not None and "access_token" not in data: data["access_token"] = access_token kwargs.update({"response": data, "backend": self}) return self.strategy.authenticate(*args, **kwargs)
[docs] def get_unauthorized_token(self): name = self.name + self.UNATHORIZED_TOKEN_SUFIX unauthed_tokens = self.strategy.session_get(name, []) if not unauthed_tokens: raise AuthTokenError(self, "Missing unauthorized token") data_token = self.data.get(self.OAUTH_TOKEN_PARAMETER_NAME) if data_token is None: raise AuthTokenError(self, "Missing unauthorized token") token = None for utoken in unauthed_tokens: orig_utoken = utoken if not isinstance(utoken, dict): utoken = parse_qs(utoken) if utoken.get(self.OAUTH_TOKEN_PARAMETER_NAME) == data_token: self.strategy.session_set( name, list(set(unauthed_tokens) - {orig_utoken}) ) token = utoken break else: raise AuthTokenError(self, "Incorrect tokens") return token
[docs] def set_unauthorized_token(self): token = self.unauthorized_token() name = self.name + self.UNATHORIZED_TOKEN_SUFIX tokens = [*self.strategy.session_get(name, []), token] self.strategy.session_set(name, tokens) return token
[docs] def request_token_extra_arguments(self) -> dict[str, str]: """Return extra arguments needed on request-token process""" return cast("dict[str, str]", self.setting("REQUEST_TOKEN_EXTRA_ARGUMENTS", {}))
[docs] def unauthorized_token(self): """Return request for unauthorized token (first stage)""" params = self.request_token_extra_arguments() params.update(self.get_scope_argument()) key, secret = self.get_key_and_secret() state = self.get_or_create_state() response = self.request( self.REQUEST_TOKEN_URL, params=params, auth=OAuth1(key, secret, callback_uri=self.get_redirect_uri(state)), method=self.REQUEST_TOKEN_METHOD, ) content = response.content if response.encoding or response.apparent_encoding: content = content.decode(response.encoding or response.apparent_encoding) else: content = response.content.decode() return content
[docs] def oauth_authorization_request(self, token): """Generate OAuth request to authorize token.""" if not isinstance(token, dict): token = parse_qs(token) params = self.auth_extra_arguments() or {} params.update(self.get_scope_argument()) params[self.OAUTH_TOKEN_PARAMETER_NAME] = cast( "str", token.get(self.OAUTH_TOKEN_PARAMETER_NAME) ) state = self.get_or_create_state() params[self.REDIRECT_URI_PARAMETER_NAME] = self.get_redirect_uri(state) return url_add_parameters(self.authorization_url(), params)
[docs] def oauth_auth( self, token: dict | None = None, oauth_verifier=None, signature_type=SIGNATURE_TYPE_AUTH_HEADER, ): key, secret = self.get_key_and_secret() oauth_verifier = oauth_verifier or self.data.get("oauth_verifier") if token: resource_owner_key = token.get("oauth_token") resource_owner_secret = token.get("oauth_token_secret") if not resource_owner_key: raise AuthTokenError(self, "Missing oauth_token") if not resource_owner_secret: raise AuthTokenError(self, "Missing oauth_token_secret") else: resource_owner_key = None resource_owner_secret = None state = self.get_or_create_state() return OAuth1( key, secret, resource_owner_key=resource_owner_key, resource_owner_secret=resource_owner_secret, callback_uri=self.get_redirect_uri(state), verifier=oauth_verifier, signature_type=signature_type, )
[docs] def oauth_request( self, token: dict, url: str, params=None, method: Literal["GET", "POST"] = "GET" ) -> Response: """Generate OAuth request, setups callback url""" return self.request( url, method=method, params=params, auth=self.oauth_auth(token) )
[docs] def access_token(self, token: dict) -> dict[str, str]: """Return request for access token value""" return self.get_querystring( self.access_token_url(), auth=self.oauth_auth(token), method=self.ACCESS_TOKEN_METHOD, )
[docs] def user_data(self, access_token: dict, *args, **kwargs) -> dict[str, Any] | None: """Loads user data from service. Implement in subclass""" return {}
[docs] class BaseOAuth2(OAuthAuth): """Base class for OAuth2 providers. OAuth2 details at: https://datatracker.ietf.org/doc/html/rfc6749 """ REFRESH_TOKEN_URL: str | None = None REFRESH_TOKEN_METHOD: Literal["GET", "POST", "DELETE"] = "POST" RESPONSE_TYPE: str | None = "code" REDIRECT_STATE = True STATE_PARAMETER = True USE_BASIC_AUTH = False
[docs] def use_basic_auth(self) -> bool: return self.USE_BASIC_AUTH
[docs] def auth_params(self, state: str | None = None) -> dict[str, str]: client_id, _client_secret = self.get_key_and_secret() params = {"client_id": client_id, "redirect_uri": self.get_redirect_uri(state)} if self.STATE_PARAMETER and state: params["state"] = state if self.RESPONSE_TYPE: params["response_type"] = self.RESPONSE_TYPE return params
[docs] def auth_url(self) -> str: """Return redirect url""" state = self.get_or_create_state() params = self.auth_params(state) params.update(self.get_scope_argument()) params.update(self.auth_extra_arguments()) # when self.REDIRECT_STATE is False, redirect_uri matching is strictly enforced, # so match the providers value exactly. return url_add_parameters( self.authorization_url(), params, not self.REDIRECT_STATE )
[docs] def auth_complete_params(self, state=None): params = { "grant_type": "authorization_code", # request auth code "code": self.data.get("code", ""), # server response code "redirect_uri": self.get_redirect_uri(state), } if not self.use_basic_auth(): client_id, client_secret = self.get_key_and_secret() params.update( { "client_id": client_id, "client_secret": client_secret, } ) return params
[docs] def auth_complete_credentials(self): if self.use_basic_auth(): return self.get_key_and_secret() return None
[docs] def auth_headers(self) -> Mapping[str, str | bytes]: return { "Content-Type": "application/json" if self.ACCESS_TOKEN_PAYLOAD == "json" else "application/x-www-form-urlencoded", "Accept": "application/json", }
[docs] def extra_data( self, user, uid: str, response: dict[str, Any], details: dict[str, Any], pipeline_kwargs: dict[str, Any], ) -> dict[str, Any]: """Return access_token, token_type, and extra defined names to store in extra_data field""" data = super().extra_data(user, uid, response, details, pipeline_kwargs) data["token_type"] = response.get("token_type") or pipeline_kwargs.get( "token_type" ) return data
[docs] def request_access_token( self, url: str, method: Literal["GET", "POST", "DELETE"] = "GET", headers: Mapping[str, str | bytes] | None = None, data: dict | None = None, json: dict | None = None, auth: tuple[str, str] | AuthBase | None = None, params: dict | None = None, ) -> dict[Any, Any]: with wrap_access_token_error(self): return self.get_json( url, method=method, headers=headers, data=data, auth=auth, params=params, json=json, )
[docs] def process_error(self, data) -> None: if data.get("error"): if "denied" in data["error"] or "cancelled" in data["error"]: raise AuthCanceled(self, data.get("error_description", "")) raise AuthFailed(self, data.get("error_description") or data["error"]) if "denied" in data: raise AuthCanceled(self, data["denied"])
[docs] @handle_http_errors def auth_complete(self, *args, **kwargs): """Completes login process, must return user instance""" self.process_error(self.data) state = self.validate_state() data = params = json = None auth_params = self.auth_complete_params(state) if self.ACCESS_TOKEN_METHOD == "GET": params = auth_params elif self.ACCESS_TOKEN_PAYLOAD == "json": json = auth_params else: data = auth_params response = self.request_access_token( self.access_token_url(), data=data, json=json, params=params, headers=self.auth_headers(), auth=self.auth_complete_credentials(), method=self.ACCESS_TOKEN_METHOD, ) self.process_error(response) return self.do_auth( response["access_token"], *args, response=response, **kwargs )
[docs] @handle_http_errors def do_auth(self, access_token, *args, **kwargs): """Finish the auth process once the access_token was retrieved""" data = self.user_data(access_token, *args, **kwargs) response = kwargs.get("response") or {} response.update(data or {}) if "access_token" not in response: response["access_token"] = access_token kwargs.update({"response": response, "backend": self}) return self.strategy.authenticate(*args, **kwargs)
[docs] def refresh_token_params(self, token: str, *args, **kwargs) -> dict[str, str]: client_id, client_secret = self.get_key_and_secret() return { "refresh_token": token, "grant_type": "refresh_token", "client_id": client_id, "client_secret": client_secret, }
[docs] def refresh_token_auth(self) -> AuthBase | None: return None
[docs] def process_refresh_token_response(self, response, *args, **kwargs) -> dict: return response.json()
[docs] def refresh_token(self, token: str, *args, **kwargs) -> dict: params = self.refresh_token_params(token, *args, **kwargs) url = self.refresh_token_url() method = self.REFRESH_TOKEN_METHOD is_get = method == "GET" request = self.request( url, headers=self.auth_headers(), method=method, auth=self.refresh_token_auth(), data=params if not is_get else None, params=params if is_get else None, ) return self.process_refresh_token_response(request, *args, **kwargs)
[docs] def refresh_token_url(self): return self.REFRESH_TOKEN_URL or self.access_token_url()
[docs] def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None: """Loads user data from service. Implement in subclass""" return {}
[docs] class BaseOAuth2PKCE(BaseOAuth2): """ Base class for providers using OAuth2 with Proof Key for Code Exchange (PKCE). OAuth2 details at: https://datatracker.ietf.org/doc/html/rfc6749 PKCE details at: https://datatracker.ietf.org/doc/html/rfc7636 """ PKCE_DEFAULT_CODE_CHALLENGE_METHOD = "S256" PKCE_DEFAULT_CODE_VERIFIER_LENGTH = 43 DEFAULT_USE_PKCE = True
[docs] def create_code_verifier(self): name = f"{self.name}_code_verifier" code_verifier_len = cast( "int", self.setting( "PKCE_CODE_VERIFIER_LENGTH", default=self.PKCE_DEFAULT_CODE_VERIFIER_LENGTH, ), ) code_verifier = self.strategy.random_string(code_verifier_len) self.strategy.session_set(name, code_verifier) return code_verifier
[docs] def get_code_verifier(self): name = f"{self.name}_code_verifier" return self.strategy.session_get(name)
[docs] def generate_code_challenge(self, code_verifier, challenge_method): method = challenge_method.lower() if method == "s256": hashed = hashlib.sha256(code_verifier.encode()).digest() encoded = base64.urlsafe_b64encode(hashed) return encoded.decode().replace("=", "") # remove padding if method == "plain": return code_verifier raise AuthException(self, "Unsupported code challenge method.")
[docs] def auth_params(self, state=None): params = super().auth_params(state=state) if self.setting("USE_PKCE", default=self.DEFAULT_USE_PKCE): code_challenge_method = cast( "str", self.setting( "PKCE_CODE_CHALLENGE_METHOD", default=self.PKCE_DEFAULT_CODE_CHALLENGE_METHOD, ), ) code_verifier = self.create_code_verifier() code_challenge = self.generate_code_challenge( code_verifier, code_challenge_method ) params["code_challenge_method"] = code_challenge_method params["code_challenge"] = code_challenge return params
[docs] def auth_complete_params(self, state=None): params = super().auth_complete_params(state=state) if self.setting("USE_PKCE", default=self.DEFAULT_USE_PKCE): code_verifier = self.get_code_verifier() params["code_verifier"] = code_verifier return params