Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-oauth/plain/oauth/models.py: 92%
86 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1from typing import TYPE_CHECKING
3from plain import models
4from plain.auth import get_user_model
5from plain.exceptions import ValidationError
6from plain.models import transaction
7from plain.models.db import IntegrityError, OperationalError, ProgrammingError
8from plain.preflight import Error
9from plain.runtime import settings
10from plain.utils import timezone
12from .exceptions import OAuthUserAlreadyExistsError
14if TYPE_CHECKING:
15 from .providers import OAuthToken, OAuthUser
18# TODO preflight check for deploy that ensures all provider keys in db are also in settings?
21class OAuthConnection(models.Model):
22 created_at = models.DateTimeField(auto_now_add=True)
23 updated_at = models.DateTimeField(auto_now=True)
25 user = models.ForeignKey(
26 settings.AUTH_USER_MODEL,
27 on_delete=models.CASCADE,
28 related_name="oauth_connections",
29 )
31 # The key used to refer to this provider type (in settings)
32 provider_key = models.CharField(max_length=100, db_index=True)
34 # The unique ID of the user on the provider's system
35 provider_user_id = models.CharField(max_length=100, db_index=True)
37 # Token data
38 access_token = models.CharField(max_length=2000)
39 refresh_token = models.CharField(max_length=2000, blank=True)
40 access_token_expires_at = models.DateTimeField(blank=True, null=True)
41 refresh_token_expires_at = models.DateTimeField(blank=True, null=True)
43 class Meta:
44 constraints = [
45 models.UniqueConstraint(
46 fields=["provider_key", "provider_user_id"],
47 name="unique_oauth_provider_user_id",
48 )
49 ]
50 ordering = ("provider_key",)
52 def __str__(self):
53 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
55 def refresh_access_token(self) -> None:
56 from .providers import OAuthToken, get_oauth_provider_instance
58 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
59 oauth_token = OAuthToken(
60 access_token=self.access_token,
61 refresh_token=self.refresh_token,
62 access_token_expires_at=self.access_token_expires_at,
63 refresh_token_expires_at=self.refresh_token_expires_at,
64 )
65 refreshed_oauth_token = provider_instance.refresh_oauth_token(
66 oauth_token=oauth_token
67 )
68 self.set_token_fields(refreshed_oauth_token)
69 self.save()
71 def set_token_fields(self, oauth_token: "OAuthToken"):
72 self.access_token = oauth_token.access_token
73 self.refresh_token = oauth_token.refresh_token
74 self.access_token_expires_at = oauth_token.access_token_expires_at
75 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
77 def set_user_fields(self, oauth_user: "OAuthUser"):
78 self.provider_user_id = oauth_user.id
80 def access_token_expired(self) -> bool:
81 return (
82 self.access_token_expires_at is not None
83 and self.access_token_expires_at < timezone.now()
84 )
86 def refresh_token_expired(self) -> bool:
87 return (
88 self.refresh_token_expires_at is not None
89 and self.refresh_token_expires_at < timezone.now()
90 )
92 @classmethod
93 def get_or_create_user(
94 cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
95 ) -> "OAuthConnection":
96 try:
97 connection = cls.objects.get(
98 provider_key=provider_key,
99 provider_user_id=oauth_user.id,
100 )
101 connection.set_token_fields(oauth_token)
102 connection.save()
103 return connection
104 except cls.DoesNotExist:
105 with transaction.atomic():
106 # If email needs to be unique, then we expect
107 # that to be taken care of on the user model itself
108 try:
109 user = get_user_model()(
110 **oauth_user.user_model_fields,
111 )
112 user.save()
113 except (IntegrityError, ValidationError):
114 raise OAuthUserAlreadyExistsError()
116 return cls.connect(
117 user=user,
118 provider_key=provider_key,
119 oauth_token=oauth_token,
120 oauth_user=oauth_user,
121 )
123 @classmethod
124 def connect(
125 cls,
126 *,
127 user: settings.AUTH_USER_MODEL,
128 provider_key: str,
129 oauth_token: "OAuthToken",
130 oauth_user: "OAuthUser",
131 ) -> "OAuthConnection":
132 """
133 Connect will either create a new connection or update an existing connection
134 """
135 try:
136 connection = cls.objects.get(
137 user=user,
138 provider_key=provider_key,
139 provider_user_id=oauth_user.id,
140 )
141 except cls.DoesNotExist:
142 # Create our own instance (not using get_or_create)
143 # so that any created signals contain the token fields too
144 connection = cls(
145 user=user,
146 provider_key=provider_key,
147 provider_user_id=oauth_user.id,
148 )
150 connection.set_user_fields(oauth_user)
151 connection.set_token_fields(oauth_token)
152 connection.save()
154 return connection
156 @classmethod
157 def check(cls, **kwargs):
158 """
159 A system check for ensuring that provider_keys in the database are also present in settings.
161 Note that the --database flag is required for this to work:
162 python manage.py check --database default
163 """
164 errors = super().check(**kwargs)
166 databases = kwargs.get("databases", None)
167 if not databases:
168 return errors
170 from .providers import get_provider_keys
172 for database in databases:
173 try:
174 keys_in_db = set(
175 cls.objects.using(database)
176 .values_list("provider_key", flat=True)
177 .distinct()
178 )
179 except (OperationalError, ProgrammingError):
180 # Check runs on manage.py migrate, and the table may not exist yet
181 # or it may not be installed on the particular database intentionally
182 continue
184 keys_in_settings = set(get_provider_keys())
186 if keys_in_db - keys_in_settings:
187 errors.append(
188 Error(
189 "The following OAuth providers are in the database but not in the settings: {}".format(
190 ", ".join(keys_in_db - keys_in_settings)
191 ),
192 id="plain.oauth.E001",
193 )
194 )
196 return errors