Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-oauth/plain/oauth/providers.py: 82%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:04 -0500

1import datetime 

2import secrets 

3from typing import Any 

4from urllib.parse import urlencode 

5 

6from plain.auth import login as auth_login 

7from plain.http import HttpRequest, Response, ResponseRedirect 

8from plain.runtime import settings 

9from plain.urls import reverse 

10from plain.utils.crypto import get_random_string 

11from plain.utils.module_loading import import_string 

12 

13from .exceptions import OAuthError, OAuthStateMismatchError 

14from .models import OAuthConnection 

15 

16SESSION_STATE_KEY = "plainoauth_state" 

17SESSION_NEXT_KEY = "plainoauth_next" 

18 

19 

20class OAuthToken: 

21 def __init__( 

22 self, 

23 *, 

24 access_token: str, 

25 refresh_token: str = "", 

26 access_token_expires_at: datetime.datetime = None, 

27 refresh_token_expires_at: datetime.datetime = None, 

28 ): 

29 self.access_token = access_token 

30 self.refresh_token = refresh_token 

31 self.access_token_expires_at = access_token_expires_at 

32 self.refresh_token_expires_at = refresh_token_expires_at 

33 

34 

35class OAuthUser: 

36 def __init__(self, *, id: str, **user_model_fields: dict): 

37 self.id = id # ID on the provider's system 

38 self.user_model_fields = user_model_fields 

39 

40 def __str__(self): 

41 if "email" in self.user_model_fields: 

42 return self.user_model_fields["email"] 

43 if "username" in self.user_model_fields: 

44 return self.user_model_fields["username"] 

45 return str(self.id) 

46 

47 

48class OAuthProvider: 

49 authorization_url = "" 

50 

51 def __init__( 

52 self, 

53 *, 

54 # Provided automatically 

55 provider_key: str, 

56 # Required as kwargs in OAUTH_LOGIN_PROVIDERS setting 

57 client_id: str, 

58 client_secret: str, 

59 # Not necessarily required, but commonly used 

60 scope: str = "", 

61 ): 

62 self.provider_key = provider_key 

63 self.client_id = client_id 

64 self.client_secret = client_secret 

65 self.scope = scope 

66 

67 def get_authorization_url_params(self, *, request: HttpRequest) -> dict: 

68 return { 

69 "redirect_uri": self.get_callback_url(request=request), 

70 "client_id": self.get_client_id(), 

71 "scope": self.get_scope(), 

72 "state": self.generate_state(), 

73 "response_type": "code", 

74 } 

75 

76 def refresh_oauth_token(self, *, oauth_token: OAuthToken) -> OAuthToken: 

77 raise NotImplementedError() 

78 

79 def get_oauth_token(self, *, code: str, request: HttpRequest) -> OAuthToken: 

80 raise NotImplementedError() 

81 

82 def get_oauth_user(self, *, oauth_token: OAuthToken) -> OAuthUser: 

83 raise NotImplementedError() 

84 

85 def get_authorization_url(self, *, request: HttpRequest) -> str: 

86 return self.authorization_url 

87 

88 def get_client_id(self) -> str: 

89 return self.client_id 

90 

91 def get_client_secret(self) -> str: 

92 return self.client_secret 

93 

94 def get_scope(self) -> str: 

95 return self.scope 

96 

97 def get_callback_url(self, *, request: HttpRequest) -> str: 

98 url = reverse("oauth:callback", kwargs={"provider": self.provider_key}) 

99 return request.build_absolute_uri(url) 

100 

101 def generate_state(self) -> str: 

102 return get_random_string(length=32) 

103 

104 def check_request_state(self, *, request: HttpRequest) -> None: 

105 if error := request.GET.get("error"): 

106 raise OAuthError(error) 

107 

108 state = request.GET["state"] 

109 expected_state = request.session.pop(SESSION_STATE_KEY) 

110 if not secrets.compare_digest(state, expected_state): 

111 raise OAuthStateMismatchError() 

112 

113 def handle_login_request( 

114 self, *, request: HttpRequest, redirect_to: str = "" 

115 ) -> Response: 

116 authorization_url = self.get_authorization_url(request=request) 

117 authorization_params = self.get_authorization_url_params(request=request) 

118 

119 if "state" in authorization_params: 

120 # Store the state in the session so we can check on callback 

121 request.session[SESSION_STATE_KEY] = authorization_params["state"] 

122 

123 # Store next url in session so we can get it on the callback request 

124 if redirect_to: 

125 request.session[SESSION_NEXT_KEY] = redirect_to 

126 elif "next" in request.POST: 

127 request.session[SESSION_NEXT_KEY] = request.POST["next"] 

128 

129 # Sort authorization params for consistency 

130 sorted_authorization_params = sorted(authorization_params.items()) 

131 redirect_url = authorization_url + "?" + urlencode(sorted_authorization_params) 

132 return ResponseRedirect(redirect_url) 

133 

134 def handle_connect_request( 

135 self, *, request: HttpRequest, redirect_to: str = "" 

136 ) -> Response: 

137 return self.handle_login_request(request=request, redirect_to=redirect_to) 

138 

139 def handle_disconnect_request(self, *, request: HttpRequest) -> Response: 

140 provider_user_id = request.POST["provider_user_id"] 

141 connection = OAuthConnection.objects.get( 

142 provider_key=self.provider_key, provider_user_id=provider_user_id 

143 ) 

144 connection.delete() 

145 redirect_url = self.get_disconnect_redirect_url(request=request) 

146 return ResponseRedirect(redirect_url) 

147 

148 def handle_callback_request(self, *, request: HttpRequest) -> Response: 

149 self.check_request_state(request=request) 

150 

151 oauth_token = self.get_oauth_token(code=request.GET["code"], request=request) 

152 oauth_user = self.get_oauth_user(oauth_token=oauth_token) 

153 

154 if request.user: 

155 connection = OAuthConnection.connect( 

156 user=request.user, 

157 provider_key=self.provider_key, 

158 oauth_token=oauth_token, 

159 oauth_user=oauth_user, 

160 ) 

161 user = connection.user 

162 else: 

163 connection = OAuthConnection.get_or_create_user( 

164 provider_key=self.provider_key, 

165 oauth_token=oauth_token, 

166 oauth_user=oauth_user, 

167 ) 

168 

169 user = connection.user 

170 

171 self.login(request=request, user=user) 

172 

173 redirect_url = self.get_login_redirect_url(request=request) 

174 return ResponseRedirect(redirect_url) 

175 

176 def login(self, *, request: HttpRequest, user: Any) -> Response: 

177 auth_login(request=request, user=user) 

178 

179 def get_login_redirect_url(self, *, request: HttpRequest) -> str: 

180 return request.session.pop(SESSION_NEXT_KEY, "/") 

181 

182 def get_disconnect_redirect_url(self, *, request: HttpRequest) -> str: 

183 return request.POST.get("next", "/") 

184 

185 

186def get_oauth_provider_instance(*, provider_key: str) -> OAuthProvider: 

187 OAUTH_LOGIN_PROVIDERS = getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}) 

188 provider_class_path = OAUTH_LOGIN_PROVIDERS[provider_key]["class"] 

189 provider_class = import_string(provider_class_path) 

190 provider_kwargs = OAUTH_LOGIN_PROVIDERS[provider_key].get("kwargs", {}) 

191 return provider_class(provider_key=provider_key, **provider_kwargs) 

192 

193 

194def get_provider_keys() -> list[str]: 

195 return list(getattr(settings, "OAUTH_LOGIN_PROVIDERS", {}).keys())