Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-oauth/plain/oauth/providers.py: 82%
111 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:04 -0500
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-16 22:04 -0500
1import datetime
2import secrets
3from typing import Any
4from urllib.parse import urlencode
6from plain.auth import login as auth_login
7from plain.http import HttpRequest, Response, ResponseRedirect
8from plain.runtime import settings
9from plain.urls import reverse
10from plain.utils.crypto import get_random_string
11from plain.utils.module_loading import import_string
13from .exceptions import OAuthError, OAuthStateMismatchError
14from .models import OAuthConnection
16SESSION_STATE_KEY = "plainoauth_state"
17SESSION_NEXT_KEY = "plainoauth_next"
20class OAuthToken:
21 def __init__(
22 self,
23 *,
24 access_token: str,
25 refresh_token: str = "",
26 access_token_expires_at: datetime.datetime = None,
27 refresh_token_expires_at: datetime.datetime = None,
28 ):
29 self.access_token = access_token
30 self.refresh_token = refresh_token
31 self.access_token_expires_at = access_token_expires_at
32 self.refresh_token_expires_at = refresh_token_expires_at
35class OAuthUser:
36 def __init__(self, *, id: str, **user_model_fields: dict):
37 self.id = id # ID on the provider's system
38 self.user_model_fields = user_model_fields
40 def __str__(self):
41 if "email" in self.user_model_fields:
42 return self.user_model_fields["email"]
43 if "username" in self.user_model_fields:
44 return self.user_model_fields["username"]
45 return str(self.id)
48class OAuthProvider:
49 authorization_url = ""
51 def __init__(
52 self,
53 *,
54 # Provided automatically
55 provider_key: str,
56 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
57 client_id: str,
58 client_secret: str,
59 # Not necessarily required, but commonly used
60 scope: str = "",
61 ):
62 self.provider_key = provider_key
63 self.client_id = client_id
64 self.client_secret = client_secret
65 self.scope = scope
67 def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
68 return {
69 "redirect_uri": self.get_callback_url(request=request),
70 "client_id": self.get_client_id(),
71 "scope": self.get_scope(),
72 "state": self.generate_state(),
73 "response_type": "code",
74 }
76 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
77 raise NotImplementedError()
79 def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
80 raise NotImplementedError()
82 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
83 raise NotImplementedError()
85 def get_authorization_url(self, *, request: HttpRequest) -> str:
86 return self.authorization_url
88 def get_client_id(self) -> str:
89 return self.client_id
91 def get_client_secret(self) -> str:
92 return self.client_secret
94 def get_scope(self) -> str:
95 return self.scope
97 def get_callback_url(self, *, request: HttpRequest) -> str:
98 url = reverse("oauth:callback", kwargs={"provider": self.provider_key})
99 return request.build_absolute_uri(url)
101 def generate_state(self) -> str:
102 return get_random_string(length=32)
104 def check_request_state(self, *, request: HttpRequest) -> None:
105 if error := request.GET.get("error"):
106 raise OAuthError(error)
108 state = request.GET["state"]
109 expected_state = request.session.pop(SESSION_STATE_KEY)
110 if not secrets.compare_digest(state, expected_state):
111 raise OAuthStateMismatchError()
113 def handle_login_request(
114 self, *, request: HttpRequest, redirect_to: str = ""
115 ) -> Response:
116 authorization_url = self.get_authorization_url(request=request)
117 authorization_params = self.get_authorization_url_params(request=request)
119 if "state" in authorization_params:
120 # Store the state in the session so we can check on callback
121 request.session[SESSION_STATE_KEY] = authorization_params["state"]
123 # Store next url in session so we can get it on the callback request
124 if redirect_to:
125 request.session[SESSION_NEXT_KEY] = redirect_to
126 elif "next" in request.POST:
127 request.session[SESSION_NEXT_KEY] = request.POST["next"]
129 # Sort authorization params for consistency
130 sorted_authorization_params = sorted(authorization_params.items())
131 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
132 return ResponseRedirect(redirect_url)
134 def handle_connect_request(
135 self, *, request: HttpRequest, redirect_to: str = ""
136 ) -> Response:
137 return self.handle_login_request(request=request, redirect_to=redirect_to)
139 def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
140 provider_user_id = request.POST["provider_user_id"]
141 connection = OAuthConnection.objects.get(
142 provider_key=self.provider_key, provider_user_id=provider_user_id
143 )
144 connection.delete()
145 redirect_url = self.get_disconnect_redirect_url(request=request)
146 return ResponseRedirect(redirect_url)
148 def handle_callback_request(self, *, request: HttpRequest) -> Response:
149 self.check_request_state(request=request)
151 oauth_token = self.get_oauth_token(code=request.GET["code"], request=request)
152 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
154 if request.user:
155 connection = OAuthConnection.connect(
156 user=request.user,
157 provider_key=self.provider_key,
158 oauth_token=oauth_token,
159 oauth_user=oauth_user,
160 )
161 user = connection.user
162 else:
163 connection = OAuthConnection.get_or_create_user(
164 provider_key=self.provider_key,
165 oauth_token=oauth_token,
166 oauth_user=oauth_user,
167 )
169 user = connection.user
171 self.login(request=request, user=user)
173 redirect_url = self.get_login_redirect_url(request=request)
174 return ResponseRedirect(redirect_url)
176 def login(self, *, request: HttpRequest, user: Any) -> Response:
177 auth_login(request=request, user=user)
179 def get_login_redirect_url(self, *, request: HttpRequest) -> str:
180 return request.session.pop(SESSION_NEXT_KEY, "/")
182 def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
183 return request.POST.get("next", "/")
186def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
187 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
188 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
189 provider_class = import_string(provider_class_path)
190 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
191 return provider_class(provider_key=provider_key, **provider_kwargs)
194def get_provider_keys() -> list[str]:
195 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())