Coverage for C:\src\imod-python\imod\typing\structured.py: 92%
128 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
1# %%
3import itertools
4from collections import defaultdict
5from typing import Any, DefaultDict, Dict, List, Set, Tuple
7import dask
8import numpy as np
9import xarray as xr
11# %%
14def check_dtypes(das: List[xr.DataArray]) -> None:
15 """Check whether the dtypes of all arrays are the same."""
16 dtypes = set(da.dtype for da in das)
17 if len(dtypes) != 1:
18 raise TypeError(f"DataArrays do not match in dtype: {dtypes}")
19 return
22def _is_nonunique_dimsize(sizes: Set[int]) -> bool:
23 return len(sizes) != 1
26def check_sizes(sizes: DefaultDict[str, Set[int]], attribute: str) -> None:
27 """Utility for checking a dict of dimension names and sizes. Skips x and y."""
28 sizes.pop("x", None)
29 sizes.pop("y", None)
30 conflicting = {k: v for k, v in sizes.items() if _is_nonunique_dimsize(v)}
31 if conflicting:
32 message = (
33 f"DataArrays do not match in {attribute} along dimension(s):\n"
34 + "\n".join([f" {k}: {v}" for k, v in conflicting.items()])
35 )
36 raise ValueError(message)
37 return
40def check_dims(das: List[xr.DataArray]) -> None:
41 all_dims = set(da.dims for da in das)
42 if len(all_dims) != 1:
43 raise ValueError(
44 f"All DataArrays should have exactly the same dimensions. Found: {all_dims}"
45 )
46 last_dims = das[0].dims[-2:]
47 if not last_dims == ("y", "x"):
48 raise ValueError(f'Last dimensions must be ("y", "x"). Found: {last_dims}')
49 check_dim_sizes(das)
52def check_dim_sizes(das: List[xr.DataArray]) -> None:
53 """Check whether all non-xy dims are equally sized."""
54 sizes = defaultdict(set)
55 for da in das:
56 for key, value in da.sizes.items():
57 sizes[key].add(value)
58 check_sizes(sizes, "size")
59 return
62def check_coords(das: List[xr.DataArray]):
63 def drop_xy(coords) -> Dict[str, Any]:
64 coords = dict(coords)
65 coords.pop("y")
66 coords.pop("x")
67 return xr.Coordinates(coords)
69 first_coords = drop_xy(das[0].coords)
70 disjoint = [
71 i + 1
72 for i, da in enumerate(das[1:])
73 if not first_coords.equals(drop_xy(da.coords))
74 ]
75 if disjoint:
76 raise ValueError(
77 f"Non x-y coordinates do not match for partition 0 with partitions: {disjoint}"
78 )
79 return
82def check_chunk_sizes(das: List[xr.DataArray]) -> None:
83 """Check whether all chunks are equal on non-xy dims."""
84 chunks = [da.chunks for da in das]
85 iterator = (item is None for item in chunks)
86 allnone = all(iterator)
87 if allnone:
88 return
89 if any(iterator) != allnone:
90 raise ValueError("Some DataArrays are chunked, while others are not.")
92 sizes = defaultdict(set)
93 for da in das:
94 for key, value in zip(da.dims, da.chunks):
95 sizes[key].add(value)
96 check_sizes(sizes, "chunks")
97 return
100def merge_arrays(
101 arrays: List[np.ndarray],
102 ixs: List[np.ndarray],
103 iys: List[np.ndarray],
104 yx_shape: Tuple[int, int],
105) -> np.ndarray:
106 """
107 Merge the arrays in the last two (y, x) dimensions.
109 Parameters
110 ----------
111 arrays: list of N np.ndarray
112 ixs: list of N np.ndarray of int
113 The i-th element are the x indices of the i-th array into the merged
114 array.
115 iys: list of N np.ndarray of int
116 The i-th element are the y indices of the i-th array into the merged
117 array.
118 yx_shape: tuple of int
119 The number of rows and columns of the merged array.
121 Returns
122 -------
123 merged: np.ndarray
124 """
125 first = arrays[0]
126 shape = first.shape[:-2] + yx_shape
127 out = np.full(shape, np.nan, dtype=first.dtype)
128 for a, ix, iy in zip(arrays, ixs, iys):
129 ysize, xsize = a.shape[-2:]
130 # Create view of partition, see:
131 # https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
132 out_partition_view = out[..., iy : iy + ysize, ix : ix + xsize]
133 # Assign active values to view (updates `out` inplace)
134 out_partition_view[...] = np.where(~np.isnan(a), a, out_partition_view)
135 return out
138def _unique_coords(das: List[xr.DataArray], dim: str) -> xr.DataArray:
139 """Collect unique coords in list of dataarrays"""
140 return np.unique(np.concatenate([da.coords[dim].values for da in das]))
143def _is_nonequidistant_coord(da: xr.DataArray, dim: str) -> bool:
144 return (dim in da.coords) and (da.coords[dim].size != 1)
147def _merge_nonequidistant_coords(
148 das: List[xr.DataArray], coordname: str, indices: List[np.ndarray], nsize: int
149):
150 dtype = das[0].coords[coordname].dtype
151 out = np.full((nsize,), np.nan, dtype=dtype)
152 for da, index in zip(das, indices):
153 coords = da.coords[coordname]
154 out[index : index + coords.size] = coords.values
155 return out
158def _merge_partitions(das: List[xr.DataArray]) -> xr.DataArray:
159 # Do some input checking
160 check_dtypes(das)
161 check_dims(das)
162 check_chunk_sizes(das)
163 check_coords(das)
165 # Create the x and y coordinates of the merged grid.
166 x = _unique_coords(das, "x")
167 y = _unique_coords(das, "y")
168 nrow = y.size
169 ncol = x.size
170 # Compute the indices for where the different subdomain parts belong
171 # in the merged grid.
172 ixs = [np.searchsorted(x, da.x.values[0], side="left") for da in das]
173 iys = [nrow - np.searchsorted(y, da.y.values[0], side="right") for da in das]
174 yx_shape = (nrow, ncol)
176 # Collect coordinates
177 first = das[0]
178 coords = dict(first.coords)
179 coords["x"] = x
180 coords["y"] = y[::-1]
181 if _is_nonequidistant_coord(first, "dx"):
182 coords["dx"] = ("x", _merge_nonequidistant_coords(das, "dx", ixs, ncol))
183 if _is_nonequidistant_coord(first, "dy"):
184 coords["dy"] = ("y", _merge_nonequidistant_coords(das, "dy", iys, nrow))
186 arrays = [da.data for da in das]
187 if first.chunks is None:
188 # If the data is in memory, merge all at once.
189 data = merge_arrays(arrays, ixs, iys, yx_shape)
190 else:
191 # Iterate over the chunks of the dask array. Collect the chunks
192 # from every partition and merge them, chunk by chunk.
193 # The delayed merged result is stored as a flat list. These can
194 # be directly concatenated into a new dask array if chunking occurs
195 # on only the first dimension (e.g. time), but not if chunks exist
196 # in multiple dimensions (e.g. time and layer).
197 #
198 # dask.array.block() is capable of concatenating over multiple
199 # dimensions if we feed it a nested list of lists of dask arrays.
200 # This is more easily represented by a numpy array of objects
201 # (dask arrays), since numpy has nice tooling for reshaping.
202 #
203 # Normally, we'd append to a list, then convert to numpy array and
204 # reshape. However, numpy attempts to join a list of dask arrays into
205 # a single large numpy array when initialized. This behavior is not
206 # triggered when setting individual elements of the array, so we
207 # create the numpy array in advance and set its elements.
209 block_shape = das[0].data.blocks.shape[:-2]
210 merged_blocks = np.empty(np.prod(block_shape), dtype=object)
211 dimension_ranges = [range(size) for size in block_shape]
212 for i, index in enumerate(itertools.product(*dimension_ranges)):
213 # This is a workaround for python 3.10
214 # FUTURE: can be rewritten to arr.blocks[*index, ...] in python 3.11
215 index_with_ellipsis = tuple(index) + (...,)
216 # arr.blocks provides us access to the chunks of the array.
217 arrays_to_merge = [arr.blocks[index_with_ellipsis] for arr in arrays]
218 delayed_merged = dask.delayed(merge_arrays)(
219 arrays_to_merge, ixs, iys, yx_shape
220 )
221 dask_merged = dask.array.from_delayed(
222 delayed_merged,
223 shape=arrays_to_merge[0].shape[:-2] + yx_shape,
224 dtype=first.dtype,
225 )
226 merged_blocks[i] = dask_merged
228 # After merging, the xy chunks are always (1, 1)
229 reshaped = merged_blocks.reshape(block_shape + (1, 1))
230 data = dask.array.block(reshaped.tolist())
232 return xr.DataArray(
233 data=data,
234 coords=coords,
235 dims=first.dims,
236 )
239def merge_partitions(
240 das: List[xr.DataArray | xr.Dataset],
241) -> xr.Dataset:
242 first_item = das[0]
243 if isinstance(first_item, xr.Dataset):
244 unique_keys = set([key for da in das for key in da.keys()])
245 merged_ls = []
246 for key in unique_keys:
247 merged_ls.append(_merge_partitions([da[key] for da in das]).rename(key))
248 return xr.merge(merged_ls)
249 elif isinstance(first_item, xr.DataArray):
250 # Store name to rename after concatenation
251 name = first_item.name
252 return _merge_partitions(das).to_dataset(name=name)
253 else:
254 raise TypeError(
255 f"Expected type: xr.DataArray or xr.Dataset, got {type(first_item)}"
256 )