from __future__ import annotations
import secrets
from typing import TYPE_CHECKING, Any, Protocol, cast
from .backends.utils import get_backend
from .exceptions import (
SocialAuthImproperlyConfiguredError,
StrategyMissingBackendError,
StrategyMissingFeatureError,
)
from .pipeline import DEFAULT_AUTH_PIPELINE, DEFAULT_DISCONNECT_PIPELINE
from .pipeline.utils import partial_load
from .store import OpenIdSessionWrapper, OpenIdStore
from .utils import (
PARTIAL_TOKEN_PENDING_CONFIRMATION_SESSION_NAME,
PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME,
PARTIAL_TOKEN_PENDING_SESSION_NAME,
PARTIAL_TOKEN_SESSION_NAME,
module_member,
setting_name,
)
if TYPE_CHECKING:
from .backends.base import BaseAuth
from .storage import BaseStorage, CodeMixin, PartialMixin, UserProtocol
[docs]
class HttpResponseProtocol(Protocol):
@property
def url(self) -> str: ...
[docs]
class BaseTemplateStrategy:
def __init__(self, strategy) -> None:
self.strategy = strategy
[docs]
def render(
self,
tpl: str | None = None,
html: str | None = None,
context: dict[str, Any] | None = None,
) -> str:
context = context or {}
if tpl:
return self.render_template(tpl, context)
if not html:
raise ValueError("Missing template or html parameters")
return self.render_string(html, context)
[docs]
def render_template(self, tpl: str, context: dict[str, Any] | None) -> str:
raise NotImplementedError("Implement in subclass")
[docs]
def render_string(self, html: str, context: dict[str, Any] | None) -> str:
raise NotImplementedError("Implement in subclass")
[docs]
class BaseStrategy:
ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
DEFAULT_TEMPLATE_STRATEGY = BaseTemplateStrategy
SESSION_SAVE_KEY = "psa_session_id"
def __init__(
self,
storage: type[BaseStorage] | None = None,
tpl: type[BaseTemplateStrategy] | None = None,
) -> None:
self._storage = storage
self.tpl = (tpl or self.DEFAULT_TEMPLATE_STRATEGY)(self)
@property
def storage(self) -> type[BaseStorage]:
if self._storage is None:
raise StrategyMissingBackendError
return self._storage
[docs]
def setting(self, name: str, default=None, backend: BaseAuth | None = None):
names = [setting_name(name), name]
if backend:
names.insert(0, setting_name(backend.name, name))
for value in names:
try:
return self.get_setting(value)
except (AttributeError, KeyError):
pass
return default
[docs]
def create_user(self, *args, **kwargs):
return self.storage.user.create_user(*args, **kwargs)
[docs]
def get_user(self, *args, **kwargs):
return self.storage.user.get_user(*args, **kwargs)
[docs]
def session_setdefault(self, name: str, value):
self.session_set(name, value)
return self.session_get(name)
[docs]
def get_session_id(self) -> str | None:
"""
Return session ID to be used by restore_session.
"""
return None
[docs]
def restore_session(self, session_id: str, kwargs: dict[str, Any]) -> None:
"""
Restores session and updates kwargs to match it.
This is only called if get_session_id returns a value.
"""
raise StrategyMissingFeatureError(self.__class__.__name__, "session restore")
[docs]
def openid_session_dict(self, name: str) -> OpenIdSessionWrapper:
# Many frameworks are switching the session serialization from Pickle
# to JSON to avoid code execution risks. Flask did this from Flask
# 0.10, Django is switching to JSON by default from version 1.6.
#
# Sadly python-openid stores classes instances in the session which
# fails the JSON serialization, the classes are:
#
# openid.yadis.manager.YadisServiceManager
# openid.consumer.discover.OpenIDServiceEndpoint
#
# This method will return a wrapper over the session value used with
# openid (a dict) which will automatically keep a pickled value for the
# mentioned classes.
return OpenIdSessionWrapper(self.session_setdefault(name, {}))
[docs]
def to_session_value(self, val):
return val
[docs]
def from_session_value(self, val):
return val
[docs]
def partial_load(self, token: str) -> PartialMixin | None:
return partial_load(self, token)
[docs]
def clean_partial_pipeline(self, token) -> None:
self.storage.partial.destroy(token)
current_token_in_session = self.session_get(PARTIAL_TOKEN_SESSION_NAME)
if current_token_in_session == token:
self.session_pop(PARTIAL_TOKEN_SESSION_NAME)
pending_token_in_session = self.session_get(PARTIAL_TOKEN_PENDING_SESSION_NAME)
if pending_token_in_session == token:
self.session_pop(PARTIAL_TOKEN_PENDING_SESSION_NAME)
self.session_pop(PARTIAL_TOKEN_PENDING_REQUEST_SESSION_NAME)
self.session_pop(PARTIAL_TOKEN_PENDING_CONFIRMATION_SESSION_NAME)
[docs]
def partial_pipeline_external_resume_confirmation(
self,
backend: BaseAuth,
partial: PartialMixin,
request_data: dict[str, Any],
) -> HttpResponseProtocol | None:
return None
[docs]
def partial_pipeline_external_resume_confirmed(
self,
backend: BaseAuth,
request_data: dict[str, Any],
) -> bool:
return False
[docs]
def openid_store(self) -> OpenIdStore:
return OpenIdStore(self)
[docs]
def get_pipeline(self, backend: BaseAuth | None = None) -> list[str]:
return cast(
"list[str]", self.setting("PIPELINE", DEFAULT_AUTH_PIPELINE, backend)
)
[docs]
def get_disconnect_pipeline(self, backend: BaseAuth | None = None) -> list[str]:
return cast(
"list[str]",
self.setting("DISCONNECT_PIPELINE", DEFAULT_DISCONNECT_PIPELINE, backend),
)
[docs]
def random_string(self, length: int = 12, chars: str = ALLOWED_CHARS) -> str:
return "".join([secrets.choice(chars) for i in range(length)])
[docs]
def absolute_uri(self, path: str | None = None) -> str:
uri = self.build_absolute_uri(path)
if uri and self.setting("REDIRECT_IS_HTTPS"):
uri = uri.replace("http://", "https://")
return uri
[docs]
def get_language(self) -> str:
"""Return current language"""
return ""
[docs]
def send_email_validation(
self, backend: BaseAuth, email: str, partial_token: str | None = None
) -> CodeMixin:
email_validation = self.setting("EMAIL_VALIDATION_FUNCTION")
if not email_validation:
raise SocialAuthImproperlyConfiguredError(
"EMAIL_VALIDATION_FUNCTION missing"
)
send_email = module_member(email_validation)
code = self.storage.code.make_code(email)
send_email(self, backend, code, partial_token)
return code
[docs]
def validate_email(self, email: str, code: str) -> bool:
verification_code = self.storage.code.get_code(code)
if not verification_code or verification_code.code != code:
return False
if verification_code.email != email:
return False
if verification_code.verified:
return False
verification_code.verify()
return True
[docs]
def render_html(
self,
tpl: str | None = None,
html: str | None = None,
context: dict[str, Any] | None = None,
) -> str:
"""Render given template or raw html with given context"""
return self.tpl.render(tpl, html, context)
[docs]
def authenticate(
self, backend: BaseAuth, *args, **kwargs
) -> UserProtocol | HttpResponseProtocol | None:
"""Trigger the authentication mechanism tied to the current
framework"""
kwargs["strategy"] = self
kwargs["storage"] = self.storage
kwargs["backend"] = backend
args, kwargs = self.clean_authenticate_args(*args, **kwargs)
return backend.authenticate(*args, **kwargs)
[docs]
def clean_authenticate_args(self, *args, **kwargs):
"""Take authenticate arguments and return a "cleaned" version
of them"""
return args, kwargs
[docs]
def get_backends(self) -> list[str]:
"""Return configured backends"""
return cast("list[str]", self.setting("AUTHENTICATION_BACKENDS", []))
[docs]
def get_backend_class(self, name: str) -> type[BaseAuth]:
"""Return a configured backend class"""
return get_backend(self.get_backends(), name)
[docs]
def get_backend(
self, name: str, redirect_uri: str | None = None, **kwargs
) -> BaseAuth:
"""Return a configured backend instance"""
backend_class = self.get_backend_class(name)
kwargs["redirect_uri"] = redirect_uri
return backend_class(self, **kwargs)
# Implement the following methods on strategies sub-classes
[docs]
def redirect(self, url: str) -> HttpResponseProtocol:
"""Return a response redirect to the given URL"""
raise NotImplementedError("Implement in subclass")
[docs]
def get_setting(self, name: str):
"""Return value for given setting name"""
raise NotImplementedError("Implement in subclass")
[docs]
def html(self, content: str) -> HttpResponseProtocol:
"""Return HTTP response with given content"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_data(self, merge: bool = True):
"""Return current request data (POST or GET)"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_host(self) -> str:
"""Return current host value"""
raise NotImplementedError("Implement in subclass")
[docs]
def session_get(self, name: str, default=None):
"""Return session value for given key"""
raise NotImplementedError("Implement in subclass")
[docs]
def session_set(self, name: str, value):
"""Set session value for given key"""
raise NotImplementedError("Implement in subclass")
[docs]
def session_pop(self, name: str):
"""Pop session value for given key"""
raise NotImplementedError("Implement in subclass")
[docs]
def build_absolute_uri(self, path: str | None = None) -> str:
"""Build absolute URI with given (optional) path"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_is_secure(self) -> bool:
"""Is the request using HTTPS?"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_path(self) -> str:
"""path of the current request"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_port(self) -> int:
"""Port in use for this request"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_get(self):
"""Request GET data"""
raise NotImplementedError("Implement in subclass")
[docs]
def request_post(self):
"""Request POST data"""
raise NotImplementedError("Implement in subclass")