Coverage for C:\src\imod-python\imod\typing\grid.py: 89%
167 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
1import pickle
2import textwrap
3from typing import Callable, Mapping, Sequence
5import numpy as np
6import xarray as xr
7import xugrid as xu
8from fastcore.dispatch import typedispatch
10from imod.typing import GridDataArray, GridDataset, structured
11from imod.util.spatial import _polygonize
14@typedispatch
15def zeros_like(grid: xr.DataArray, *args, **kwargs):
16 return xr.zeros_like(grid, *args, **kwargs)
19@typedispatch
20def zeros_like(grid: xu.UgridDataArray, *args, **kwargs):
21 return xu.zeros_like(grid, *args, **kwargs)
24@typedispatch
25def ones_like(grid: xr.DataArray, *args, **kwargs):
26 return xr.ones_like(grid, *args, **kwargs)
29@typedispatch
30def ones_like(grid: xu.UgridDataArray, *args, **kwargs):
31 return xu.ones_like(grid, *args, **kwargs)
34@typedispatch
35def nan_like(grid: xr.DataArray, dtype=np.float32, *args, **kwargs):
36 return xr.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs)
39@typedispatch
40def nan_like(grid: xu.UgridDataArray, dtype=np.float32, *args, **kwargs):
41 return xu.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs)
44@typedispatch
45def is_unstructured(grid: xu.UgridDataArray | xu.UgridDataset) -> bool:
46 return True
49@typedispatch
50def is_unstructured(grid: xr.DataArray | xr.Dataset) -> bool:
51 return False
54def _force_decreasing_y(structured_grid: xr.DataArray | xr.Dataset):
55 flip = slice(None, None, -1)
56 if structured_grid.indexes["y"].is_monotonic_increasing:
57 structured_grid = structured_grid.isel(y=flip)
58 elif not structured_grid.indexes["y"].is_monotonic_decreasing:
59 raise RuntimeError(
60 f"Non-monotonous y-coordinates for grid: {structured_grid.name}."
61 )
62 return structured_grid
65def _get_first_item(objects: Sequence):
66 return next(iter(objects))
69# Typedispatching doesn't work based on types of list elements, therefore resort to
70# isinstance testing
71def _type_dispatch_functions_on_grid_sequence(
72 objects: Sequence[GridDataArray | GridDataset],
73 unstructured_func: Callable,
74 structured_func: Callable,
75 *args,
76 **kwargs,
77) -> GridDataArray | GridDataset:
78 """
79 Type dispatch functions on sequence of grids. Functions like merging or concatenating.
80 """
81 first_object = _get_first_item(objects)
82 start_type = type(first_object)
83 homogeneous = all([isinstance(o, start_type) for o in objects])
84 if not homogeneous:
85 unique_types = set([type(o) for o in objects])
86 raise TypeError(
87 f"Only homogeneous sequences can be reduced, received sequence of {unique_types}"
88 )
89 if isinstance(first_object, (xu.UgridDataArray, xu.UgridDataset)):
90 return unstructured_func(objects, *args, **kwargs)
91 elif isinstance(first_object, (xr.DataArray, xr.Dataset)):
92 return _force_decreasing_y(structured_func(objects, *args, **kwargs))
93 raise TypeError(
94 f"'{unstructured_func.__name__}' not supported for type {type(objects[0])}"
95 )
98# Typedispatching doesn't work based on types of dict elements, therefore resort
99# to manual type testing
100def _type_dispatch_functions_on_dict(
101 dict_of_objects: Mapping[str, GridDataArray | float | bool | int],
102 unstructured_func: Callable,
103 structured_func: Callable,
104 *args,
105 **kwargs,
106):
107 """
108 Typedispatch function on grid and scalar variables provided in dictionary.
109 Types do not need to be homogeneous as scalars and grids can be mixed. No
110 mixing of structured and unstructured grids is allowed. Also allows running
111 function on dictionary with purely scalars, in which case it will call to
112 the xarray function.
113 """
115 error_msg = textwrap.dedent(
116 """
117 Received both structured grid (xr.DataArray) and xu.UgridDataArray. This
118 means structured grids as well as unstructured grids were provided.
119 """
120 )
122 if dict_of_objects is None:
123 return xr.Dataset()
125 types = [type(arg) for arg in dict_of_objects.values()]
126 has_unstructured = xu.UgridDataArray in types
127 # Test structured if xr.DataArray and spatial.
128 has_structured_grid = any(
129 [
130 isinstance(arg, xr.DataArray) and is_spatial_2D(arg)
131 for arg in dict_of_objects.values()
132 ]
133 )
134 if has_structured_grid and has_unstructured:
135 raise TypeError(error_msg)
136 if has_unstructured:
137 return unstructured_func([dict_of_objects], *args, **kwargs)
139 return structured_func([dict_of_objects], *args, **kwargs)
142def merge(
143 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
144) -> GridDataset:
145 return _type_dispatch_functions_on_grid_sequence(
146 objects, xu.merge, xr.merge, *args, **kwargs
147 )
150def merge_partitions(
151 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
152) -> GridDataArray | GridDataset:
153 return _type_dispatch_functions_on_grid_sequence(
154 objects, xu.merge_partitions, structured.merge_partitions, *args, **kwargs
155 )
158def concat(
159 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
160) -> GridDataArray | GridDataset:
161 return _type_dispatch_functions_on_grid_sequence(
162 objects, xu.concat, xr.concat, *args, **kwargs
163 )
166def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs):
167 """
168 Work around xugrid issue https://github.com/Deltares/xugrid/issues/179
170 Expects only one dictionary in list. List is used to have same API as
171 xr.merge().
173 Merges unstructured grids first, then manually assigns scalar variables.
174 """
175 if len(variables_to_merge) > 1:
176 raise ValueError(
177 f"Only one dict of variables expected, got {len(variables_to_merge)}"
178 )
180 variables_to_merge_dict = variables_to_merge[0]
182 if not isinstance(variables_to_merge_dict, dict):
183 raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}")
185 # Separate variables into list of grids and dict of scalar variables
186 grids_ls = []
187 scalar_dict = {}
188 for name, variable in variables_to_merge_dict.items():
189 if isinstance(variable, xu.UgridDataArray):
190 grids_ls.append(variable.rename(name))
191 else:
192 scalar_dict[name] = variable
194 # Merge grids
195 dataset = xu.merge(grids_ls, *args, **kwargs)
197 # Temporarily work around this xugrid issue, until fixed:
198 # https://github.com/Deltares/xugrid/issues/206
199 grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids]
200 unique_grid_hashes = np.unique(grid_hashes)
201 if unique_grid_hashes.size > 1:
202 raise ValueError(
203 "Multiple grids provided, please provide data on one unique grid"
204 )
205 else:
206 # Possibly won't work anymore if this ever gets implemented:
207 # https://github.com/Deltares/xugrid/issues/195
208 dataset._grids = [dataset.grids[0]]
210 # Assign scalar variables manually
211 for name, variable in scalar_dict.items():
212 dataset[name] = variable
214 return dataset
217def merge_with_dictionary(
218 variables_to_merge: Mapping[str, GridDataArray | float | bool | int],
219 *args,
220 **kwargs,
221):
222 return _type_dispatch_functions_on_dict(
223 variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs
224 )
227@typedispatch
228def bounding_polygon(active: xr.DataArray):
229 """Return bounding polygon of active cells"""
230 to_polygonize = active.where(active, other=np.nan)
231 polygons_gdf = _polygonize(to_polygonize)
232 # Filter polygons with inactive values (NaN)
233 is_active_polygon = polygons_gdf["value"] == 1.0
234 return polygons_gdf.loc[is_active_polygon]
237@typedispatch
238def bounding_polygon(active: xu.UgridDataArray):
239 """Return bounding polygon of active cells"""
240 active_indices = np.where(active > 0)[0]
241 domain_slice = {f"{active.ugrid.grid.face_dimension}": active_indices}
242 active_clipped = active.isel(domain_slice, missing_dims="ignore")
244 return active_clipped.ugrid.grid.bounding_polygon()
247@typedispatch
248def is_spatial_2D(array: xr.DataArray) -> bool:
249 """Return True if the array contains data in at least 2 spatial dimensions"""
250 coords = array.coords
251 dims = array.dims
252 has_spatial_coords = "x" in coords and "y" in coords
253 has_spatial_dims = "x" in dims and "y" in dims
254 return has_spatial_coords & has_spatial_dims
257@typedispatch
258def is_spatial_2D(array: xu.UgridDataArray) -> bool:
259 """Return True if the array contains data associated to cell faces"""
260 face_dim = array.ugrid.grid.face_dimension
261 dims = array.dims
262 coords = array.coords
263 has_spatial_coords = face_dim in coords
264 has_spatial_dims = face_dim in dims
265 return has_spatial_dims & has_spatial_coords
268@typedispatch
269def is_spatial_2D(_: object) -> bool:
270 return False
273@typedispatch
274def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool:
275 return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid)
278@typedispatch
279def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool:
280 return array1.equals(array2)
283@typedispatch
284def is_equal(array1: object, array2: object) -> bool:
285 return False
288@typedispatch
289def is_same_domain(grid1: xu.UgridDataArray, grid2: xu.UgridDataArray) -> bool:
290 return grid1.coords.equals(grid2.coords) and grid1.ugrid.grid.equals(
291 grid2.ugrid.grid
292 )
295@typedispatch
296def is_same_domain(grid1: xr.DataArray, grid2: xr.DataArray) -> bool:
297 return grid1.coords.equals(grid2.coords)
300@typedispatch
301def is_same_domain(grid1: object, grid2: object) -> bool:
302 return False
305@typedispatch
306def get_spatial_dimension_names(grid: xr.DataArray) -> list[str]:
307 return ["x", "y", "layer", "dx", "dy"]
310@typedispatch
311def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]:
312 facedim = grid.ugrid.grid.face_dimension
313 return [facedim, "layer"]
316@typedispatch
317def get_spatial_dimension_names(grid: object) -> list[str]:
318 return []
321@typedispatch
322def get_grid_geometry_hash(grid: xr.DataArray) -> int:
323 hash_x = hash(pickle.dumps(grid["x"].values))
324 hash_y = hash(pickle.dumps(grid["y"].values))
325 return (hash_x, hash_y)
328@typedispatch
329def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int:
330 hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x))
331 hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y))
332 hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity))
333 return (hash_x, hash_y, hash_connectivity)
336@typedispatch
337def get_grid_geometry_hash(grid: object) -> int:
338 raise ValueError("get_grid_geometry_hash not supported for this object.")