Coverage for C:\src\imod-python\imod\mf6\utilities\mask.py: 92%
65 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
1import numbers
3import numpy as np
4from xarray.core.utils import is_scalar
6from imod.mf6.auxiliary_variables import (
7 expand_transient_auxiliary_variables,
8 remove_expanded_auxiliary_variables_from_dataset,
9)
10from imod.mf6.interfaces.imodel import IModel
11from imod.mf6.interfaces.ipackage import IPackage
12from imod.mf6.interfaces.isimulation import ISimulation
13from imod.typing.grid import GridDataArray, get_spatial_dimension_names, is_same_domain
16def _mask_all_models(
17 simulation: ISimulation,
18 mask: GridDataArray,
19):
20 spatial_dims = get_spatial_dimension_names(mask)
21 if any(coord not in spatial_dims for coord in mask.coords):
22 raise ValueError("unexpected coordinate dimension in masking domain")
24 if simulation.is_split():
25 raise ValueError(
26 "masking can only be applied to simulations that have not been split. Apply masking before splitting."
27 )
29 flowmodels = list(simulation.get_models_of_type("gwf6").keys())
30 transportmodels = list(simulation.get_models_of_type("gwt6").keys())
31 modelnames = flowmodels + transportmodels
33 for name in modelnames:
34 if is_same_domain(simulation[name].domain, mask):
35 simulation[name].mask_all_packages(mask)
36 else:
37 raise ValueError(
38 "masking can only be applied to simulations when all the models in the simulation use the same grid."
39 )
42def _mask_all_packages(
43 model: IModel,
44 mask: GridDataArray,
45):
46 spatial_dimension_names = get_spatial_dimension_names(mask)
47 if any(coord not in spatial_dimension_names for coord in mask.coords):
48 raise ValueError("unexpected coordinate dimension in masking domain")
50 for pkgname, pkg in model.items():
51 model[pkgname] = pkg.mask(mask)
52 model.purge_empty_packages()
55def _mask(package: IPackage, mask: GridDataArray) -> IPackage:
56 masked = {}
57 if len(package.auxiliary_data_fields) > 0:
58 remove_expanded_auxiliary_variables_from_dataset(package)
59 for var in package.dataset.data_vars.keys():
60 if _skip_masking_variable(package, var, package.dataset[var]):
61 masked[var] = package.dataset[var]
62 else:
63 masked[var] = _mask_spatial_var(package, var, mask)
64 if len(package.auxiliary_data_fields) > 0:
65 expand_transient_auxiliary_variables(package)
66 return type(package)(**masked)
69def _skip_masking_variable(package: IPackage, var: str, da: GridDataArray) -> bool:
70 if (
71 package._skip_masking_dataarray(var)
72 or len(da.dims) == 0
73 or set(da.coords).issubset(["layer"])
74 ):
75 return True
76 if is_scalar(da.values[()]):
77 return True
78 spatial_dims = ["x", "y", "mesh2d_nFaces", "layer"]
79 if not np.any([coord in spatial_dims for coord in da.coords]):
80 return True
81 return False
84def _mask_spatial_var(self, var: str, mask: GridDataArray) -> GridDataArray:
85 da = self.dataset[var]
86 array_mask = _adjust_mask_for_unlayered_data(da, mask)
88 if issubclass(da.dtype.type, numbers.Integral):
89 if var == "idomain":
90 return da.where(array_mask > 0, other=array_mask)
91 else:
92 return da.where(array_mask > 0, other=0)
93 elif issubclass(da.dtype.type, numbers.Real):
94 return da.where(array_mask > 0)
95 else:
96 raise TypeError(
97 f"Expected dtype float or integer. Received instead: {da.dtype}"
98 )
101def _adjust_mask_for_unlayered_data(
102 da: GridDataArray, mask: GridDataArray
103) -> GridDataArray:
104 """
105 Some arrays are not layered while the mask is layered (for example the
106 top array in dis or disv packaged). In that case we use the top layer of
107 the mask to perform the masking. If layer is not a dataset dimension,
108 but still a dataset coordinate, we limit the mask to the relevant layer
109 coordinate(s).
110 """
111 array_mask = mask
112 if "layer" in da.coords and "layer" not in da.dims:
113 array_mask = mask.sel(layer=da.coords["layer"])
114 if "layer" not in da.coords and "layer" in array_mask.coords:
115 array_mask = mask.isel(layer=0)
117 return array_mask