from __future__ import annotations
import base64
import time
from typing import TYPE_CHECKING, Any, Literal, cast
import requests
from social_core.exceptions import AuthConnectionError, AuthUnknownError
from social_core.registry import REGISTRY
from social_core.utils import module_member, parse_qs, social_logger, user_agent
if TYPE_CHECKING:
from collections.abc import Mapping
from requests import Response
from requests.auth import AuthBase
from social_core.storage import PartialMixin, PipelineUserProtocol, UserProtocol
from social_core.strategy import BaseStrategy, HttpResponseProtocol
[docs]
class BaseAuth:
"""A authentication backend that authenticates the user based on
the provider response"""
name = "" # provider name, it's stored in database
supports_inactive_user = False # Django auth
ID_KEY: str = ""
EXTRA_DATA: list[str | tuple[str, str] | tuple[str, str, bool]] | None = None
GET_ALL_EXTRA_DATA = False
REQUIRES_EMAIL_VALIDATION = False
SEND_USER_AGENT = True
def __init__(
self, strategy: BaseStrategy | None = None, redirect_uri: str | None = None
) -> None:
self.strategy: BaseStrategy = (
strategy if strategy is not None else REGISTRY.default_strategy
)
self.redirect_uri = redirect_uri
self.data = self.strategy.request_data()
self.redirect_uri = self.strategy.absolute_uri(self.redirect_uri)
[docs]
def log_debug(self, message, *args) -> None:
social_logger.debug(f"{self.name}: {message}", *args)
[docs]
def log_warning(self, message, *args) -> None:
social_logger.warning(f"{self.name}: {message}", *args)
[docs]
def setting(self, name: str, default=None):
"""Return setting value from strategy"""
return self.strategy.setting(name, default=default, backend=self)
[docs]
def start(self) -> HttpResponseProtocol:
if self.uses_redirect():
return self.strategy.redirect(self.auth_url())
return self.strategy.html(self.auth_html())
[docs]
def complete(self, *args, **kwargs) -> HttpResponseProtocol | UserProtocol | None:
return self.auth_complete(*args, **kwargs)
[docs]
def auth_url(self) -> str:
"""Must return redirect URL to auth provider"""
raise NotImplementedError("Implement in subclass")
[docs]
def auth_html(self) -> str:
"""Must return login HTML content returned by provider"""
return "Implement in subclass"
[docs]
def auth_complete(
self, *args, **kwargs
) -> HttpResponseProtocol | UserProtocol | None:
"""Completes login process, must return user instance"""
raise NotImplementedError("Implement in subclass")
[docs]
def process_error(self, data) -> None:
"""Hook to process provider response errors.
Default implementation is a no-op. Backends that can detect
provider-specific error payloads should override this method and
raise an appropriate exception when needed.
"""
[docs]
def authenticate(
self, *args, **kwargs
) -> UserProtocol | HttpResponseProtocol | None:
"""Authenticate user using social credentials
Authentication is made if this is the correct backend, backend
verification is made by kwargs inspection for current backend
name presence.
"""
# Validate backend and arguments. Require that the Social Auth
# response be passed in as a keyword argument, to make sure we
# don't match the username/password calling conventions of
# authenticate.
if (
"backend" not in kwargs
or kwargs["backend"].name != self.name
or "strategy" not in kwargs
or "response" not in kwargs
):
return None
self.strategy = kwargs.get("strategy") or self.strategy
self.redirect_uri = kwargs.get("redirect_uri") or self.redirect_uri
self.data = self.strategy.request_data()
kwargs.setdefault("is_new", False)
pipeline = self.strategy.get_pipeline(self)
args, kwargs = self.strategy.clean_authenticate_args(*args, **kwargs)
return self.pipeline(pipeline, *args, **kwargs)
[docs]
def pipeline(
self, pipeline, pipeline_index: int = 0, *args, **kwargs
) -> UserProtocol | HttpResponseProtocol | None:
out = self.run_pipeline(pipeline, pipeline_index, *args, **kwargs)
if not isinstance(out, dict):
return cast("HttpResponseProtocol", out)
user = cast("UserProtocol | None", out.get("user"))
if user:
pipeline_user = cast("PipelineUserProtocol", user)
pipeline_user.social_user = cast("Any", out.get("social"))
pipeline_user.is_new = bool(out.get("is_new"))
return user
[docs]
def disconnect(self, *args, **kwargs) -> dict:
pipeline = self.strategy.get_disconnect_pipeline(self)
kwargs["name"] = self.name
kwargs["user_storage"] = self.strategy.storage.user
return self.run_pipeline(pipeline, *args, **kwargs)
[docs]
def run_pipeline(
self, pipeline: list[str], pipeline_index=0, *args, **kwargs
) -> dict:
out = kwargs.copy()
out.setdefault("strategy", self.strategy)
out.setdefault("backend", out.pop(self.name, None) or self)
out.setdefault("request", self.strategy.request_data())
out.setdefault("details", {})
if (
not isinstance(pipeline_index, int)
or pipeline_index < 0
or pipeline_index >= len(pipeline)
):
pipeline_index = 0
for idx, name in enumerate(pipeline[pipeline_index:]):
out["pipeline_index"] = pipeline_index + idx
func = module_member(name)
result = func(*args, **out) or {}
if not isinstance(result, dict):
return result
out.update(result)
return out
[docs]
def auth_allowed(self, response, details):
"""Return True if the user should be allowed to authenticate, by
default check if email is whitelisted (if there's a whitelist)"""
emails = [
email.lower()
for email in cast("list[str]", self.setting("WHITELISTED_EMAILS", []))
]
domains = [
domain.lower()
for domain in cast("list[str]", self.setting("WHITELISTED_DOMAINS", []))
]
email = details.get("email")
allowed = True
if email and (emails or domains):
email = email.lower()
parts = email.split("@", 1)
if len(parts) != 2:
allowed = False
else:
domain = parts[1]
allowed = email in emails or domain in domains
return allowed
[docs]
def id_key(self) -> str:
"""Return the ID_KEY to use for this backend, checking settings first."""
return self.setting("ID_KEY") or self.ID_KEY
[docs]
def get_user_id(self, details, response):
"""Return a unique ID for the current user, by default from server
response or details."""
id_key = self.id_key()
if details:
user_id = details.get(id_key)
if user_id:
return user_id
return response.get(id_key)
def get_user_details(self, response) -> dict[str, Any]:
"""Must return user details in a known internal struct:
{'username': <username if any>,
'email': <user email if any>,
'fullname': <user full name if any>,
'first_name': <user first name if any>,
'last_name': <user last name if any>}
"""
raise NotImplementedError("Implement in subclass")
[docs]
def get_user_names(self, fullname="", first_name="", last_name=""):
# Avoid None values
fullname = fullname or ""
first_name = first_name or ""
last_name = last_name or ""
if fullname and not (first_name or last_name):
try:
first_name, last_name = fullname.split(" ", 1)
except ValueError:
first_name = first_name or fullname or ""
last_name = last_name or ""
fullname = fullname or f"{first_name} {last_name}"
return fullname.strip(), first_name.strip(), last_name.strip()
[docs]
def get_user(self, user_id):
"""
Return user with given ID from the User model used by this backend.
This is called by django.contrib.auth.middleware.
"""
return self.strategy.get_user(user_id)
[docs]
def continue_pipeline(
self, partial: PartialMixin
) -> UserProtocol | HttpResponseProtocol | None:
"""Continue previous halted pipeline"""
return self.strategy.authenticate(
self, *partial.args, pipeline_index=partial.next_step, **partial.kwargs
)
[docs]
def uses_redirect(self) -> bool:
"""Return True if this provider uses redirect url method,
otherwise return false."""
return True
[docs]
def request( # noqa: PLR0913
self,
url: str,
*,
method: Literal["GET", "POST", "DELETE"] = "GET",
headers: Mapping[str, str | bytes] | None = None,
data: dict | None = None,
json: dict | None = None,
auth: tuple[str, str] | AuthBase | None = None,
params: dict | None = None,
timeout: float | None = None,
) -> Response:
headers = {} if headers is None else dict(headers)
proxies = self.setting("PROXIES")
verify = self.setting("VERIFY_SSL", True)
if timeout is None:
timeout = (
self.setting("REQUESTS_TIMEOUT")
or self.setting("URLOPEN_TIMEOUT")
or 5.0
)
if self.SEND_USER_AGENT and "User-Agent" not in headers:
headers["User-Agent"] = self.setting("USER_AGENT") or user_agent()
try:
response = requests.request(
method,
url,
headers=headers,
data=data,
json=json,
auth=auth,
params=params,
timeout=timeout,
proxies=proxies,
verify=verify,
)
except requests.ConnectionError as err:
raise AuthConnectionError(self, str(err)) from err
response.raise_for_status()
return response
[docs]
def get_json( # noqa: PLR0913
self,
url: str,
method: Literal["GET", "POST", "DELETE"] = "GET",
headers: Mapping[str, str | bytes] | None = None,
data: dict | None = None,
json: dict | None = None,
auth: tuple[str, str] | AuthBase | None = None,
params: dict | None = None,
timeout: float | None = None,
) -> dict[Any, Any]:
return self.request(
url,
method=method,
headers=headers,
data=data,
json=json,
auth=auth,
params=params,
timeout=timeout,
).json()
[docs]
def get_querystring(self, url, *args, **kwargs) -> dict[str, str]:
return parse_qs(self.request(url, *args, **kwargs).text)
[docs]
def get_key_and_secret(self) -> tuple[str, str]:
"""Return tuple with Consumer Key and Consumer Secret for current
service provider. Must return (key, secret), order *must* be respected.
"""
return cast("str", self.setting("KEY")), cast("str", self.setting("SECRET"))
[docs]
def get_key_and_secret_basic_auth(self) -> bytes:
"""Generate HTTP Basic Authentication header value from KEY and SECRET.
Returns:
Basic authentication value in the format b"Basic <base64-encoded-credentials>"
"""
key, secret = self.get_key_and_secret()
credentials = f"{key}:{secret}".encode()
encoded = base64.b64encode(credentials)
return b"Basic " + encoded