Coverage for C:\src\imod-python\imod\schemata.py: 91%
250 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
1"""
2Schemata to help validation of input.
4This code is based on: https://github.com/carbonplan/xarray-schema
6which has the following MIT license:
8 MIT License
10 Copyright (c) 2021 carbonplan
12 Permission is hereby granted, free of charge, to any person obtaining a copy
13 of this software and associated documentation files (the "Software"), to deal
14 in the Software without restriction, including without limitation the rights
15 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 copies of the Software, and to permit persons to whom the Software is
17 furnished to do so, subject to the following conditions:
19 The above copyright notice and this permission notice shall be included in all
20 copies or substantial portions of the Software.
22 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 SOFTWARE.
30In the future, we may be able to replace this module by whatever the best
31validation xarray library becomes.
32"""
34import abc
35import operator
36from functools import partial
37from typing import Any, Callable, Dict, Optional, Tuple, TypeAlias, Union
39import numpy as np
40import scipy
41import xarray as xr
42import xugrid as xu
43from numpy.typing import DTypeLike # noqa: F401
45from imod.typing import GridDataArray, ScalarAsDataArray
47DimsT = Union[str, None]
48ShapeT = Tuple[Union[int, None]]
49ChunksT = Union[bool, Dict[str, Union[int, None]]]
51OPERATORS = {
52 "<": operator.lt,
53 "<=": operator.le,
54 "==": operator.eq,
55 "!=": operator.ne,
56 ">=": operator.ge,
57 ">": operator.gt,
58}
61def partial_operator(op, value):
62 # partial doesn't allow us to insert the 1st arg on call, and
63 # operators don't work with kwargs, so resort to lambda to swap
64 # args a and b around.
65 # https://stackoverflow.com/a/37468215
66 return partial(lambda b, a: OPERATORS[op](a, b), value)
69def scalar_None(obj):
70 """
71 Test if object is a scalar None DataArray, which is the default value for optional
72 variables.
73 """
74 if not isinstance(obj, (xr.DataArray, xu.UgridDataArray)):
75 return False
76 else:
77 return (len(obj.shape) == 0) & (~obj.notnull()).all()
80def align_other_obj_with_coords(
81 obj: GridDataArray, other_obj: GridDataArray
82) -> Tuple[xr.DataArray, xr.DataArray]:
83 """
84 Align other_obj with obj if coordname in obj but not in its dims.
85 Avoid issues like:
86 https://github.com/Deltares/imod-python/issues/830
88 """
89 for coordname in obj.coords.keys():
90 if (coordname in other_obj.dims) and not (coordname in obj.dims):
91 obj = obj.expand_dims(coordname)
92 # Note:
93 # xr.align forces xu.UgridDataArray to xr.DataArray. Keep that in mind
94 # in further data processing.
95 return xr.align(obj, other_obj, join="left")
98class ValidationError(Exception):
99 pass
102class BaseSchema(abc.ABC):
103 @abc.abstractmethod
104 def validate(self, obj: GridDataArray, **kwargs) -> None:
105 pass
107 def __or__(self, other):
108 """
109 This allows us to write:
111 DimsSchema("layer", "y", "x") | DimsSchema("layer")
113 And get a SchemaUnion back.
114 """
115 return SchemaUnion(self, other)
118# SchemaType = TypeVar("SchemaType", bound=BaseSchema)
119SchemaType: TypeAlias = BaseSchema
122class SchemaUnion:
123 """
124 Succesful validation only requires a single succes.
126 Used to validate multiple options.
127 """
129 def __init__(self, *args):
130 ntypes = len(set(type(arg) for arg in args))
131 if ntypes > 1:
132 raise TypeError("schemata in a union should have the same type")
133 self.schemata = tuple(args)
135 def validate(self, obj: Any, **kwargs):
136 errors = []
137 for schema in self.schemata:
138 try:
139 schema.validate(obj, **kwargs)
140 except ValidationError as e:
141 errors.append(e)
143 if len(errors) == len(self.schemata): # All schemata failed
144 message = "\n\t" + "\n\t".join(str(error) for error in errors)
145 raise ValidationError(f"No option succeeded:{message}")
147 def __or__(self, other):
148 return SchemaUnion(*self.schemata, other)
151class DTypeSchema(BaseSchema):
152 def __init__(self, dtype: DTypeLike) -> None:
153 if dtype in [
154 np.floating,
155 np.integer,
156 np.signedinteger,
157 np.unsignedinteger,
158 np.generic,
159 ]:
160 self.dtype = dtype
161 else:
162 self.dtype = np.dtype(dtype)
164 def validate(self, obj: GridDataArray, **kwargs) -> None:
165 """
166 Validate dtype
168 Parameters
169 ----------
170 dtype : Any
171 Dtype of the DataArray.
172 """
173 if scalar_None(obj):
174 return
176 if not np.issubdtype(obj.dtype, self.dtype):
177 raise ValidationError(f"dtype {obj.dtype} != {self.dtype}")
180class DimsSchema(BaseSchema):
181 def __init__(self, *dims: DimsT) -> None:
182 self.dims = dims
184 def _fill_in_face_dim(self, obj: Union[xr.DataArray, xu.UgridDataArray]):
185 """
186 Return dims with a filled in face dim if necessary.
187 """
188 if "{face_dim}" in self.dims and isinstance(obj, xu.UgridDataArray):
189 return tuple(
190 (
191 obj.ugrid.grid.face_dimension if i == "{face_dim}" else i
192 for i in self.dims
193 )
194 )
195 elif "{edge_dim}" in self.dims and isinstance(obj, xu.UgridDataArray):
196 return tuple(
197 (
198 obj.ugrid.grid.edge_dimension if i == "{edge_dim}" else i
199 for i in self.dims
200 )
201 )
202 else:
203 return self.dims
205 def validate(self, obj: GridDataArray, **kwargs) -> None:
206 """Validate dimensions
207 Parameters
208 ----------
209 dims : Tuple[Union[str, None]]
210 Dimensions of the DataArray. `None` may be used as a wildcard value.
211 """
212 dims = self._fill_in_face_dim(obj)
213 # Force to tuple for error message print
214 expected = tuple(dims)
215 actual = tuple(obj.dims)
216 if actual != expected:
217 raise ValidationError(f"dim mismatch: expected {expected}, got {actual}")
220class EmptyIndexesSchema(BaseSchema):
221 """
222 Verify indexes, check if no dims with zero size are included. Skips
223 unstructured grid dimensions.
224 """
226 def __init__(self) -> None:
227 pass
229 def get_dims_to_validate(self, obj: Union[xr.DataArray, xu.UgridDataArray]):
230 dims_to_validate = list(obj.dims)
232 # Remove face dim from list to validate, as it has no ``indexes``
233 # attribute.
234 if isinstance(obj, xu.UgridDataArray):
235 ugrid_dims = obj.ugrid.grid.dimensions
236 dims_to_validate = [
237 dim for dim in dims_to_validate if dim not in ugrid_dims
238 ]
239 return dims_to_validate
241 def validate(self, obj: GridDataArray, **kwargs) -> None:
242 dims_to_validate = self.get_dims_to_validate(obj)
244 for dim in dims_to_validate:
245 if len(obj.indexes[dim]) == 0:
246 raise ValidationError(f"provided dimension {dim} with size 0")
249class IndexesSchema(EmptyIndexesSchema):
250 """
251 Verify indexes, check if no dims with zero size are included and that
252 indexes are monotonic. Skips unstructured grid dimensions.
253 """
255 def __init__(self) -> None:
256 pass
258 def validate(self, obj: GridDataArray, **kwargs) -> None:
259 # Test if indexes all empty
260 super().validate(obj)
262 dims_to_validate = self.get_dims_to_validate(obj)
264 for dim in dims_to_validate:
265 if dim == "y":
266 if not obj.indexes[dim].is_monotonic_decreasing:
267 raise ValidationError(
268 f"coord {dim} which is not monotonically decreasing"
269 )
271 else:
272 if not obj.indexes[dim].is_monotonic_increasing:
273 raise ValidationError(
274 f"coord {dim} which is not monotonically increasing"
275 )
278class ShapeSchema(BaseSchema):
279 def __init__(self, shape: ShapeT) -> None:
280 """
281 Validate shape.
283 Parameters
284 ----------
285 shape : ShapeT
286 Shape of the DataArray. `None` may be used as a wildcard value.
287 """
288 self.shape = shape
290 def validate(self, obj: GridDataArray, **kwargs) -> None:
291 if len(self.shape) != len(obj.shape):
292 raise ValidationError(
293 f"number of dimensions in shape ({len(obj.shape)}) o!= da.ndim ({len(self.shape)})"
294 )
296 for i, (actual, expected) in enumerate(zip(obj.shape, self.shape)):
297 if expected is not None and actual != expected:
298 raise ValidationError(
299 f"shape mismatch in axis {i}: {actual} != {expected}"
300 )
303class CompatibleSettingsSchema(BaseSchema):
304 def __init__(self, other: ScalarAsDataArray, other_value: bool) -> None:
305 """
306 Validate if settings are compatible
307 """
308 self.other = other
309 self.other_value = other_value
311 def validate(self, obj: ScalarAsDataArray, **kwargs) -> None:
312 other_obj = kwargs[self.other]
313 if scalar_None(obj) or scalar_None(other_obj):
314 return
315 expected = np.all(other_obj == self.other_value)
317 if obj and not expected:
318 raise ValidationError(
319 f"Incompatible setting: {self.other} should be {self.other_value}"
320 )
323class CoordsSchema(BaseSchema):
324 """
325 Validate presence of coords.
327 Parameters
328 ----------
329 coords : dict_like
330 coords of the DataArray. `None` may be used as a wildcard value.
331 """
333 def __init__(
334 self,
335 coords: Tuple[str, ...],
336 require_all_keys: bool = True,
337 allow_extra_keys: bool = True,
338 ) -> None:
339 self.coords = coords
340 self.require_all_keys = require_all_keys
341 self.allow_extra_keys = allow_extra_keys
343 def validate(self, obj: GridDataArray, **kwargs) -> None:
344 coords = list(obj.coords.keys())
346 if self.require_all_keys:
347 missing_keys = set(self.coords) - set(coords)
348 if missing_keys:
349 raise ValidationError(f"coords has missing keys: {missing_keys}")
351 if not self.allow_extra_keys:
352 extra_keys = set(coords) - set(self.coords)
353 if extra_keys:
354 raise ValidationError(f"coords has extra keys: {extra_keys}")
356 for key in self.coords:
357 if key not in coords:
358 raise ValidationError(f"key {key} not in coords")
361class OtherCoordsSchema(BaseSchema):
362 """
363 Validate whether coordinates match those of other.
364 """
366 def __init__(
367 self,
368 other: str,
369 require_all_keys: bool = True,
370 allow_extra_keys: bool = True,
371 ):
372 self.other = other
373 self.require_all_keys = require_all_keys
374 self.allow_extra_keys = allow_extra_keys
376 def validate(self, obj: GridDataArray, **kwargs) -> None:
377 other_obj = kwargs[self.other]
378 other_coords = list(other_obj.coords.keys())
379 return CoordsSchema(
380 other_coords,
381 self.require_all_keys,
382 self.allow_extra_keys,
383 ).validate(obj)
386class ValueSchema(BaseSchema, abc.ABC):
387 """
388 Base class for AllValueSchema or AnyValueSchema.
389 """
391 def __init__(
392 self,
393 operator: str,
394 other: Any,
395 ignore: Optional[Tuple[str, str, Any]] = None,
396 ):
397 self.operator = OPERATORS[operator]
398 self.operator_str = operator
399 self.other = other
400 self.to_ignore = None
401 self.ignore_varname = None
403 if ignore:
404 self.ignore_varname = ignore[0]
405 self.to_ignore = partial_operator(ignore[1], ignore[2])
407 def get_explicitly_ignored(self, kwargs: Dict) -> Any:
408 """
409 Get cells that should be explicitly ignored by the schema
410 """
411 if self.to_ignore:
412 ignore_obj = kwargs[self.ignore_varname]
413 return self.to_ignore(ignore_obj)
414 else:
415 return False
418class AllValueSchema(ValueSchema):
419 """
420 Validate whether all values pass a condition.
422 E.g. if operator is ">":
424 assert (values > threshold).all()
425 """
427 def validate(self, obj: GridDataArray, **kwargs) -> None:
428 if isinstance(self.other, str):
429 other_obj = kwargs[self.other]
430 else:
431 other_obj = self.other
433 if scalar_None(obj) or scalar_None(other_obj):
434 return
436 explicitly_ignored = self.get_explicitly_ignored(kwargs)
438 ignore = (
439 np.isnan(obj) | np.isnan(other_obj) | explicitly_ignored
440 ) # ignore nan by setting to True
442 condition = self.operator(obj, other_obj)
443 condition = condition | ignore
444 if not condition.all():
445 raise ValidationError(
446 f"not all values comply with criterion: {self.operator_str} {self.other}"
447 )
450class AnyValueSchema(ValueSchema):
451 """
452 Validate whether any value passes a condition.
454 E.g. if operator is ">":
456 assert (values > threshold).any()
457 """
459 def validate(self, obj: GridDataArray, **kwargs) -> None:
460 if isinstance(self.other, str):
461 other_obj = kwargs[self.other]
462 else:
463 other_obj = self.other
465 if scalar_None(obj) or scalar_None(other_obj):
466 return
468 explicitly_ignored = self.get_explicitly_ignored(kwargs)
470 ignore = (
471 ~np.isnan(obj) | ~np.isnan(other_obj) | explicitly_ignored
472 ) # ignore nan by setting to False
474 condition = self.operator(obj, other_obj)
475 condition = condition | ignore
476 if not condition.any():
477 raise ValidationError(
478 f"not a single value complies with criterion: {self.operator_str} {self.other}"
479 )
482def _notnull(obj):
483 """
484 Helper function; does the same as xr.DataArray.notnull. This function is to
485 avoid an issue where xr.DataArray.notnull() returns ordinary numpy arrays
486 for instances of xu.UgridDataArray.
487 """
489 return ~np.isnan(obj)
492class NoDataSchema(BaseSchema):
493 def __init__(
494 self,
495 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
496 ):
497 if isinstance(is_notnull, tuple):
498 op, value = is_notnull
499 self.is_notnull = partial_operator(op, value)
500 else:
501 self.is_notnull = is_notnull
504class AllNoDataSchema(NoDataSchema):
505 """
506 Fails when all data is NoData.
507 """
509 def validate(self, obj: GridDataArray, **kwargs) -> None:
510 valid = self.is_notnull(obj)
511 if ~valid.any():
512 raise ValidationError("all nodata")
515class AnyNoDataSchema(NoDataSchema):
516 """
517 Fails when any data is NoData.
518 """
520 def validate(self, obj: GridDataArray, **kwargs) -> None:
521 valid = self.is_notnull(obj)
522 if ~valid.all():
523 raise ValidationError("found a nodata value")
526class NoDataComparisonSchema(BaseSchema):
527 """
528 Base class for IdentityNoDataSchema and AllInsideNoDataSchema.
529 """
531 def __init__(
532 self,
533 other: str,
534 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
535 is_other_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
536 ):
537 self.other = other
538 if isinstance(is_notnull, tuple):
539 op, value = is_notnull
540 self.is_notnull = partial_operator(op, value)
541 else:
542 self.is_notnull = is_notnull
544 if isinstance(is_other_notnull, tuple):
545 op, value = is_other_notnull
546 self.is_other_notnull = partial_operator(op, value)
547 else:
548 self.is_other_notnull = is_other_notnull
551class IdentityNoDataSchema(NoDataComparisonSchema):
552 """
553 Checks that the NoData values are located at exactly the same locations.
555 Tests only if if all dimensions of the other object are present in the
556 object. So tests if "stage" with `{time, layer, y, x}` compared to "idomain"
557 `{layer, y, x}` but doesn't test if "k" with `{layer}` is comperated to
558 "idomain" `{layer, y, x}`
559 """
561 def validate(self, obj: GridDataArray, **kwargs) -> None:
562 other_obj = kwargs[self.other]
564 # Only test if object has all dimensions in other object.
565 missing_dims = set(other_obj.dims) - set(obj.dims)
567 if len(missing_dims) == 0:
568 valid = self.is_notnull(obj)
569 other_valid = self.is_other_notnull(other_obj)
570 if (valid ^ other_valid).any():
571 raise ValidationError(f"nodata is not aligned with {self.other}")
574class AllInsideNoDataSchema(NoDataComparisonSchema):
575 """
576 Checks that all notnull values all occur within the notnull values of other.
577 """
579 def validate(self, obj: GridDataArray, **kwargs) -> None:
580 other_obj = kwargs[self.other]
581 valid = self.is_notnull(obj)
582 other_valid = self.is_other_notnull(other_obj)
584 valid, other_valid = align_other_obj_with_coords(valid, other_obj)
586 if (valid & ~other_valid).any():
587 raise ValidationError(f"data values found at nodata values of {self.other}")
590class ActiveCellsConnectedSchema(BaseSchema):
591 """
592 Check if active cells are connected, to avoid isolated islands which can
593 cause convergence issues, if they don't have a head boundary condition, but
594 do have a specified flux.
596 Note
597 ----
598 This schema only works for structured grids.
599 """
601 def __init__(
602 self,
603 is_notnull: Union[Callable, Tuple[str, Any]] = _notnull,
604 ):
605 if isinstance(is_notnull, tuple):
606 op, value = is_notnull
607 self.is_notnull = partial_operator(op, value)
608 else:
609 self.is_notnull = is_notnull
611 def validate(self, obj: GridDataArray, **kwargs) -> None:
612 if isinstance(obj, xu.UgridDataArray):
613 # TODO: https://deltares.github.io/xugrid/api/xugrid.UgridDataArrayAccessor.connected_components.html
614 raise NotImplementedError(
615 f"Schema {self.__name__} only works for structured grids, received xu.UgridDataArray."
616 )
618 active = self.is_notnull(obj)
620 _, nlabels = scipy.ndimage.label(active)
621 if nlabels > 1:
622 raise ValidationError(
623 f"{nlabels} disconnected areas detected in model domain"
624 )