Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-oauth/plain/oauth/providers.py: 82%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 11:16 -0600

1import datetime 

2import secrets 

3from typing import Any 

4from urllib.parse import urlencode 

5 

6from plain.auth import login as auth_login 

7from plain.http import HttpRequest, Response, ResponseRedirect 

8from plain.runtime import settings 

9from plain.urls import reverse 

10from plain.utils.crypto import get_random_string 

11from plain.utils.module_loading import import_string 

12 

13from .exceptions import OAuthError, OAuthStateMismatchError 

14from .models import OAuthConnection 

15 

16SESSION_STATE_KEY = "plainoauth_state" 

17SESSION_NEXT_KEY = "plainoauth_next" 

18 

19 

20class OAuthToken: 

21 def __init__( 

22 self, 

23 *, 

24 access_token: str, 

25 refresh_token: str = "", 

26 access_token_expires_at: datetime.datetime = None, 

27 refresh_token_expires_at: datetime.datetime = None, 

28 ): 

29 self.access_token = access_token 

30 self.refresh_token = refresh_token 

31 self.access_token_expires_at = access_token_expires_at 

32 self.refresh_token_expires_at = refresh_token_expires_at 

33 

34 

35class OAuthUser: 

36 def __init__(self, *, id: str, **user_model_fields: dict): 

37 self.id = id # ID on the provider's system 

38 self.user_model_fields = user_model_fields 

39 

40 def __str__(self): 

41 if "email" in self.user_model_fields: 

42 return self.user_model_fields["email"] 

43 if "username" in self.user_model_fields: 

44 return self.user_model_fields["username"] 

45 return str(self.id) 

46 

47 

48class OAuthProvider: 

49 authorization_url = "" 

50 

51 def __init__( 

52 self, 

53 *, 

54 # Provided automatically 

55 provider_key: str, 

56 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting 

57 client_id: str, 

58 client_secret: str, 

59 # Not necessarily required, but commonly used 

60 scope: str = "", 

61 ): 

62 self.provider_key = provider_key 

63 self.client_id = client_id 

64 self.client_secret = client_secret 

65 self.scope = scope 

66 

67 def get_authorization_url_params(self, *, request: HttpRequest) -> dict: 

68 return { 

69 "redirect_uri": self.get_callback_url(request=request), 

70 "client_id": self.get_client_id(), 

71 "scope": self.get_scope(), 

72 "state": self.generate_state(), 

73 "response_type": "code", 

74 } 

75 

76 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: 

77 raise NotImplementedError() 

78 

79 def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken: 

80 raise NotImplementedError() 

81 

82 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: 

83 raise NotImplementedError() 

84 

85 def get_authorization_url(self, *, request: HttpRequest) -> str: 

86 return self.authorization_url 

87 

88 def get_client_id(self) -> str: 

89 return self.client_id 

90 

91 def get_client_secret(self) -> str: 

92 return self.client_secret 

93 

94 def get_scope(self) -> str: 

95 return self.scope 

96 

97 def get_callback_url(self, *, request: HttpRequest) -> str: 

98 url = reverse("oauth:callback", kwargs={"provider": self.provider_key}) 

99 return request.build_absolute_uri(url) 

100 

101 def generate_state(self) -> str: 

102 return get_random_string(length=32) 

103 

104 def check_request_state(self, *, request: HttpRequest) -> None: 

105 if error := request.GET.get("error"): 

106 raise OAuthError(error) 

107 

108 state = request.GET["state"] 

109 expected_state = request.session.pop(SESSION_STATE_KEY) 

110 request.session.save() # Make sure the pop is saved (won't save on an exception) 

111 if not secrets.compare_digest(state, expected_state): 

112 raise OAuthStateMismatchError() 

113 

114 def handle_login_request( 

115 self, *, request: HttpRequest, redirect_to: str = "" 

116 ) -> Response: 

117 authorization_url = self.get_authorization_url(request=request) 

118 authorization_params = self.get_authorization_url_params(request=request) 

119 

120 if "state" in authorization_params: 

121 # Store the state in the session so we can check on callback 

122 request.session[SESSION_STATE_KEY] = authorization_params["state"] 

123 

124 # Store next url in session so we can get it on the callback request 

125 if redirect_to: 

126 request.session[SESSION_NEXT_KEY] = redirect_to 

127 elif "next" in request.POST: 

128 request.session[SESSION_NEXT_KEY] = request.POST["next"] 

129 

130 # Sort authorization params for consistency 

131 sorted_authorization_params = sorted(authorization_params.items()) 

132 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params) 

133 return ResponseRedirect(redirect_url) 

134 

135 def handle_connect_request( 

136 self, *, request: HttpRequest, redirect_to: str = "" 

137 ) -> Response: 

138 return self.handle_login_request(request=request, redirect_to=redirect_to) 

139 

140 def handle_disconnect_request(self, *, request: HttpRequest) -> Response: 

141 provider_user_id = request.POST["provider_user_id"] 

142 connection = OAuthConnection.objects.get( 

143 provider_key=self.provider_key, provider_user_id=provider_user_id 

144 ) 

145 connection.delete() 

146 redirect_url = self.get_disconnect_redirect_url(request=request) 

147 return ResponseRedirect(redirect_url) 

148 

149 def handle_callback_request(self, *, request: HttpRequest) -> Response: 

150 self.check_request_state(request=request) 

151 

152 oauth_token = self.get_oauth_token(code=request.GET["code"], request=request) 

153 oauth_user = self.get_oauth_user(oauth_token=oauth_token) 

154 

155 if request.user: 

156 connection = OAuthConnection.connect( 

157 user=request.user, 

158 provider_key=self.provider_key, 

159 oauth_token=oauth_token, 

160 oauth_user=oauth_user, 

161 ) 

162 user = connection.user 

163 else: 

164 connection = OAuthConnection.get_or_create_user( 

165 provider_key=self.provider_key, 

166 oauth_token=oauth_token, 

167 oauth_user=oauth_user, 

168 ) 

169 

170 user = connection.user 

171 

172 self.login(request=request, user=user) 

173 

174 redirect_url = self.get_login_redirect_url(request=request) 

175 return ResponseRedirect(redirect_url) 

176 

177 def login(self, *, request: HttpRequest, user: Any) -> Response: 

178 auth_login(request=request, user=user) 

179 

180 def get_login_redirect_url(self, *, request: HttpRequest) -> str: 

181 return request.session.pop(SESSION_NEXT_KEY, "/") 

182 

183 def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str: 

184 return request.POST.get("next", "/") 

185 

186 

187def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider: 

188 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}) 

189 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"] 

190 provider_class = import_string(provider_class_path) 

191 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {}) 

192 return provider_class(provider_key=provider_key, **provider_kwargs) 

193 

194 

195def get_provider_keys() -> list[str]: 

196 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())