Source code for social_core.backends.base

from __future__ import annotations

import base64
import time
from typing import TYPE_CHECKING, Any, Literal, cast

import requests

from social_core.exceptions import AuthConnectionError, AuthUnknownError
from social_core.registry import REGISTRY
from social_core.utils import module_member, parse_qs, social_logger, user_agent

if TYPE_CHECKING:
    from collections.abc import Mapping

    from requests import Response
    from requests.auth import AuthBase

    from social_core.storage import PartialMixin, PipelineUserProtocol, UserProtocol
    from social_core.strategy import BaseStrategy, HttpResponseProtocol


[docs] class BaseAuth: """A authentication backend that authenticates the user based on the provider response""" name = "" # provider name, it's stored in database supports_inactive_user = False # Django auth ID_KEY: str = "" EXTRA_DATA: list[str | tuple[str, str] | tuple[str, str, bool]] | None = None GET_ALL_EXTRA_DATA = False REQUIRES_EMAIL_VALIDATION = False SEND_USER_AGENT = True def __init__( self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None ) -> None: self.strategy: BaseStrategy = ( strategy if strategy is not None else REGISTRY.default_strategy ) self.redirect_uri = redirect_uri self.data = self.strategy.request_data() self.redirect_uri = self.strategy.absolute_uri(self.redirect_uri)
[docs] def log_debug(self, message, *args) -> None: social_logger.debug(f"{self.name}: {message}", *args)
[docs] def log_warning(self, message, *args) -> None: social_logger.warning(f"{self.name}: {message}", *args)
[docs] def setting(self, name: str, default=None): """Return setting value from strategy""" return self.strategy.setting(name, default=default, backend=self)
[docs] def start(self) -> HttpResponseProtocol: if self.uses_redirect(): return self.strategy.redirect(self.auth_url()) return self.strategy.html(self.auth_html())
[docs] def complete(self, *args, **kwargs) -> HttpResponseProtocol | UserProtocol | None: return self.auth_complete(*args, **kwargs)
[docs] def auth_url(self) -> str: """Must return redirect URL to auth provider""" raise NotImplementedError("Implement in subclass")
[docs] def auth_html(self) -> str: """Must return login HTML content returned by provider""" return "Implement in subclass"
[docs] def auth_complete( self, *args, **kwargs ) -> HttpResponseProtocol | UserProtocol | None: """Completes login process, must return user instance""" raise NotImplementedError("Implement in subclass")
[docs] def process_error(self, data) -> None: """Hook to process provider response errors. Default implementation is a no-op. Backends that can detect provider-specific error payloads should override this method and raise an appropriate exception when needed. """
[docs] def authenticate( self, *args, **kwargs ) -> UserProtocol | HttpResponseProtocol | None: """Authenticate user using social credentials Authentication is made if this is the correct backend, backend verification is made by kwargs inspection for current backend name presence. """ # Validate backend and arguments. Require that the Social Auth # response be passed in as a keyword argument, to make sure we # don't match the username/password calling conventions of # authenticate. if ( "backend" not in kwargs or kwargs["backend"].name != self.name or "strategy" not in kwargs or "response" not in kwargs ): return None self.strategy = kwargs.get("strategy") or self.strategy self.redirect_uri = kwargs.get("redirect_uri") or self.redirect_uri self.data = self.strategy.request_data() kwargs.setdefault("is_new", False) pipeline = self.strategy.get_pipeline(self) args, kwargs = self.strategy.clean_authenticate_args(*args, **kwargs) return self.pipeline(pipeline, *args, **kwargs)
[docs] def pipeline( self, pipeline, pipeline_index: int = 0, *args, **kwargs ) -> UserProtocol | HttpResponseProtocol | None: out = self.run_pipeline(pipeline, pipeline_index, *args, **kwargs) if not isinstance(out, dict): return cast("HttpResponseProtocol", out) user = cast("UserProtocol | None", out.get("user")) if user: pipeline_user = cast("PipelineUserProtocol", user) pipeline_user.social_user = cast("Any", out.get("social")) pipeline_user.is_new = bool(out.get("is_new")) return user
[docs] def disconnect(self, *args, **kwargs) -> dict: pipeline = self.strategy.get_disconnect_pipeline(self) kwargs["name"] = self.name kwargs["user_storage"] = self.strategy.storage.user return self.run_pipeline(pipeline, *args, **kwargs)
[docs] def run_pipeline( self, pipeline: list[str], pipeline_index=0, *args, **kwargs ) -> dict: out = kwargs.copy() out.setdefault("strategy", self.strategy) out.setdefault("backend", out.pop(self.name, None) or self) out.setdefault("request", self.strategy.request_data()) out.setdefault("details", {}) if ( not isinstance(pipeline_index, int) or pipeline_index < 0 or pipeline_index >= len(pipeline) ): pipeline_index = 0 for idx, name in enumerate(pipeline[pipeline_index:]): out["pipeline_index"] = pipeline_index + idx func = module_member(name) result = func(*args, **out) or {} if not isinstance(result, dict): return result out.update(result) return out
[docs] def extra_data( self, user: UserProtocol | None, uid: str, response: dict[str, Any], details: dict[str, Any], pipeline_kwargs: dict[str, Any], ) -> dict[str, Any]: """Return default extra data to store in extra_data field""" data: dict[str, Any] = { # store the last time authentication took place "auth_time": int(time.time()) } extra_data_entries: ( list[str] | list[str | tuple[str, str] | tuple[str, str, bool]] ) = [] if self.GET_ALL_EXTRA_DATA or self.setting("GET_ALL_EXTRA_DATA", False): extra_data_entries = list(response.keys()) else: extra_data_entries = (self.EXTRA_DATA or []) + cast( "list[str | tuple[str, str] | tuple[str, str, bool]]", self.setting("EXTRA_DATA", []), ) for entry in extra_data_entries: if isinstance(entry, list): entry = tuple(cast("list[str]", entry)) discard = False if isinstance(entry, str): name = alias = entry elif len(entry) == 3: name, alias, discard = entry elif len(entry) == 2: name, alias = entry elif len(entry) == 1: name = alias = entry[0] else: raise AuthUnknownError(self, f"Invalid EXTRA_DATA item: {entry!r}") value = response.get(name, details.get(name, details.get(alias))) if discard and not value: continue data[alias] = value return data
[docs] def auth_allowed(self, response, details): """Return True if the user should be allowed to authenticate, by default check if email is whitelisted (if there's a whitelist)""" emails = [ email.lower() for email in cast("list[str]", self.setting("WHITELISTED_EMAILS", [])) ] domains = [ domain.lower() for domain in cast("list[str]", self.setting("WHITELISTED_DOMAINS", [])) ] email = details.get("email") allowed = True if email and (emails or domains): email = email.lower() parts = email.split("@", 1) if len(parts) != 2: allowed = False else: domain = parts[1] allowed = email in emails or domain in domains return allowed
[docs] def id_key(self) -> str: """Return the ID_KEY to use for this backend, checking settings first.""" return self.setting("ID_KEY") or self.ID_KEY
[docs] def get_user_id(self, details, response): """Return a unique ID for the current user, by default from server response or details.""" id_key = self.id_key() if details: user_id = details.get(id_key) if user_id: return user_id return response.get(id_key)
def get_user_details(self, response) -> dict[str, Any]: """Must return user details in a known internal struct: {'username': <username if any>, 'email': <user email if any>, 'fullname': <user full name if any>, 'first_name': <user first name if any>, 'last_name': <user last name if any>} """ raise NotImplementedError("Implement in subclass")
[docs] def get_user_names(self, fullname="", first_name="", last_name=""): # Avoid None values fullname = fullname or "" first_name = first_name or "" last_name = last_name or "" if fullname and not (first_name or last_name): try: first_name, last_name = fullname.split(" ", 1) except ValueError: first_name = first_name or fullname or "" last_name = last_name or "" fullname = fullname or f"{first_name} {last_name}" return fullname.strip(), first_name.strip(), last_name.strip()
[docs] def get_user(self, user_id): """ Return user with given ID from the User model used by this backend. This is called by django.contrib.auth.middleware. """ return self.strategy.get_user(user_id)
[docs] def continue_pipeline( self, partial: PartialMixin ) -> UserProtocol | HttpResponseProtocol | None: """Continue previous halted pipeline""" return self.strategy.authenticate( self, *partial.args, pipeline_index=partial.next_step, **partial.kwargs )
[docs] def auth_extra_arguments(self) -> dict[str, str]: """Return extra arguments needed on auth process. Configured AUTH_EXTRA_ARGUMENTS are not overridden by request data by default. Set AUTH_EXTRA_ARGUMENTS_OVERRIDE_ALLOWLIST to an iterable of configured extra-argument keys that may be replaced by matching request data values. """ extra_arguments = self.setting("AUTH_EXTRA_ARGUMENTS", {}).copy() override_allowlist = ( self.setting("AUTH_EXTRA_ARGUMENTS_OVERRIDE_ALLOWLIST", ()) or () ) if isinstance(override_allowlist, str): override_allowlist = (override_allowlist,) extra_arguments.update( (key, self.data[key]) for key in override_allowlist if key in extra_arguments and key in self.data ) return extra_arguments
[docs] def uses_redirect(self) -> bool: """Return True if this provider uses redirect url method, otherwise return false.""" return True
[docs] def request( # noqa: PLR0913 self, url: str, *, method: Literal["GET", "POST", "DELETE"] = "GET", headers: Mapping[str, str | bytes] | None = None, data: dict | None = None, json: dict | None = None, auth: tuple[str, str] | AuthBase | None = None, params: dict | None = None, timeout: float | None = None, ) -> Response: headers = {} if headers is None else dict(headers) proxies = self.setting("PROXIES") verify = self.setting("VERIFY_SSL", True) if timeout is None: timeout = ( self.setting("REQUESTS_TIMEOUT") or self.setting("URLOPEN_TIMEOUT") or 5.0 ) if self.SEND_USER_AGENT and "User-Agent" not in headers: headers["User-Agent"] = self.setting("USER_AGENT") or user_agent() try: response = requests.request( method, url, headers=headers, data=data, json=json, auth=auth, params=params, timeout=timeout, proxies=proxies, verify=verify, ) except requests.ConnectionError as err: raise AuthConnectionError(self, str(err)) from err response.raise_for_status() return response
[docs] def get_json( # noqa: PLR0913 self, url: str, method: Literal["GET", "POST", "DELETE"] = "GET", headers: Mapping[str, str | bytes] | None = None, data: dict | None = None, json: dict | None = None, auth: tuple[str, str] | AuthBase | None = None, params: dict | None = None, timeout: float | None = None, ) -> dict[Any, Any]: return self.request( url, method=method, headers=headers, data=data, json=json, auth=auth, params=params, timeout=timeout, ).json()
[docs] def get_querystring(self, url, *args, **kwargs) -> dict[str, str]: return parse_qs(self.request(url, *args, **kwargs).text)
[docs] def get_key_and_secret(self) -> tuple[str, str]: """Return tuple with Consumer Key and Consumer Secret for current service provider. Must return (key, secret), order *must* be respected. """ return cast("str", self.setting("KEY")), cast("str", self.setting("SECRET"))
[docs] def get_key_and_secret_basic_auth(self) -> bytes: """Generate HTTP Basic Authentication header value from KEY and SECRET. Returns: Basic authentication value in the format b"Basic <base64-encoded-credentials>" """ key, secret = self.get_key_and_secret() credentials = f"{key}:{secret}".encode() encoded = base64.b64encode(credentials) return b"Basic " + encoded