Source code for social_core.utils

from __future__ import annotations

import contextlib
import functools
import hmac
import logging
import re
import time
import unicodedata
from dataclasses import dataclass
from importlib import import_module
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import parse_qs as battery_parse_qs
from urllib.parse import unquote, urlencode, urlparse, urlunparse

import requests

import social_core
from social_core.pipeline.utils import is_dict_type, to_plain_dict

from .exceptions import (
    AuthCanceled,
    AuthForbidden,
    AuthTokenError,
    AuthUnreachableProvider,
)

if TYPE_CHECKING:
    from .backends.base import BaseAuth
    from .storage import PartialMixin, UserProtocol
    from .strategy import BaseStrategy, HttpResponseProtocol

SETTING_PREFIX = "SOCIAL_AUTH"

PARTIAL_TOKEN_SESSION_NAME = "partial_pipeline_token"
PARTIAL_TOKEN_PENDING_SESSION_NAME = "partial_pipeline_pending_token"
PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME = "partial_pipeline_pending_request"
PARTIAL_TOKEN_PENDING_CONFIRMATION_SESSION_NAME = (
    "partial_pipeline_pending_confirmation"
)
PARTIAL_PIPELINE_ALLOW_EXTERNAL_RESUME = "allow_external_resume"


social_logger = logging.getLogger("social")


[docs] @dataclass class PartialPipelineResult: partial: PartialMixin | None = None response: HttpResponseProtocol | None = None halt: bool = False
[docs] @dataclass class PartialPipelineSelection: token: str | None = None owns_token: bool = False pending_resume: bool = False
[docs] def module_member(name): mod, member = name.rsplit(".", 1) module = import_module(mod) return getattr(module, member)
[docs] def user_agent() -> str: """Builds a simple User-Agent string to send in requests""" return f"social-auth-{social_core.__version__}"
[docs] def url_add_parameters( url: str, params: dict[str, str] | None, _unquote_query: bool = False ) -> str: """Adds parameters to URL, parameter will be repeated if already present""" if params: fragments = list(urlparse(url)) value = parse_qs(fragments[4]) value.update(params) fragments[4] = urlencode(value) if _unquote_query: fragments[4] = unquote(fragments[4]) url = urlunparse(fragments) return url
[docs] def to_setting_name(*names: str) -> str: return "_".join([name.upper().replace("-", "_") for name in names if name])
[docs] def setting_name(*names: str) -> str: return to_setting_name(*((SETTING_PREFIX, *names)))
[docs] def sanitize_redirect(hosts: list[str], redirect_to: str | Any) -> str | None: """ Given a list of hostnames and an untrusted URL to redirect to, this method tests it to make sure it isn't garbage/harmful and returns it, else returns None, similar as how's it done on django.contrib.auth.views. """ # Avoid redirect on evil URLs like ///evil.com and URLs containing # backslashes or control characters that browsers may normalize. if ( not redirect_to or not isinstance(redirect_to, str) or redirect_to.startswith("///") or "\\" in redirect_to or any(unicodedata.category(char)[0] == "C" for char in redirect_to) ): return None try: parsed_url = urlparse(redirect_to) if parsed_url.scheme and parsed_url.scheme not in {"http", "https"}: return None if parsed_url.scheme and not parsed_url.netloc: return None # Don't redirect to a host that's not in the list netloc = parsed_url.netloc or hosts[0] except (IndexError, TypeError, AttributeError, ValueError): return None if netloc in hosts: return redirect_to return None
[docs] def user_is_authenticated(user: UserProtocol | None) -> bool: if user and hasattr(user, "is_authenticated"): if callable(user.is_authenticated): authenticated = user.is_authenticated() else: authenticated = user.is_authenticated elif user: authenticated = True else: authenticated = False return authenticated
[docs] def user_is_active(user: UserProtocol | None) -> bool: if user and hasattr(user, "is_active"): is_active = user.is_active() if callable(user.is_active) else user.is_active elif user: is_active = True else: is_active = False return is_active
# This slugify version was borrowed from django revision a61dbd6
[docs] def slugify(value): """Converts to lowercase, removes non-word characters (alphanumerics and underscores) and converts spaces to hyphens. Also strips leading and trailing whitespace.""" value = ( unicodedata.normalize("NFKD", str(value)) .encode("ascii", "ignore") .decode("ascii") ) value = re.sub(r"[^\w\s-]", "", value).strip().lower() return re.sub(r"[-\s]+", "-", value)
[docs] def first(func, items): """Return the first item in the list for what func returns True""" for item in items: if func(item): return item return None
[docs] def parse_qs(value): """Like urlparse.parse_qs but transform list values to single items""" return drop_lists(battery_parse_qs(value))
[docs] def get_querystring(url: str): return parse_qs(urlparse(url).query)
[docs] def drop_lists(value): out = {} for key, val in value.items(): val = val[0] if isinstance(key, bytes): key = str(key, "utf-8") if isinstance(val, bytes): val = str(val, "utf-8") out[key] = val return out
def _partial_pipeline_matches_request( backend: BaseAuth, partial: PartialMixin | None, request_data: dict[str, Any] ) -> bool: if not partial or partial.backend != backend.name: return False # Normally when resuming a pipeline, request_data will be empty. We only # need to check for a uid match if new data was provided (i.e. if current # request specifies the ID_KEY). id_key = backend.id_key() if id_key and id_key in request_data: id_from_partial = partial.kwargs.get("uid") id_from_request = request_data.get(id_key) return id_from_partial == id_from_request return True def _extend_partial_pipeline( partial: PartialMixin, request_data: dict[str, Any], user: UserProtocol | None, kwargs: dict[str, Any], ) -> PartialMixin: if user: # don't update user if it's None kwargs.setdefault("user", user) kwargs["request"] = request_data partial.extend_kwargs(kwargs) return partial def _select_partial_pipeline_token( request_token: str | None, session_token: str | None, pending_token: str | None, confirmation_requested: bool, ) -> PartialPipelineSelection: if request_token and request_token == session_token: return PartialPipelineSelection(token=request_token, owns_token=True) if confirmation_requested and pending_token: selected_token = request_token or pending_token pending_resume = selected_token == pending_token return PartialPipelineSelection( token=selected_token, owns_token=pending_resume, pending_resume=pending_resume, ) if request_token: return PartialPipelineSelection(token=request_token) return PartialPipelineSelection(token=session_token, owns_token=bool(session_token)) def _confirmed_partial_pipeline_request_data( backend: BaseAuth, request_data: dict[str, Any], ) -> dict[str, Any] | None: if not backend.strategy.partial_pipeline_external_resume_confirmed( backend, request_data ): return None pending_request_data = backend.strategy.from_session_value( backend.strategy.session_get(PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME, {}) or {} ) return {**pending_request_data, **request_data} def _external_partial_pipeline_result( backend: BaseAuth, partial: PartialMixin, selected_token: str, request_data: dict[str, Any], ) -> PartialPipelineResult: response = backend.strategy.partial_pipeline_external_resume_confirmation( backend, partial, request_data ) if response is None: return PartialPipelineResult(halt=True) backend.strategy.session_set(PARTIAL_TOKEN_PENDING_SESSION_NAME, selected_token) backend.strategy.session_set( PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME, backend.strategy.to_session_value( to_plain_dict(request_data) if is_dict_type(request_data) else request_data ), ) return PartialPipelineResult(response=response)
[docs] def partial_pipeline_result( backend: BaseAuth, user: UserProtocol | None = None, partial_token: str | None = None, *args, **kwargs, ) -> PartialPipelineResult: request_data = backend.strategy.request_data() partial_argument_name = backend.setting( "PARTIAL_PIPELINE_TOKEN_NAME", "partial_token" ) request_token = cast( "str | None", partial_token or request_data.get(partial_argument_name) ) session_token = backend.strategy.session_get(PARTIAL_TOKEN_SESSION_NAME, None) pending_token = backend.strategy.session_get( PARTIAL_TOKEN_PENDING_SESSION_NAME, None ) confirmation_parameter = backend.setting( "PARTIAL_PIPELINE_EXTERNAL_RESUME_CONFIRMATION_PARAMETER", "partial_pipeline_confirm", ) confirmation_requested = ( bool(confirmation_parameter) and confirmation_parameter in request_data ) selection = _select_partial_pipeline_token( request_token=request_token, session_token=session_token, pending_token=pending_token, confirmation_requested=confirmation_requested, ) if not selection.token: return PartialPipelineResult() result = PartialPipelineResult(halt=bool(request_token or confirmation_requested)) effective_request_data = request_data if selection.pending_resume: confirmed_request_data = _confirmed_partial_pipeline_request_data( backend, request_data ) if confirmed_request_data is None: return PartialPipelineResult(halt=True) effective_request_data = confirmed_request_data partial: PartialMixin | None = backend.strategy.partial_load(selection.token) partial_matches = _partial_pipeline_matches_request( backend, partial, effective_request_data ) if partial and partial_matches: if selection.owns_token: result = PartialPipelineResult( partial=_extend_partial_pipeline( partial, effective_request_data, user, kwargs ) ) elif partial.data.get(PARTIAL_PIPELINE_ALLOW_EXTERNAL_RESUME): result = _external_partial_pipeline_result( backend, partial, selection.token, effective_request_data ) else: result = PartialPipelineResult(halt=True) elif selection.owns_token: backend.strategy.clean_partial_pipeline(selection.token) return result
[docs] def partial_pipeline_data( backend: BaseAuth, user: UserProtocol | None = None, partial_token: str | None = None, *args, **kwargs, ) -> PartialMixin | None: return partial_pipeline_result( backend, user, partial_token, *args, **kwargs ).partial
[docs] def build_absolute_uri(host_url: str, path: str | None = None) -> str: """Build absolute URI with given (optional) path""" path = path or "" if path.startswith(("http://", "https://")): return path if host_url.endswith("/") and path.startswith("/"): path = path[1:] return host_url + path
[docs] def constant_time_compare(val1: str | bytes, val2: str | bytes) -> bool: """Compare two values and prevent timing attacks for cryptographic use.""" if isinstance(val1, str): val1 = val1.encode("utf-8") if isinstance(val2, str): val2 = val2.encode("utf-8") return hmac.compare_digest(val1, val2)
[docs] def is_url(value: str | None) -> bool: return value is not None and value.startswith(("http://", "https://", "/"))
[docs] def setting_url(backend: BaseAuth, *names: str | None) -> str | None: for name in names: # Name can actually None, value or setting name if not name: continue if is_url(name): return name value = backend.setting(name) if is_url(value): return value return None
[docs] def handle_http_errors(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except requests.HTTPError as err: social_logger.exception( "Request failed with %d: %s", err.response.status_code, err.response.text, ) if err.response.status_code == 400: raise AuthCanceled(args[0], response=err.response) from err if err.response.status_code == 401: raise AuthForbidden(args[0]) from err if err.response.status_code == 503: raise AuthUnreachableProvider(args[0]) from err raise return wrapper
[docs] @contextlib.contextmanager def wrap_access_token_error(backend: BaseAuth): try: yield except requests.HTTPError as error: if error.response.status_code == 401: raise AuthTokenError( backend, "Invalid key/secret, perhaps expired" ) from error raise
[docs] def append_slash(url: str) -> str: """Make sure we append a slash at the end of the URL otherwise we have issues with urljoin Example: >>> urlparse.urljoin('http://www.example.com/api/v3', 'user/1/') 'http://www.example.com/api/user/1/' """ if url and not url.endswith("/"): url = f"{url}/" return url
[docs] def get_strategy(strategy: str, storage: str, *args, **kwargs) -> BaseStrategy: Strategy = module_member(strategy) Storage = module_member(storage) return Strategy(Storage, *args, **kwargs)
[docs] class cache: """ Cache decorator that caches the return value of a method for a specified time. It maintains a cache per class and method arguments, so subclasses have a different cache entry for the same cached method. """ def __init__(self, ttl: int) -> None: self.ttl = ttl self.cache: dict[ tuple[type, tuple[Any, ...], tuple[tuple[str, Any], ...]], Any ] = {} def __call__(self, fn): def wrapped(this, *args, **kwargs): now = time.time() last_updated = None cached_value = None cache_key = (this.__class__, args, tuple(sorted(kwargs.items()))) if cache_key in self.cache: last_updated, cached_value = self.cache[cache_key] # ignoring this type issue is safe; if cached_value is returned, last_updated # is also set, but the type checker doesn't know it. if not cached_value or not last_updated or now - last_updated > self.ttl: try: cached_value = fn(this, *args, **kwargs) self.cache[cache_key] = (now, cached_value) # pylint: disable-next=broad-exception-caught except Exception: # Use previously cached value when call fails, if available if not cached_value: raise return cached_value cast("Any", wrapped).invalidate = self._invalidate return wrapped def _invalidate(self) -> None: self.cache.clear()