Source code for social_core.strategy

from __future__ import annotations

import secrets
from typing import TYPE_CHECKING, Any, Protocol, cast

from .backends.utils import get_backend
from .exceptions import (
    SocialAuthImproperlyConfiguredError,
    StrategyMissingBackendError,
    StrategyMissingFeatureError,
)
from .pipeline import DEFAULT_AUTH_PIPELINE, DEFAULT_DISCONNECT_PIPELINE
from .pipeline.utils import partial_load
from .store import OpenIdSessionWrapper, OpenIdStore
from .utils import (
    PARTIAL_TOKEN_PENDING_CONFIRMATION_SESSION_NAME,
    PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME,
    PARTIAL_TOKEN_PENDING_SESSION_NAME,
    PARTIAL_TOKEN_SESSION_NAME,
    module_member,
    setting_name,
)

if TYPE_CHECKING:
    from .backends.base import BaseAuth
    from .storage import BaseStorage, CodeMixin, PartialMixin, UserProtocol


[docs] class HttpResponseProtocol(Protocol): @property def url(self) -> str: ...
[docs] class BaseTemplateStrategy: def __init__(self, strategy) -> None: self.strategy = strategy
[docs] def render( self, tpl: str | None = None, html: str | None = None, context: dict[str, Any] | None = None, ) -> str: context = context or {} if tpl: return self.render_template(tpl, context) if not html: raise ValueError("Missing template or html parameters") return self.render_string(html, context)
[docs] def render_template(self, tpl: str, context: dict[str, Any] | None) -> str: raise NotImplementedError("Implement in subclass")
[docs] def render_string(self, html: str, context: dict[str, Any] | None) -> str: raise NotImplementedError("Implement in subclass")
[docs] class BaseStrategy: ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" DEFAULT_TEMPLATE_STRATEGY = BaseTemplateStrategy SESSION_SAVE_KEY = "psa_session_id" def __init__( self, storage: type[BaseStorage] | None = None, tpl: type[BaseTemplateStrategy] | None = None, ) -> None: self._storage = storage self.tpl = (tpl or self.DEFAULT_TEMPLATE_STRATEGY)(self) @property def storage(self) -> type[BaseStorage]: if self._storage is None: raise StrategyMissingBackendError return self._storage
[docs] def setting(self, name: str, default=None, backend: BaseAuth | None = None): names = [setting_name(name), name] if backend: names.insert(0, setting_name(backend.name, name)) for value in names: try: return self.get_setting(value) except (AttributeError, KeyError): pass return default
[docs] def create_user(self, *args, **kwargs): return self.storage.user.create_user(*args, **kwargs)
[docs] def get_user(self, *args, **kwargs): return self.storage.user.get_user(*args, **kwargs)
[docs] def session_setdefault(self, name: str, value): self.session_set(name, value) return self.session_get(name)
[docs] def get_session_id(self) -> str | None: """ Return session ID to be used by restore_session. """ return None
[docs] def restore_session(self, session_id: str, kwargs: dict[str, Any]) -> None: """ Restores session and updates kwargs to match it. This is only called if get_session_id returns a value. """ raise StrategyMissingFeatureError(self.__class__.__name__, "session restore")
[docs] def openid_session_dict(self, name: str) -> OpenIdSessionWrapper: # Many frameworks are switching the session serialization from Pickle # to JSON to avoid code execution risks. Flask did this from Flask # 0.10, Django is switching to JSON by default from version 1.6. # # Sadly python-openid stores classes instances in the session which # fails the JSON serialization, the classes are: # # openid.yadis.manager.YadisServiceManager # openid.consumer.discover.OpenIDServiceEndpoint # # This method will return a wrapper over the session value used with # openid (a dict) which will automatically keep a pickled value for the # mentioned classes. return OpenIdSessionWrapper(self.session_setdefault(name, {}))
[docs] def to_session_value(self, val): return val
[docs] def from_session_value(self, val): return val
[docs] def partial_load(self, token: str) -> PartialMixin | None: return partial_load(self, token)
[docs] def clean_partial_pipeline(self, token) -> None: self.storage.partial.destroy(token) current_token_in_session = self.session_get(PARTIAL_TOKEN_SESSION_NAME) if current_token_in_session == token: self.session_pop(PARTIAL_TOKEN_SESSION_NAME) pending_token_in_session = self.session_get(PARTIAL_TOKEN_PENDING_SESSION_NAME) if pending_token_in_session == token: self.session_pop(PARTIAL_TOKEN_PENDING_SESSION_NAME) self.session_pop(PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME) self.session_pop(PARTIAL_TOKEN_PENDING_CONFIRMATION_SESSION_NAME)
[docs] def partial_pipeline_external_resume_confirmation( self, backend: BaseAuth, partial: PartialMixin, request_data: dict[str, Any], ) -> HttpResponseProtocol | None: return None
[docs] def partial_pipeline_external_resume_confirmed( self, backend: BaseAuth, request_data: dict[str, Any], ) -> bool: return False
[docs] def openid_store(self) -> OpenIdStore: return OpenIdStore(self)
[docs] def get_pipeline(self, backend: BaseAuth | None = None) -> list[str]: return cast( "list[str]", self.setting("PIPELINE", DEFAULT_AUTH_PIPELINE, backend) )
[docs] def get_disconnect_pipeline(self, backend: BaseAuth | None = None) -> list[str]: return cast( "list[str]", self.setting("DISCONNECT_PIPELINE", DEFAULT_DISCONNECT_PIPELINE, backend), )
[docs] def random_string(self, length: int = 12, chars: str = ALLOWED_CHARS) -> str: return "".join([secrets.choice(chars) for i in range(length)])
[docs] def absolute_uri(self, path: str | None = None) -> str: uri = self.build_absolute_uri(path) if uri and self.setting("REDIRECT_IS_HTTPS"): uri = uri.replace("http://", "https://") return uri
[docs] def get_language(self) -> str: """Return current language""" return ""
[docs] def send_email_validation( self, backend: BaseAuth, email: str, partial_token: str | None = None ) -> CodeMixin: email_validation = self.setting("EMAIL_VALIDATION_FUNCTION") if not email_validation: raise SocialAuthImproperlyConfiguredError( "EMAIL_VALIDATION_FUNCTION missing" ) send_email = module_member(email_validation) code = self.storage.code.make_code(email) send_email(self, backend, code, partial_token) return code
[docs] def validate_email(self, email: str, code: str) -> bool: verification_code = self.storage.code.get_code(code) if not verification_code or verification_code.code != code: return False if verification_code.email != email: return False if verification_code.verified: return False verification_code.verify() return True
[docs] def render_html( self, tpl: str | None = None, html: str | None = None, context: dict[str, Any] | None = None, ) -> str: """Render given template or raw html with given context""" return self.tpl.render(tpl, html, context)
[docs] def authenticate( self, backend: BaseAuth, *args, **kwargs ) -> UserProtocol | HttpResponseProtocol | None: """Trigger the authentication mechanism tied to the current framework""" kwargs["strategy"] = self kwargs["storage"] = self.storage kwargs["backend"] = backend args, kwargs = self.clean_authenticate_args(*args, **kwargs) return backend.authenticate(*args, **kwargs)
[docs] def clean_authenticate_args(self, *args, **kwargs): """Take authenticate arguments and return a "cleaned" version of them""" return args, kwargs
[docs] def get_backends(self) -> list[str]: """Return configured backends""" return cast("list[str]", self.setting("AUTHENTICATION_BACKENDS", []))
[docs] def get_backend_class(self, name: str) -> type[BaseAuth]: """Return a configured backend class""" return get_backend(self.get_backends(), name)
[docs] def get_backend( self, name: str, redirect_uri: str | None = None, **kwargs ) -> BaseAuth: """Return a configured backend instance""" backend_class = self.get_backend_class(name) kwargs["redirect_uri"] = redirect_uri return backend_class(self, **kwargs)
# Implement the following methods on strategies sub-classes
[docs] def redirect(self, url: str) -> HttpResponseProtocol: """Return a response redirect to the given URL""" raise NotImplementedError("Implement in subclass")
[docs] def get_setting(self, name: str): """Return value for given setting name""" raise NotImplementedError("Implement in subclass")
[docs] def html(self, content: str) -> HttpResponseProtocol: """Return HTTP response with given content""" raise NotImplementedError("Implement in subclass")
[docs] def request_data(self, merge: bool = True): """Return current request data (POST or GET)""" raise NotImplementedError("Implement in subclass")
[docs] def request_host(self) -> str: """Return current host value""" raise NotImplementedError("Implement in subclass")
[docs] def session_get(self, name: str, default=None): """Return session value for given key""" raise NotImplementedError("Implement in subclass")
[docs] def session_set(self, name: str, value): """Set session value for given key""" raise NotImplementedError("Implement in subclass")
[docs] def session_pop(self, name: str): """Pop session value for given key""" raise NotImplementedError("Implement in subclass")
[docs] def build_absolute_uri(self, path: str | None = None) -> str: """Build absolute URI with given (optional) path""" raise NotImplementedError("Implement in subclass")
[docs] def request_is_secure(self) -> bool: """Is the request using HTTPS?""" raise NotImplementedError("Implement in subclass")
[docs] def request_path(self) -> str: """path of the current request""" raise NotImplementedError("Implement in subclass")
[docs] def request_port(self) -> int: """Port in use for this request""" raise NotImplementedError("Implement in subclass")
[docs] def request_get(self): """Request GET data""" raise NotImplementedError("Implement in subclass")
[docs] def request_post(self): """Request POST data""" raise NotImplementedError("Implement in subclass")