Source code for social_core.storage

"""Models mixins for Social Auth"""

from __future__ import annotations

import base64
import re
import uuid
from abc import abstractmethod
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Protocol, cast

from openid.association import Association as OpenIdAssociation

from .exceptions import InvalidExpiryValue, MissingBackend

if TYPE_CHECKING:
    from collections.abc import Callable

    from social_core.backends.base import BaseAuth
    from social_core.strategy import BaseStrategy

NO_ASCII_REGEX = re.compile(r"[^\x00-\x7F]+")
NO_SPECIAL_REGEX = re.compile(r"[^\w.@+_-]+", re.UNICODE)


[docs] class UserProtocol(Protocol): @property def id(self, /) -> int: ... @property def username(self, /) -> str: ... @property def is_active(self, /) -> bool | Callable[[], bool]: ... @property def is_authenticated(self, /) -> bool | Callable[[], bool]: ...
[docs] class PipelineUserProtocol(UserProtocol, Protocol): # Set in BaseAuth.pipeline social_user: UserMixin | None is_new: bool
[docs] class UserMixin: # Consider tokens that expire in 5 seconds as already expired ACCESS_TOKEN_EXPIRED_THRESHOLD = 5 provider = "" uid: str user: UserProtocol extra_data: dict[str, Any]
[docs] @abstractmethod def save(self): ...
[docs] def get_backend(self, strategy: BaseStrategy) -> type[BaseAuth]: return strategy.get_backend_class(self.provider)
[docs] def get_backend_instance(self, strategy: BaseStrategy) -> BaseAuth | None: try: return strategy.get_backend(self.provider) except MissingBackend: return None
@property def access_token(self) -> str | None: """Return access_token stored in extra_data or None""" return self.extra_data.get("access_token")
[docs] def refresh_token(self, strategy: BaseStrategy, *args, **kwargs) -> None: token = self.extra_data.get("refresh_token") or self.extra_data.get( "access_token" ) backend = self.get_backend_instance(strategy) refresh_token = getattr(backend, "refresh_token", None) if backend else None if token and callable(refresh_token): assert backend is not None response = cast("dict[str, Any]", refresh_token(token, *args, **kwargs)) extra_data = backend.extra_data( self.user, self.uid, response, self.extra_data or {}, {} ) if self.set_extra_data(extra_data): self.save()
def _compute_expiration_from_timestamp( self, value: int | str, field_name: str = "expires" ) -> timedelta: """Compute expiration timedelta from an absolute timestamp.""" try: timestamp = int(value) except (ValueError, TypeError) as e: raise InvalidExpiryValue(field_name, value) from e try: now = datetime.now(timezone.utc) expiry_time = datetime.fromtimestamp(timestamp, tz=timezone.utc) return expiry_time - now except (OSError, ValueError) as e: raise InvalidExpiryValue(field_name, value) from e def _compute_expiration_from_relative( self, value: int | str, field_name: str = "expires" ) -> timedelta: """Compute expiration timedelta from relative seconds.""" try: seconds = int(value) except (ValueError, TypeError) as e: raise InvalidExpiryValue(field_name, value) from e auth_time = self.extra_data.get("auth_time") if auth_time: try: auth_timestamp = int(auth_time) except (ValueError, TypeError): # Invalid auth_time value, fall back to treating as seconds from now pass else: try: now = datetime.now(timezone.utc) reference = datetime.fromtimestamp(auth_timestamp, tz=timezone.utc) return (reference + timedelta(seconds=seconds)) - now except (OSError, ValueError): # auth_time timestamp out of range, fall back to treating as seconds from now pass # If no auth_time or invalid auth_time, treat as seconds from now return timedelta(seconds=seconds)
[docs] def expiration_timedelta(self) -> timedelta | None: """Return provider session live seconds. Returns a timedelta ready to use with session.set_expiry(). If provider returns a timestamp instead of session seconds to live, the timedelta is inferred from current time (using UTC timezone). Handles three types of expiration data: - expires_on: Always treated as absolute timestamp - expires_in: Always treated as relative seconds from auth_time - expires: Uses heuristic (>63072000 = 2 years) to distinguish timestamp vs relative """ if not self.extra_data: return None # Check for expires_on (absolute timestamp) expires_on = self.extra_data.get("expires_on") if expires_on is not None: return self._compute_expiration_from_timestamp(expires_on, "expires_on") # Check for expires_in (relative seconds from auth_time) expires_in = self.extra_data.get("expires_in") if expires_in is not None: return self._compute_expiration_from_relative(expires_in, "expires_in") # Check for expires (use heuristic to determine type) return self._handle_expires_field()
def _handle_expires_field(self) -> timedelta | None: """Handle the generic expires field using heuristic.""" expires = self.extra_data.get("expires") if expires is None: return None try: expires_int = int(expires) except (ValueError, TypeError) as e: raise InvalidExpiryValue("expires", expires) from e # Use 2 years (63072000 seconds) as threshold to distinguish # absolute timestamps from relative seconds # Most tokens expire in hours/days/months, timestamps are much larger TIMESTAMP_THRESHOLD = 63072000 if expires_int > TIMESTAMP_THRESHOLD: # Likely an absolute timestamp, try treating as expires_on return self._compute_expiration_from_timestamp(expires_int, "expires") # Treat as relative seconds (like expires_in) return self._compute_expiration_from_relative(expires_int, "expires")
[docs] def expiration_datetime(self): # backward compatible alias return self.expiration_timedelta()
[docs] def access_token_expired(self): """Return true / false if access token is already expired""" expiration = self.expiration_timedelta() return ( expiration and expiration.total_seconds() <= self.ACCESS_TOKEN_EXPIRED_THRESHOLD )
[docs] def get_access_token(self, strategy: BaseStrategy) -> str | None: """Returns a valid access token.""" if self.access_token_expired(): self.refresh_token(strategy) return self.access_token
[docs] def set_extra_data(self, extra_data: dict[str, Any] | None = None) -> bool: if extra_data and self.extra_data != extra_data: if self.extra_data and not isinstance(self.extra_data, str): self.extra_data.update(extra_data) else: self.extra_data = extra_data return True return False
[docs] @classmethod def clean_username(cls, value: str) -> str: """Clean username removing any unsupported character""" value = NO_ASCII_REGEX.sub("", value) return NO_SPECIAL_REGEX.sub("", value)
[docs] @classmethod def changed(cls, user: UserProtocol) -> None: """The given user instance is ready to be saved""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get_username(cls, user: UserProtocol) -> str: """Return the username for given user""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def user_model(cls) -> type[UserProtocol]: """Return the user model""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def username_max_length(cls) -> int: """Return the max length for username""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def allowed_to_disconnect( cls, user: UserProtocol, backend_name: str, association_id=None ) -> bool: """Return if it's safe to disconnect the social account for the given user""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def disconnect(cls, entry): """Disconnect the social account for the given user""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def user_exists(cls, *args, **kwargs) -> bool: """ Return True/False if a User instance exists with the given arguments. Arguments are directly passed to filter() manager method. """ raise NotImplementedError("Implement in subclass")
[docs] @classmethod def create_user(cls, *args, **kwargs): """Create a user instance""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get_user(cls, pk): """Return user instance for given id""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get_users_by_email(cls, email: str): """Return users instances for given email address""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get_social_auth(cls, provider: str, uid: str): """Return UserSocialAuth for given provider and uid""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get_social_auth_for_user( cls, user: UserProtocol, provider: str | None = None, # pylint: disable-next=redefined-builtin id: int | None = None, # noqa: A002 ): """Return all the UserSocialAuth instances for given user""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def create_social_auth(cls, user: UserProtocol, uid: str, provider: str): """Create a UserSocialAuth instance for given user""" raise NotImplementedError("Implement in subclass")
[docs] class NonceMixin: """One use numbers""" server_url = "" timestamp = 0 salt = ""
[docs] @classmethod def use(cls, server_url: str, timestamp, salt: str): """Create a Nonce instance""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get(cls, server_url: str, salt: str): """Retrieve a Nonce instance""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def delete(cls, nonce): """Delete a Nonce instance""" raise NotImplementedError("Implement in subclass")
[docs] class AssociationMixin: """OpenId account association""" server_url = "" handle = "" secret: str | bytes = "" issued = 0 lifetime = 0 assoc_type = ""
[docs] @classmethod def oids(cls, server_url, handle=None): kwargs = {"server_url": server_url} if handle is not None: kwargs["handle"] = handle return sorted( ((assoc.id, cls.openid_association(assoc)) for assoc in cls.get(**kwargs)), key=lambda x: x[1].issued, reverse=True, )
[docs] @classmethod def openid_association(cls, assoc): secret = assoc.secret if not isinstance(secret, bytes): secret = secret.encode() return OpenIdAssociation( assoc.handle, base64.decodebytes(secret), assoc.issued, assoc.lifetime, assoc.assoc_type, )
[docs] @classmethod def store(cls, server_url, association): """Create an Association instance""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def get(cls, server_url: str | None = None, handle: str | None = None): """Get an Association instance""" raise NotImplementedError("Implement in subclass")
[docs] @classmethod def remove(cls, ids_to_delete): """Remove an Association instance""" raise NotImplementedError("Implement in subclass")
[docs] class CodeMixin: email = "" code = "" verified = False
[docs] @abstractmethod def save(self): ...
[docs] def verify(self) -> None: self.verified = True self.save()
[docs] @classmethod def generate_code(cls): return uuid.uuid4().hex
[docs] @classmethod def make_code(cls, email: str) -> CodeMixin: code = cls() code.email = email code.code = cls.generate_code() code.verified = False code.save() return code
[docs] @classmethod def get_code(cls, code): raise NotImplementedError("Implement in subclass")
[docs] class PartialMixin: token = "" data: dict[str, Any] = {} next_step: int backend = "" @property def args(self): return self.data.get("args", []) @args.setter def args(self, value) -> None: self.data["args"] = value
[docs] @abstractmethod def save(self): ...
@property def kwargs(self): return self.data.get("kwargs", {}) @kwargs.setter def kwargs(self, value) -> None: self.data["kwargs"] = value
[docs] def extend_kwargs(self, values) -> None: self.data["kwargs"].update(values)
[docs] @classmethod def generate_token(cls) -> str: return uuid.uuid4().hex
[docs] @classmethod def load(cls, token: str) -> PartialMixin | None: raise NotImplementedError("Implement in subclass")
[docs] @classmethod def destroy(cls, token: str): raise NotImplementedError("Implement in subclass")
[docs] @classmethod def prepare( cls, backend: str, next_step: int, data: dict[str, Any] ) -> PartialMixin: partial = cls() partial.backend = backend partial.next_step = next_step partial.data = data partial.token = cls.generate_token() return partial
[docs] @classmethod def store(cls, partial: PartialMixin) -> PartialMixin: partial.save() return partial
[docs] class BaseStorage: user = UserMixin nonce = NonceMixin association = AssociationMixin code = CodeMixin partial = PartialMixin
[docs] @classmethod def is_integrity_error(cls, exception) -> bool: """Check if given exception flags an integrity error in the DB""" raise NotImplementedError("Implement in subclass")