Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-oauth/plain/oauth/providers.py: 82%
112 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import datetime
2import secrets
3from typing import Any
4from urllib.parse import urlencode
6from plain.auth import login as auth_login
7from plain.http import HttpRequest, Response, ResponseRedirect
8from plain.runtime import settings
9from plain.urls import reverse
10from plain.utils.crypto import get_random_string
11from plain.utils.module_loading import import_string
13from .exceptions import OAuthError, OAuthStateMismatchError
14from .models import OAuthConnection
16SESSION_STATE_KEY = "plainoauth_state"
17SESSION_NEXT_KEY = "plainoauth_next"
20class OAuthToken:
21 def __init__(
22 self,
23 *,
24 access_token: str,
25 refresh_token: str = "",
26 access_token_expires_at: datetime.datetime = None,
27 refresh_token_expires_at: datetime.datetime = None,
28 ):
29 self.access_token = access_token
30 self.refresh_token = refresh_token
31 self.access_token_expires_at = access_token_expires_at
32 self.refresh_token_expires_at = refresh_token_expires_at
35class OAuthUser:
36 def __init__(self, *, id: str, **user_model_fields: dict):
37 self.id = id # ID on the provider's system
38 self.user_model_fields = user_model_fields
40 def __str__(self):
41 if "email" in self.user_model_fields:
42 return self.user_model_fields["email"]
43 if "username" in self.user_model_fields:
44 return self.user_model_fields["username"]
45 return str(self.id)
48class OAuthProvider:
49 authorization_url = ""
51 def __init__(
52 self,
53 *,
54 # Provided automatically
55 provider_key: str,
56 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting
57 client_id: str,
58 client_secret: str,
59 # Not necessarily required, but commonly used
60 scope: str = "",
61 ):
62 self.provider_key = provider_key
63 self.client_id = client_id
64 self.client_secret = client_secret
65 self.scope = scope
67 def get_authorization_url_params(self, *, request: HttpRequest) -> dict:
68 return {
69 "redirect_uri": self.get_callback_url(request=request),
70 "client_id": self.get_client_id(),
71 "scope": self.get_scope(),
72 "state": self.generate_state(),
73 "response_type": "code",
74 }
76 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken:
77 raise NotImplementedError()
79 def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken:
80 raise NotImplementedError()
82 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser:
83 raise NotImplementedError()
85 def get_authorization_url(self, *, request: HttpRequest) -> str:
86 return self.authorization_url
88 def get_client_id(self) -> str:
89 return self.client_id
91 def get_client_secret(self) -> str:
92 return self.client_secret
94 def get_scope(self) -> str:
95 return self.scope
97 def get_callback_url(self, *, request: HttpRequest) -> str:
98 url = reverse("oauth:callback", kwargs={"provider": self.provider_key})
99 return request.build_absolute_uri(url)
101 def generate_state(self) -> str:
102 return get_random_string(length=32)
104 def check_request_state(self, *, request: HttpRequest) -> None:
105 if error := request.GET.get("error"):
106 raise OAuthError(error)
108 state = request.GET["state"]
109 expected_state = request.session.pop(SESSION_STATE_KEY)
110 request.session.save() # Make sure the pop is saved (won't save on an exception)
111 if not secrets.compare_digest(state, expected_state):
112 raise OAuthStateMismatchError()
114 def handle_login_request(
115 self, *, request: HttpRequest, redirect_to: str = ""
116 ) -> Response:
117 authorization_url = self.get_authorization_url(request=request)
118 authorization_params = self.get_authorization_url_params(request=request)
120 if "state" in authorization_params:
121 # Store the state in the session so we can check on callback
122 request.session[SESSION_STATE_KEY] = authorization_params["state"]
124 # Store next url in session so we can get it on the callback request
125 if redirect_to:
126 request.session[SESSION_NEXT_KEY] = redirect_to
127 elif "next" in request.POST:
128 request.session[SESSION_NEXT_KEY] = request.POST["next"]
130 # Sort authorization params for consistency
131 sorted_authorization_params = sorted(authorization_params.items())
132 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params)
133 return ResponseRedirect(redirect_url)
135 def handle_connect_request(
136 self, *, request: HttpRequest, redirect_to: str = ""
137 ) -> Response:
138 return self.handle_login_request(request=request, redirect_to=redirect_to)
140 def handle_disconnect_request(self, *, request: HttpRequest) -> Response:
141 provider_user_id = request.POST["provider_user_id"]
142 connection = OAuthConnection.objects.get(
143 provider_key=self.provider_key, provider_user_id=provider_user_id
144 )
145 connection.delete()
146 redirect_url = self.get_disconnect_redirect_url(request=request)
147 return ResponseRedirect(redirect_url)
149 def handle_callback_request(self, *, request: HttpRequest) -> Response:
150 self.check_request_state(request=request)
152 oauth_token = self.get_oauth_token(code=request.GET["code"], request=request)
153 oauth_user = self.get_oauth_user(oauth_token=oauth_token)
155 if request.user:
156 connection = OAuthConnection.connect(
157 user=request.user,
158 provider_key=self.provider_key,
159 oauth_token=oauth_token,
160 oauth_user=oauth_user,
161 )
162 user = connection.user
163 else:
164 connection = OAuthConnection.get_or_create_user(
165 provider_key=self.provider_key,
166 oauth_token=oauth_token,
167 oauth_user=oauth_user,
168 )
170 user = connection.user
172 self.login(request=request, user=user)
174 redirect_url = self.get_login_redirect_url(request=request)
175 return ResponseRedirect(redirect_url)
177 def login(self, *, request: HttpRequest, user: Any) -> Response:
178 auth_login(request=request, user=user)
180 def get_login_redirect_url(self, *, request: HttpRequest) -> str:
181 return request.session.pop(SESSION_NEXT_KEY, "/")
183 def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str:
184 return request.POST.get("next", "/")
187def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider:
188 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {})
189 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"]
190 provider_class = import_string(provider_class_path)
191 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {})
192 return provider_class(provider_key=provider_key, **provider_kwargs)
195def get_provider_keys() -> list[str]:
196 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())