Coverage for /Users/davegaeddert/Developer/dropseed/plain/plain-models/plain/models/constraints.py: 52%
233 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 11:16 -0600
1import warnings
2from enum import Enum
3from types import NoneType
5from plain.exceptions import FieldError, ValidationError
6from plain.models.db import DEFAULT_DB_ALIAS, connections
7from plain.models.expressions import Exists, ExpressionList, F, OrderBy
8from plain.models.indexes import IndexExpression
9from plain.models.lookups import Exact
10from plain.models.query_utils import Q
11from plain.models.sql.query import Query
12from plain.utils.deprecation import RemovedInDjango60Warning
14__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
17class BaseConstraint:
18 default_violation_error_message = "Constraint “%(name)s” is violated."
19 violation_error_code = None
20 violation_error_message = None
22 # RemovedInDjango60Warning: When the deprecation ends, replace with:
23 # def __init__(
24 # self, *, name, violation_error_code=None, violation_error_message=None
25 # ):
26 def __init__(
27 self, *args, name=None, violation_error_code=None, violation_error_message=None
28 ):
29 # RemovedInDjango60Warning.
30 if name is None and not args:
31 raise TypeError(
32 f"{self.__class__.__name__}.__init__() missing 1 required keyword-only "
33 f"argument: 'name'"
34 )
35 self.name = name
36 if violation_error_code is not None:
37 self.violation_error_code = violation_error_code
38 if violation_error_message is not None:
39 self.violation_error_message = violation_error_message
40 else:
41 self.violation_error_message = self.default_violation_error_message
42 # RemovedInDjango60Warning.
43 if args:
44 warnings.warn(
45 f"Passing positional arguments to {self.__class__.__name__} is "
46 f"deprecated.",
47 RemovedInDjango60Warning,
48 stacklevel=2,
49 )
50 for arg, attr in zip(args, ["name", "violation_error_message"]):
51 if arg:
52 setattr(self, attr, arg)
54 @property
55 def contains_expressions(self):
56 return False
58 def constraint_sql(self, model, schema_editor):
59 raise NotImplementedError("This method must be implemented by a subclass.")
61 def create_sql(self, model, schema_editor):
62 raise NotImplementedError("This method must be implemented by a subclass.")
64 def remove_sql(self, model, schema_editor):
65 raise NotImplementedError("This method must be implemented by a subclass.")
67 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
68 raise NotImplementedError("This method must be implemented by a subclass.")
70 def get_violation_error_message(self):
71 return self.violation_error_message % {"name": self.name}
73 def deconstruct(self):
74 path = f"{self.__class__.__module__}.{self.__class__.__name__}"
75 path = path.replace("plain.models.constraints", "plain.models")
76 kwargs = {"name": self.name}
77 if (
78 self.violation_error_message is not None
79 and self.violation_error_message != self.default_violation_error_message
80 ):
81 kwargs["violation_error_message"] = self.violation_error_message
82 if self.violation_error_code is not None:
83 kwargs["violation_error_code"] = self.violation_error_code
84 return (path, (), kwargs)
86 def clone(self):
87 _, args, kwargs = self.deconstruct()
88 return self.__class__(*args, **kwargs)
91class CheckConstraint(BaseConstraint):
92 def __init__(
93 self, *, check, name, violation_error_code=None, violation_error_message=None
94 ):
95 self.check = check
96 if not getattr(check, "conditional", False):
97 raise TypeError(
98 "CheckConstraint.check must be a Q instance or boolean expression."
99 )
100 super().__init__(
101 name=name,
102 violation_error_code=violation_error_code,
103 violation_error_message=violation_error_message,
104 )
106 def _get_check_sql(self, model, schema_editor):
107 query = Query(model=model, alias_cols=False)
108 where = query.build_where(self.check)
109 compiler = query.get_compiler(connection=schema_editor.connection)
110 sql, params = where.as_sql(compiler, schema_editor.connection)
111 return sql % tuple(schema_editor.quote_value(p) for p in params)
113 def constraint_sql(self, model, schema_editor):
114 check = self._get_check_sql(model, schema_editor)
115 return schema_editor._check_sql(self.name, check)
117 def create_sql(self, model, schema_editor):
118 check = self._get_check_sql(model, schema_editor)
119 return schema_editor._create_check_sql(model, self.name, check)
121 def remove_sql(self, model, schema_editor):
122 return schema_editor._delete_check_sql(model, self.name)
124 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
125 against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
126 try:
127 if not Q(self.check).check(against, using=using):
128 raise ValidationError(
129 self.get_violation_error_message(), code=self.violation_error_code
130 )
131 except FieldError:
132 pass
134 def __repr__(self):
135 return "<{}: check={} name={}{}{}>".format(
136 self.__class__.__qualname__,
137 self.check,
138 repr(self.name),
139 (
140 ""
141 if self.violation_error_code is None
142 else f" violation_error_code={self.violation_error_code!r}"
143 ),
144 (
145 ""
146 if self.violation_error_message is None
147 or self.violation_error_message == self.default_violation_error_message
148 else f" violation_error_message={self.violation_error_message!r}"
149 ),
150 )
152 def __eq__(self, other):
153 if isinstance(other, CheckConstraint):
154 return (
155 self.name == other.name
156 and self.check == other.check
157 and self.violation_error_code == other.violation_error_code
158 and self.violation_error_message == other.violation_error_message
159 )
160 return super().__eq__(other)
162 def deconstruct(self):
163 path, args, kwargs = super().deconstruct()
164 kwargs["check"] = self.check
165 return path, args, kwargs
168class Deferrable(Enum):
169 DEFERRED = "deferred"
170 IMMEDIATE = "immediate"
172 # A similar format was proposed for Python 3.10.
173 def __repr__(self):
174 return f"{self.__class__.__qualname__}.{self._name_}"
177class UniqueConstraint(BaseConstraint):
178 def __init__(
179 self,
180 *expressions,
181 fields=(),
182 name=None,
183 condition=None,
184 deferrable=None,
185 include=None,
186 opclasses=(),
187 violation_error_code=None,
188 violation_error_message=None,
189 ):
190 if not name:
191 raise ValueError("A unique constraint must be named.")
192 if not expressions and not fields:
193 raise ValueError(
194 "At least one field or expression is required to define a "
195 "unique constraint."
196 )
197 if expressions and fields:
198 raise ValueError(
199 "UniqueConstraint.fields and expressions are mutually exclusive."
200 )
201 if not isinstance(condition, NoneType | Q):
202 raise ValueError("UniqueConstraint.condition must be a Q instance.")
203 if condition and deferrable:
204 raise ValueError("UniqueConstraint with conditions cannot be deferred.")
205 if include and deferrable:
206 raise ValueError("UniqueConstraint with include fields cannot be deferred.")
207 if opclasses and deferrable:
208 raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
209 if expressions and deferrable:
210 raise ValueError("UniqueConstraint with expressions cannot be deferred.")
211 if expressions and opclasses:
212 raise ValueError(
213 "UniqueConstraint.opclasses cannot be used with expressions. "
214 "Use a custom OpClass() instead."
215 )
216 if not isinstance(deferrable, NoneType | Deferrable):
217 raise ValueError(
218 "UniqueConstraint.deferrable must be a Deferrable instance."
219 )
220 if not isinstance(include, NoneType | list | tuple):
221 raise ValueError("UniqueConstraint.include must be a list or tuple.")
222 if not isinstance(opclasses, list | tuple):
223 raise ValueError("UniqueConstraint.opclasses must be a list or tuple.")
224 if opclasses and len(fields) != len(opclasses):
225 raise ValueError(
226 "UniqueConstraint.fields and UniqueConstraint.opclasses must "
227 "have the same number of elements."
228 )
229 self.fields = tuple(fields)
230 self.condition = condition
231 self.deferrable = deferrable
232 self.include = tuple(include) if include else ()
233 self.opclasses = opclasses
234 self.expressions = tuple(
235 F(expression) if isinstance(expression, str) else expression
236 for expression in expressions
237 )
238 super().__init__(
239 name=name,
240 violation_error_code=violation_error_code,
241 violation_error_message=violation_error_message,
242 )
244 @property
245 def contains_expressions(self):
246 return bool(self.expressions)
248 def _get_condition_sql(self, model, schema_editor):
249 if self.condition is None:
250 return None
251 query = Query(model=model, alias_cols=False)
252 where = query.build_where(self.condition)
253 compiler = query.get_compiler(connection=schema_editor.connection)
254 sql, params = where.as_sql(compiler, schema_editor.connection)
255 return sql % tuple(schema_editor.quote_value(p) for p in params)
257 def _get_index_expressions(self, model, schema_editor):
258 if not self.expressions:
259 return None
260 index_expressions = []
261 for expression in self.expressions:
262 index_expression = IndexExpression(expression)
263 index_expression.set_wrapper_classes(schema_editor.connection)
264 index_expressions.append(index_expression)
265 return ExpressionList(*index_expressions).resolve_expression(
266 Query(model, alias_cols=False),
267 )
269 def constraint_sql(self, model, schema_editor):
270 fields = [model._meta.get_field(field_name) for field_name in self.fields]
271 include = [
272 model._meta.get_field(field_name).column for field_name in self.include
273 ]
274 condition = self._get_condition_sql(model, schema_editor)
275 expressions = self._get_index_expressions(model, schema_editor)
276 return schema_editor._unique_sql(
277 model,
278 fields,
279 self.name,
280 condition=condition,
281 deferrable=self.deferrable,
282 include=include,
283 opclasses=self.opclasses,
284 expressions=expressions,
285 )
287 def create_sql(self, model, schema_editor):
288 fields = [model._meta.get_field(field_name) for field_name in self.fields]
289 include = [
290 model._meta.get_field(field_name).column for field_name in self.include
291 ]
292 condition = self._get_condition_sql(model, schema_editor)
293 expressions = self._get_index_expressions(model, schema_editor)
294 return schema_editor._create_unique_sql(
295 model,
296 fields,
297 self.name,
298 condition=condition,
299 deferrable=self.deferrable,
300 include=include,
301 opclasses=self.opclasses,
302 expressions=expressions,
303 )
305 def remove_sql(self, model, schema_editor):
306 condition = self._get_condition_sql(model, schema_editor)
307 include = [
308 model._meta.get_field(field_name).column for field_name in self.include
309 ]
310 expressions = self._get_index_expressions(model, schema_editor)
311 return schema_editor._delete_unique_sql(
312 model,
313 self.name,
314 condition=condition,
315 deferrable=self.deferrable,
316 include=include,
317 opclasses=self.opclasses,
318 expressions=expressions,
319 )
321 def __repr__(self):
322 return "<{}:{}{}{}{}{}{}{}{}{}>".format(
323 self.__class__.__qualname__,
324 "" if not self.fields else f" fields={repr(self.fields)}",
325 "" if not self.expressions else f" expressions={repr(self.expressions)}",
326 f" name={repr(self.name)}",
327 "" if self.condition is None else f" condition={self.condition}",
328 "" if self.deferrable is None else f" deferrable={self.deferrable!r}",
329 "" if not self.include else f" include={repr(self.include)}",
330 "" if not self.opclasses else f" opclasses={repr(self.opclasses)}",
331 (
332 ""
333 if self.violation_error_code is None
334 else f" violation_error_code={self.violation_error_code!r}"
335 ),
336 (
337 ""
338 if self.violation_error_message is None
339 or self.violation_error_message == self.default_violation_error_message
340 else f" violation_error_message={self.violation_error_message!r}"
341 ),
342 )
344 def __eq__(self, other):
345 if isinstance(other, UniqueConstraint):
346 return (
347 self.name == other.name
348 and self.fields == other.fields
349 and self.condition == other.condition
350 and self.deferrable == other.deferrable
351 and self.include == other.include
352 and self.opclasses == other.opclasses
353 and self.expressions == other.expressions
354 and self.violation_error_code == other.violation_error_code
355 and self.violation_error_message == other.violation_error_message
356 )
357 return super().__eq__(other)
359 def deconstruct(self):
360 path, args, kwargs = super().deconstruct()
361 if self.fields:
362 kwargs["fields"] = self.fields
363 if self.condition:
364 kwargs["condition"] = self.condition
365 if self.deferrable:
366 kwargs["deferrable"] = self.deferrable
367 if self.include:
368 kwargs["include"] = self.include
369 if self.opclasses:
370 kwargs["opclasses"] = self.opclasses
371 return path, self.expressions, kwargs
373 def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
374 queryset = model._default_manager.using(using)
375 if self.fields:
376 lookup_kwargs = {}
377 for field_name in self.fields:
378 if exclude and field_name in exclude:
379 return
380 field = model._meta.get_field(field_name)
381 lookup_value = getattr(instance, field.attname)
382 if lookup_value is None or (
383 lookup_value == ""
384 and connections[using].features.interprets_empty_strings_as_nulls
385 ):
386 # A composite constraint containing NULL value cannot cause
387 # a violation since NULL != NULL in SQL.
388 return
389 lookup_kwargs[field.name] = lookup_value
390 queryset = queryset.filter(**lookup_kwargs)
391 else:
392 # Ignore constraints with excluded fields.
393 if exclude:
394 for expression in self.expressions:
395 if hasattr(expression, "flatten"):
396 for expr in expression.flatten():
397 if isinstance(expr, F) and expr.name in exclude:
398 return
399 elif isinstance(expression, F) and expression.name in exclude:
400 return
401 replacements = {
402 F(field): value
403 for field, value in instance._get_field_value_map(
404 meta=model._meta, exclude=exclude
405 ).items()
406 }
407 expressions = []
408 for expr in self.expressions:
409 # Ignore ordering.
410 if isinstance(expr, OrderBy):
411 expr = expr.expression
412 expressions.append(Exact(expr, expr.replace_expressions(replacements)))
413 queryset = queryset.filter(*expressions)
414 model_class_pk = instance._get_pk_val(model._meta)
415 if not instance._state.adding and model_class_pk is not None:
416 queryset = queryset.exclude(pk=model_class_pk)
417 if not self.condition:
418 if queryset.exists():
419 if self.expressions:
420 raise ValidationError(
421 self.get_violation_error_message(),
422 code=self.violation_error_code,
423 )
424 # When fields are defined, use the unique_error_message() for
425 # backward compatibility.
426 for model, constraints in instance.get_constraints():
427 for constraint in constraints:
428 if constraint is self:
429 raise ValidationError(
430 instance.unique_error_message(model, self.fields),
431 )
432 else:
433 against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
434 try:
435 if (self.condition & Exists(queryset.filter(self.condition))).check(
436 against, using=using
437 ):
438 raise ValidationError(
439 self.get_violation_error_message(),
440 code=self.violation_error_code,
441 )
442 except FieldError:
443 pass