Coverage for /Users/davegaeddert/Development/dropseed/plain/plain-oauth/plain/oauth/models.py: 92%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-16 22:04 -0500

1from typing import TYPE_CHECKING 

2 

3from plain import models 

4from plain.auth import get_user_model 

5from plain.exceptions import ValidationError 

6from plain.models import transaction 

7from plain.models.db import IntegrityError, OperationalError, ProgrammingError 

8from plain.preflight import Error 

9from plain.runtime import settings 

10from plain.utils import timezone 

11 

12from .exceptions import OAuthUserAlreadyExistsError 

13 

14if TYPE_CHECKING: 

15 from .providers import OAuthToken, OAuthUser 

16 

17 

18# TODO preflight check for deploy that ensures all provider keys in db are also in settings? 

19 

20 

21class OAuthConnection(models.Model): 

22 created_at = models.DateTimeField(auto_now_add=True) 

23 updated_at = models.DateTimeField(auto_now=True) 

24 

25 user = models.ForeignKey( 

26 settings.AUTH_USER_MODEL, 

27 on_delete=models.CASCADE, 

28 related_name="oauth_connections", 

29 ) 

30 

31 # The key used to refer to this provider type (in settings) 

32 provider_key = models.CharField(max_length=100, db_index=True) 

33 

34 # The unique ID of the user on the provider's system 

35 provider_user_id = models.CharField(max_length=100, db_index=True) 

36 

37 # Token data 

38 access_token = models.CharField(max_length=2000) 

39 refresh_token = models.CharField(max_length=2000, blank=True) 

40 access_token_expires_at = models.DateTimeField(blank=True, null=True) 

41 refresh_token_expires_at = models.DateTimeField(blank=True, null=True) 

42 

43 class Meta: 

44 constraints = [ 

45 models.UniqueConstraint( 

46 fields=["provider_key", "provider_user_id"], 

47 name="unique_oauth_provider_user_id", 

48 ) 

49 ] 

50 ordering = ("provider_key",) 

51 

52 def __str__(self): 

53 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]" 

54 

55 def refresh_access_token(self) -> None: 

56 from .providers import OAuthToken, get_oauth_provider_instance 

57 

58 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key) 

59 oauth_token = OAuthToken( 

60 access_token=self.access_token, 

61 refresh_token=self.refresh_token, 

62 access_token_expires_at=self.access_token_expires_at, 

63 refresh_token_expires_at=self.refresh_token_expires_at, 

64 ) 

65 refreshed_oauth_token = provider_instance.refresh_oauth_token( 

66 oauth_token=oauth_token 

67 ) 

68 self.set_token_fields(refreshed_oauth_token) 

69 self.save() 

70 

71 def set_token_fields(self, oauth_token: "OAuthToken"): 

72 self.access_token = oauth_token.access_token 

73 self.refresh_token = oauth_token.refresh_token 

74 self.access_token_expires_at = oauth_token.access_token_expires_at 

75 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at 

76 

77 def set_user_fields(self, oauth_user: "OAuthUser"): 

78 self.provider_user_id = oauth_user.id 

79 

80 def access_token_expired(self) -> bool: 

81 return ( 

82 self.access_token_expires_at is not None 

83 and self.access_token_expires_at < timezone.now() 

84 ) 

85 

86 def refresh_token_expired(self) -> bool: 

87 return ( 

88 self.refresh_token_expires_at is not None 

89 and self.refresh_token_expires_at < timezone.now() 

90 ) 

91 

92 @classmethod 

93 def get_or_create_user( 

94 cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser" 

95 ) -> "OAuthConnection": 

96 try: 

97 connection = cls.objects.get( 

98 provider_key=provider_key, 

99 provider_user_id=oauth_user.id, 

100 ) 

101 connection.set_token_fields(oauth_token) 

102 connection.save() 

103 return connection 

104 except cls.DoesNotExist: 

105 with transaction.atomic(): 

106 # If email needs to be unique, then we expect 

107 # that to be taken care of on the user model itself 

108 try: 

109 user = get_user_model()( 

110 **oauth_user.user_model_fields, 

111 ) 

112 user.save() 

113 except (IntegrityError, ValidationError): 

114 raise OAuthUserAlreadyExistsError() 

115 

116 return cls.connect( 

117 user=user, 

118 provider_key=provider_key, 

119 oauth_token=oauth_token, 

120 oauth_user=oauth_user, 

121 ) 

122 

123 @classmethod 

124 def connect( 

125 cls, 

126 *, 

127 user: settings.AUTH_USER_MODEL, 

128 provider_key: str, 

129 oauth_token: "OAuthToken", 

130 oauth_user: "OAuthUser", 

131 ) -> "OAuthConnection": 

132 """ 

133 Connect will either create a new connection or update an existing connection 

134 """ 

135 try: 

136 connection = cls.objects.get( 

137 user=user, 

138 provider_key=provider_key, 

139 provider_user_id=oauth_user.id, 

140 ) 

141 except cls.DoesNotExist: 

142 # Create our own instance (not using get_or_create) 

143 # so that any created signals contain the token fields too 

144 connection = cls( 

145 user=user, 

146 provider_key=provider_key, 

147 provider_user_id=oauth_user.id, 

148 ) 

149 

150 connection.set_user_fields(oauth_user) 

151 connection.set_token_fields(oauth_token) 

152 connection.save() 

153 

154 return connection 

155 

156 @classmethod 

157 def check(cls, **kwargs): 

158 """ 

159 A system check for ensuring that provider_keys in the database are also present in settings. 

160 

161 Note that the --database flag is required for this to work: 

162 python manage.py check --database default 

163 """ 

164 errors = super().check(**kwargs) 

165 

166 databases = kwargs.get("databases", None) 

167 if not databases: 

168 return errors 

169 

170 from .providers import get_provider_keys 

171 

172 for database in databases: 

173 try: 

174 keys_in_db = set( 

175 cls.objects.using(database) 

176 .values_list("provider_key", flat=True) 

177 .distinct() 

178 ) 

179 except (OperationalError, ProgrammingError): 

180 # Check runs on manage.py migrate, and the table may not exist yet 

181 # or it may not be installed on the particular database intentionally 

182 continue 

183 

184 keys_in_settings = set(get_provider_keys()) 

185 

186 if keys_in_db - keys_in_settings: 

187 errors.append( 

188 Error( 

189 "The following OAuth providers are in the database but not in the settings: {}".format( 

190 ", ".join(keys_in_db - keys_in_settings) 

191 ), 

192 id="plain.oauth.E001", 

193 ) 

194 ) 

195 

196 return errors