Coverage for C:\src\imod-python\imod\typing\grid.py: 91%
191 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
1import pickle
2import textwrap
3from typing import Callable, Mapping, Sequence
5import numpy as np
6import xarray as xr
7import xugrid as xu
8from fastcore.dispatch import typedispatch
10from imod.typing import GridDataArray, GridDataset, structured
11from imod.util.spatial import _polygonize
14@typedispatch # type: ignore [no-redef]
15def zeros_like(grid: xr.DataArray, *args, **kwargs):
16 return xr.zeros_like(grid, *args, **kwargs)
19@typedispatch # type: ignore [no-redef]
20def zeros_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811
21 return xu.zeros_like(grid, *args, **kwargs)
24@typedispatch # type: ignore [no-redef]
25def ones_like(grid: xr.DataArray, *args, **kwargs):
26 return xr.ones_like(grid, *args, **kwargs)
29@typedispatch # type: ignore [no-redef]
30def ones_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811
31 return xu.ones_like(grid, *args, **kwargs)
34@typedispatch # type: ignore [no-redef]
35def nan_like(grid: xr.DataArray, dtype=np.float32, *args, **kwargs):
36 return xr.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs)
39@typedispatch # type: ignore [no-redef]
40def nan_like(grid: xu.UgridDataArray, dtype=np.float32, *args, **kwargs): # noqa: F811
41 return xu.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs)
44@typedispatch # type: ignore [no-redef]
45def is_unstructured(grid: xu.UgridDataArray | xu.UgridDataset) -> bool:
46 return True
49@typedispatch # type: ignore [no-redef]
50def is_unstructured(grid: xr.DataArray | xr.Dataset) -> bool: # noqa: F811
51 return False
54def _force_decreasing_y(structured_grid: xr.DataArray | xr.Dataset):
55 flip = slice(None, None, -1)
56 if structured_grid.indexes["y"].is_monotonic_increasing:
57 structured_grid = structured_grid.isel(y=flip)
58 elif not structured_grid.indexes["y"].is_monotonic_decreasing:
59 raise RuntimeError(
60 f"Non-monotonous y-coordinates for grid: {structured_grid.name}."
61 )
62 return structured_grid
65def _get_first_item(objects: Sequence):
66 return next(iter(objects))
69# Typedispatching doesn't work based on types of list elements, therefore resort to
70# isinstance testing
71def _type_dispatch_functions_on_grid_sequence(
72 objects: Sequence[GridDataArray | GridDataset],
73 unstructured_func: Callable,
74 structured_func: Callable,
75 *args,
76 **kwargs,
77) -> GridDataArray | GridDataset:
78 """
79 Type dispatch functions on sequence of grids. Functions like merging or concatenating.
80 """
81 first_object = _get_first_item(objects)
82 start_type = type(first_object)
83 homogeneous = all(isinstance(o, start_type) for o in objects)
84 if not homogeneous:
85 unique_types = {type(o) for o in objects}
86 raise TypeError(
87 f"Only homogeneous sequences can be reduced, received sequence of {unique_types}"
88 )
89 if isinstance(first_object, (xu.UgridDataArray, xu.UgridDataset)):
90 return unstructured_func(objects, *args, **kwargs)
91 elif isinstance(first_object, (xr.DataArray, xr.Dataset)):
92 return _force_decreasing_y(structured_func(objects, *args, **kwargs))
93 raise TypeError(
94 f"'{unstructured_func.__name__}' not supported for type {type(objects[0])}"
95 )
98# Typedispatching doesn't work based on types of dict elements, therefore resort
99# to manual type testing
100def _type_dispatch_functions_on_dict(
101 dict_of_objects: Mapping[str, GridDataArray | float | bool | int],
102 unstructured_func: Callable,
103 structured_func: Callable,
104 *args,
105 **kwargs,
106):
107 """
108 Typedispatch function on grid and scalar variables provided in dictionary.
109 Types do not need to be homogeneous as scalars and grids can be mixed. No
110 mixing of structured and unstructured grids is allowed. Also allows running
111 function on dictionary with purely scalars, in which case it will call to
112 the xarray function.
113 """
115 error_msg = textwrap.dedent(
116 """
117 Received both structured grid (xr.DataArray) and xu.UgridDataArray. This
118 means structured grids as well as unstructured grids were provided.
119 """
120 )
122 if dict_of_objects is None:
123 return xr.Dataset()
125 types = [type(arg) for arg in dict_of_objects.values()]
126 has_unstructured = xu.UgridDataArray in types
127 # Test structured if xr.DataArray and spatial.
128 has_structured_grid = any(
129 isinstance(arg, xr.DataArray) and is_spatial_2D(arg)
130 for arg in dict_of_objects.values()
131 )
132 if has_structured_grid and has_unstructured:
133 raise TypeError(error_msg)
134 if has_unstructured:
135 return unstructured_func([dict_of_objects], *args, **kwargs)
137 return structured_func([dict_of_objects], *args, **kwargs)
140def merge(
141 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
142) -> GridDataset:
143 return _type_dispatch_functions_on_grid_sequence(
144 objects, xu.merge, xr.merge, *args, **kwargs
145 )
148def merge_partitions(
149 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
150) -> GridDataArray | GridDataset:
151 return _type_dispatch_functions_on_grid_sequence(
152 objects, xu.merge_partitions, structured.merge_partitions, *args, **kwargs
153 )
156def concat(
157 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs
158) -> GridDataArray | GridDataset:
159 return _type_dispatch_functions_on_grid_sequence(
160 objects, xu.concat, xr.concat, *args, **kwargs
161 )
164def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs):
165 """
166 Work around xugrid issue https://github.com/Deltares/xugrid/issues/179
168 Expects only one dictionary in list. List is used to have same API as
169 xr.merge().
171 Merges unstructured grids first, then manually assigns scalar variables.
172 """
173 if len(variables_to_merge) > 1:
174 raise ValueError(
175 f"Only one dict of variables expected, got {len(variables_to_merge)}"
176 )
178 variables_to_merge_dict = variables_to_merge[0]
180 if not isinstance(variables_to_merge_dict, dict):
181 raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}")
183 # Separate variables into list of grids and dict of scalar variables
184 grids_ls = []
185 scalar_dict = {}
186 for name, variable in variables_to_merge_dict.items():
187 if isinstance(variable, xu.UgridDataArray):
188 grids_ls.append(variable.rename(name))
189 else:
190 scalar_dict[name] = variable
192 # Merge grids
193 dataset = xu.merge(grids_ls, *args, **kwargs)
195 # Temporarily work around this xugrid issue, until fixed:
196 # https://github.com/Deltares/xugrid/issues/206
197 grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids]
198 unique_grid_hashes = np.unique(grid_hashes)
199 if unique_grid_hashes.size > 1:
200 raise ValueError(
201 "Multiple grids provided, please provide data on one unique grid"
202 )
203 else:
204 # Possibly won't work anymore if this ever gets implemented:
205 # https://github.com/Deltares/xugrid/issues/195
206 dataset._grids = [dataset.grids[0]]
208 # Assign scalar variables manually
209 for name, variable in scalar_dict.items():
210 dataset[name] = variable
212 return dataset
215def merge_with_dictionary(
216 variables_to_merge: Mapping[str, GridDataArray | float | bool | int],
217 *args,
218 **kwargs,
219):
220 return _type_dispatch_functions_on_dict(
221 variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs
222 )
225@typedispatch # type: ignore [no-redef]
226def bounding_polygon(active: xr.DataArray):
227 """Return bounding polygon of active cells"""
228 to_polygonize = active.where(active, other=np.nan)
229 polygons_gdf = _polygonize(to_polygonize)
230 # Filter polygons with inactive values (NaN)
231 is_active_polygon = polygons_gdf["value"] == 1.0
232 return polygons_gdf.loc[is_active_polygon]
235@typedispatch # type: ignore [no-redef]
236def bounding_polygon(active: xu.UgridDataArray): # noqa: F811
237 """Return bounding polygon of active cells"""
238 active_indices = np.where(active > 0)[0]
239 domain_slice = {f"{active.ugrid.grid.face_dimension}": active_indices}
240 active_clipped = active.isel(domain_slice, missing_dims="ignore")
242 return active_clipped.ugrid.grid.bounding_polygon()
245@typedispatch # type: ignore [no-redef]
246def is_spatial_2D(array: xr.DataArray) -> bool:
247 """Return True if the array contains data in at least 2 spatial dimensions"""
248 coords = array.coords
249 dims = array.dims
250 has_spatial_coords = "x" in coords and "y" in coords
251 has_spatial_dims = "x" in dims and "y" in dims
252 return has_spatial_coords & has_spatial_dims
255@typedispatch # type: ignore [no-redef]
256def is_spatial_2D(array: xu.UgridDataArray) -> bool: # noqa: F811
257 """Return True if the array contains data associated to cell faces"""
258 face_dim = array.ugrid.grid.face_dimension
259 dims = array.dims
260 coords = array.coords
261 has_spatial_coords = face_dim in coords
262 has_spatial_dims = face_dim in dims
263 return has_spatial_dims & has_spatial_coords
266@typedispatch # type: ignore [no-redef]
267def is_spatial_2D(_: object) -> bool: # noqa: F811
268 return False
271@typedispatch # type: ignore [no-redef]
272def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool:
273 return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid)
276@typedispatch # type: ignore [no-redef]
277def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811
278 return array1.equals(array2)
281@typedispatch # type: ignore [no-redef]
282def is_equal(array1: object, array2: object) -> bool: # noqa: F811
283 return False
286@typedispatch # type: ignore [no-redef]
287def is_same_domain(grid1: xu.UgridDataArray, grid2: xu.UgridDataArray) -> bool:
288 return grid1.coords.equals(grid2.coords) and grid1.ugrid.grid.equals(
289 grid2.ugrid.grid
290 )
293@typedispatch # type: ignore [no-redef]
294def is_same_domain(grid1: xr.DataArray, grid2: xr.DataArray) -> bool: # noqa: F811
295 return grid1.coords.equals(grid2.coords)
298@typedispatch # type: ignore [no-redef]
299def is_same_domain(grid1: object, grid2: object) -> bool: # noqa: F811
300 return False
303@typedispatch # type: ignore [no-redef]
304def get_spatial_dimension_names(grid: xr.DataArray) -> list[str]:
305 return ["x", "y", "layer", "dx", "dy"]
308@typedispatch # type: ignore [no-redef]
309def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]: # noqa: F811
310 facedim = grid.ugrid.grid.face_dimension
311 return [facedim, "layer"]
314@typedispatch # type: ignore [no-redef]
315def get_spatial_dimension_names(grid: object) -> list[str]: # noqa: F811
316 return []
319@typedispatch # type: ignore [no-redef]
320def get_grid_geometry_hash(grid: xr.DataArray) -> int:
321 hash_x = hash(pickle.dumps(grid["x"].values))
322 hash_y = hash(pickle.dumps(grid["y"].values))
323 return (hash_x, hash_y)
326@typedispatch # type: ignore [no-redef]
327def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int: # noqa: F811
328 hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x))
329 hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y))
330 hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity))
331 return (hash_x, hash_y, hash_connectivity)
334@typedispatch # type: ignore [no-redef]
335def get_grid_geometry_hash(grid: object) -> int: # noqa: F811
336 raise ValueError("get_grid_geometry_hash not supported for this object.")
339@typedispatch # type: ignore [no-redef]
340def enforce_dim_order(grid: xr.DataArray) -> xr.DataArray:
341 """Enforce dimension order to iMOD Python standard"""
342 return grid.transpose("species", "time", "layer", "y", "x", missing_dims="ignore")
345@typedispatch # type: ignore [no-redef]
346def enforce_dim_order(grid: xu.UgridDataArray) -> xu.UgridDataArray: # noqa: F811
347 """Enforce dimension order to iMOD Python standard"""
348 face_dimension = grid.ugrid.grid.face_dimension
349 return grid.transpose(
350 "species", "time", "layer", face_dimension, missing_dims="ignore"
351 )
354def _enforce_unstructured(obj: GridDataArray, ugrid2d=xu.Ugrid2d) -> xu.UgridDataArray:
355 """Force obj to unstructured"""
356 return xu.UgridDataArray(xr.DataArray(obj), ugrid2d)
359def preserve_gridtype(func):
360 """
361 Decorator to preserve gridtype, this is to work around the following xugrid
362 behavior:
364 >>> UgridDataArray() * DataArray() -> UgridDataArray
365 >>> DataArray() * UgridDataArray() -> DataArray
367 with this decorator:
369 >>> UgridDataArray() * DataArray() -> UgridDataArray
370 >>> DataArray() * UgridDataArray() -> UgridDataArray
371 """
373 def decorator(*args, **kwargs):
374 unstructured = False
375 grid = None
376 for arg in args:
377 if is_unstructured(arg):
378 unstructured = True
379 grid = arg.ugrid.grid
381 x = func(*args, **kwargs)
383 if unstructured:
384 # Multiple grids returned
385 if isinstance(x, tuple):
386 return tuple(_enforce_unstructured(i, grid) for i in x)
387 return _enforce_unstructured(x, grid)
388 return x
390 return decorator