Coverage for C:\src\imod-python\imod\prepare\interpolate.py: 91%
233 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
1"""
2This module provides linear interpolation methods on regular grids, up to three
3dimensions. These functions are imported by the regrid.py module, which
4incorporates them in the Regridder class.
6### Introduction
7The interp_ functions do most of the work. One dimension is added per time
8(1, 2, 3). The simplest way to implement linear interpolation is as follows:
10First, for one dimension. Let node 0 be located at x0, and have a value of v0,
11and let node 1 be located at x1, and have a value of v1. Any value between node
120 and 1 can then be calculated as:
14v(x) = v0 + (x - x0) / (x1 - x0) * (v1 - v0)
16(x - x0) / (x1 - x0) is effectively a weighting of the different cells, and is
17called wx in the code below.
19Of course, we can rewrite it as:
21v(x) = v0 + (x - x0) * (v1 - v0) / (x1 - x0)
23Which is equal to the equation of a straight line with slope a, and intersect b.
25b = v0
26a = (delta_v / delta_x) = (v1 - v0) / (x1 - x0)
28Multilinear interpolation can be implemented as repeated one dimensional linear
29interpolation. Take two dimensions, with four cells. The values are located as
30such:
32 v00 -- v01
33 | |
34 | |
35 v10 -- v11
37v00 is upper left, v01 is upper right, etc. Given a point (x, y) between the
38drawn boundaries, wx as above, and:
39wy = (y - y0) / (y1 - y0)
41Then, to compute the value for (x, y), we interpolate in y twice, and in x once.
42(Or vice versa.)
43v0 = v00 + wy * (v01 - v00) # interpolate in y
44v1 = v10 + wy * (v11 - v10) # interpolate in y
45v = v0 + wx * (v1 - v0) # interpolate in x
47### Nodata
48All of this work perfectly, until we run into nodata values. The code below
49is aimed at xarray.DataArrays, which use np.nan values as a sentinel for nodata.
50NaN's pollute: 0 + nan = nan, 1 + 0 * nan = nan, etc.
51In between cells that are not nodata, interpolation goes fine. However, at the
52edges (say, to the left of v00 and v10), we get nan values. Note that the cell
53of v00 does cover the area, but given a nan to left, it'll end up as nodata.
55This is undesirable (the relative number of nodata cells increases).
56We have to catch nan values before they pollute the computations. This is
57straightforward in one dimension, but involves a bothersome number of
58conditionals in 2d, and certainly in 3d! An easier implementation in this case
59is not by doing repeated linear interpolations, but by using an accumulator,
60and and an accumulating divisor. We simply add the contribution of every cell
61to this accumulator, and skip it if it's nodata.
63The result is the same (the algebra to prove this is left as an exercise to the
64reader).
66### Note quite there yet
67Using the accumulator implementation, we end up creating more data, tiling over
68nodata parts. Since we're always between two cells, and the implementation above
69removes nodata values, non-nodata values will extend too far.
71The easiest way of remedying this is by taking into account where the cell
72boundaries of the original cell lie. If it's fully within a nodata cell, no
73extrapolation should occur.
75### Function responsibilities:
76* linear_inds_weights: returns per single dimensions the weights (wx, wy),
77 the source indexes (column, row numbers), and whether the destination x
78 lies fully within the cell of the matching source index. The source index
79 is always the "starting" or "leftmost" cell. Interpolation occurs between
80 this value and its neighbor to the right (index + 1).
81* interp_ function implement functionality described above, taking the weights,
82 indices, and the boolean within to compute interpolated values.
83* iter_interp: an example best serves to explain: let there be a DataArray with
84 dimensions time, layer, y, x. We wish to interpolate in two dimensions,
85 over y and x. Time and layer are "stacked" onto a single dimension, so that
86 a single loop in iter_interp suffices to interpolate over y and x, rather
87 than a nested loop (in which case the number of for loops depends on the on
88 the number of dimensions).
89* jit_interp: selects the right interpolation method for the number of
90 dimensions.
91* make_interp: provides the closure to avoid numba overhead. Basically, we
92 "inject" code before compilation, so we don't have to pass functions as an
93 argument at runtime. (Numba probably inlines the function in this case.)
94* nd_interp: collects the weights, reshapes the array so iter_interp will take
95 it.
96"""
98import numba
99import numpy as np
101from imod.prepare import common
104def _linear_inds_weights_1d(src_x, dst_x):
105 """
106 Returns indices and weights for linear interpolation along a single dimension.
107 A sentinel value of -1 is added for dst cells that are fully out of bounds.
109 Parameters
110 ----------
111 src_x : np.array
112 vertex coordinates of source
113 dst_x: np.array
114 vertex coordinates of destination
115 """
116 # Cannot interpolate "between" only one point
117 if not src_x.size > 2:
118 raise ValueError("src_x must larger than 2. Cannot interpolate with only point")
120 xmin = src_x.min()
121 xmax = src_x.max()
123 # Compute midpoints for linear interpolation
124 src_dx = np.diff(src_x)
125 mid_src_x = src_x[:-1] + 0.5 * src_dx
126 dst_dx = np.diff(dst_x)
127 mid_dst_x = dst_x[:-1] + 0.5 * dst_dx
129 # From np.searchsorted docstring:
130 # Find the indices into a sorted array a such that, if the corresponding
131 # elements in v were inserted before the indices, the order of a would
132 # be preserved.
133 i = np.searchsorted(mid_src_x, mid_dst_x) - 1
134 # Out of bounds indices
135 i[i < 0] = 0
136 i[i > mid_src_x.size - 2] = mid_src_x.size - 2
138 # -------------------------------------------------------------------------
139 # Visual example: interpolate from src with 2 cells to dst 3 cells
140 # The period . marks the midpoint of the cell
141 # The pipe | marks the cell edge
142 #
143 # |_____._____|_____._____|
144 # src_x0 src_x1
145 #
146 # |___.___|___.___|___.___|
147 # x0 x1 x2
148 #
149 # Then normalized weight for cell x1:
150 # weight = (x1 - src_x0) / (src_x1 - src_x0)
151 # -------------------------------------------------------------------------
153 norm_weights = (mid_dst_x - mid_src_x[i]) / (mid_src_x[i + 1] - mid_src_x[i])
154 # deal with out of bounds locations
155 # we place a sentinel value of -1 here
156 i[mid_dst_x < xmin] = -1
157 i[mid_dst_x > xmax] = -1
158 # In case it's just inside of bounds, use only the value at the boundary
159 norm_weights[norm_weights < 0.0] = 0.0
160 norm_weights[norm_weights > 1.0] = 1.0
161 # The following array is used only to deal with nodata values at the edges
162 # Recall that src_x are the cell edges
163 # Start inclusive (just like e.g. GDAL)
164 within = (mid_dst_x >= src_x[i]) & (mid_dst_x < src_x[i + 1])
165 return i, norm_weights, within
168@numba.njit(cache=True)
169def _interp_1d(src, dst, *inds_weights):
170 """
171 Parameters
172 ----------
173 src : np.array
174 dst : np.array
175 """
176 # Unpack the variadic arguments
177 kk, weights_x, within_x = inds_weights
178 # k are indices of dst array
179 for k, (ix, wx, in_x) in enumerate(zip(kk, weights_x, within_x)):
180 if ix < 0:
181 continue
183 # Fetch the values from source array, left v0, right v1
184 v0 = src[ix]
185 v1 = src[ix + 1]
186 # Check whether they are nodata
187 v0_ok = np.isfinite(v0)
188 v1_ok = np.isfinite(v1)
190 # Initialize and add to accumulators
191 accumulator = 0
192 accumulator_divisor = 0
193 if v0_ok:
194 multiplier = 1 - wx
195 accumulator += multiplier * v0
196 accumulator_divisor += multiplier
197 if v1_ok:
198 multiplier = wx
199 accumulator += multiplier * v1
200 accumulator_divisor += multiplier
202 # Check if the point to interpolate to falls fully within a nodata cell
203 # if that's the case, don't use the value, but continue with the next iteration.
204 # else: use the value, fill it into the destination array.
205 if accumulator_divisor > 0:
206 if in_x:
207 if not v0_ok:
208 continue
209 else:
210 if not v1_ok:
211 continue
212 v = accumulator / accumulator_divisor
213 dst[k] = v
215 return dst
218@numba.njit(cache=True)
219def _interp_2d(src, dst, *inds_weights):
220 # Unpack the variadic arguments
221 jj, weights_y, within_y, kk, weights_x, within_x = inds_weights
222 # j, k are indices of dst array
223 for j, (iy, wy, in_y) in enumerate(zip(jj, weights_y, within_y)):
224 if iy < 0:
225 continue
227 for k, (ix, wx, in_x) in enumerate(zip(kk, weights_x, within_x)):
228 if ix < 0:
229 continue
231 # Fetch the values from source array, upper left v00, lower right v11
232 v00 = src[iy, ix]
233 v01 = src[iy, ix + 1]
234 v10 = src[iy + 1, ix]
235 v11 = src[iy + 1, ix + 1]
236 # Check whether they are nodata
237 v00_ok = np.isfinite(v00)
238 v01_ok = np.isfinite(v01)
239 v10_ok = np.isfinite(v10)
240 v11_ok = np.isfinite(v11)
242 # Initialize and add to accumulators
243 accumulator = 0
244 accumulator_divisor = 0
245 if v00_ok:
246 multiplier = (1 - wx) * (1 - wy)
247 accumulator += multiplier * v00
248 accumulator_divisor += multiplier
249 if v01_ok:
250 multiplier = wx * (1 - wy)
251 accumulator += multiplier * v01
252 accumulator_divisor += multiplier
253 if v10_ok:
254 multiplier = (1 - wx) * wy
255 accumulator += multiplier * v10
256 accumulator_divisor += multiplier
257 if v11_ok:
258 multiplier = wx * wy
259 accumulator += multiplier * v11
260 accumulator_divisor += multiplier
262 # Check if the point to interpolate to falls fully within a nodata cell
263 # if that's the case, don't use the value, but continue with the next iteration.
264 # else: use the value, fill it into the destination array.
265 if accumulator_divisor > 0:
266 if in_y:
267 if in_x:
268 if not v00_ok:
269 continue
270 else:
271 if not v01_ok:
272 continue
273 else:
274 if in_x:
275 if not v10_ok:
276 continue
277 else:
278 if not v11_ok:
279 continue
280 v = accumulator / accumulator_divisor
281 dst[j, k] = v
283 return dst
286@numba.njit(cache=True)
287def _interp_3d(src, dst, *inds_weights):
288 # Unpack the variadic arguments
289 (
290 ii,
291 weights_z,
292 within_z,
293 jj,
294 weights_y,
295 within_y,
296 kk,
297 weights_x,
298 within_x,
299 ) = inds_weights
300 # i, j, k are indices of dst array
301 for i, (iz, wz, in_z) in enumerate(zip(ii, weights_z, within_z)):
302 if iz < 0:
303 continue
305 for j, (iy, wy, in_y) in enumerate(zip(jj, weights_y, within_y)):
306 if iy < 0:
307 continue
309 for k, (ix, wx, in_x) in enumerate(zip(kk, weights_x, within_x)):
310 if ix < 0:
311 continue
313 # Fetch the values from source array, top upper left v000,
314 # bottom lower right v11
315 v000 = src[iz, iy, ix]
316 v001 = src[iz, iy, ix + 1]
317 v010 = src[iz, iy + 1, ix]
318 v011 = src[iz, iy + 1, ix + 1]
319 v100 = src[iz + 1, iy, ix]
320 v101 = src[iz + 1, iy, ix + 1]
321 v110 = src[iz + 1, iy + 1, ix]
322 v111 = src[iz + 1, iy + 1, ix + 1]
323 # Check whether they are nodata
324 v000_ok = np.isfinite(v000)
325 v001_ok = np.isfinite(v001)
326 v010_ok = np.isfinite(v010)
327 v011_ok = np.isfinite(v011)
328 v100_ok = np.isfinite(v100)
329 v101_ok = np.isfinite(v101)
330 v110_ok = np.isfinite(v110)
331 v111_ok = np.isfinite(v111)
333 # Initialize and add to accumulators
334 accumulator = 0
335 accumulator_divisor = 0
336 if v000_ok:
337 multiplier = (1 - wz) * (1 - wx) * (1 - wy)
338 accumulator += multiplier * v000
339 accumulator_divisor += multiplier
340 if v001_ok:
341 multiplier = (1 - wz) * wx * (1 - wy)
342 accumulator += multiplier * v001
343 accumulator_divisor += multiplier
344 if v010_ok:
345 multiplier = (1 - wz) * (1 - wx) * wy
346 accumulator += multiplier * v010
347 accumulator_divisor += multiplier
348 if v011_ok:
349 multiplier = (1 - wz) * wx * wy
350 accumulator += multiplier * v011
351 accumulator_divisor += multiplier
352 if v100_ok:
353 multiplier = wz * (1 - wx) * (1 - wy)
354 accumulator += multiplier * v100
355 accumulator_divisor += multiplier
356 if v101_ok:
357 multiplier = wz * wx * (1 - wy)
358 accumulator += multiplier * v101
359 accumulator_divisor += multiplier
360 if v110_ok:
361 multiplier = wz * (1 - wx) * wy
362 accumulator += multiplier * v110
363 accumulator_divisor += multiplier
364 if v111_ok:
365 multiplier = wz * wx * wy
366 accumulator += multiplier * v111
367 accumulator_divisor += multiplier
369 # Check if the point to interpolate to falls fully within a nodata cell
370 # if that's the case, don't use the value, but continue with the next iteration.
371 # else: use the value, fill it into the destination array.
372 if accumulator_divisor > 0:
373 if in_z:
374 if in_y:
375 if in_x:
376 if not v000_ok:
377 continue
378 else:
379 if not v001_ok:
380 continue
381 else:
382 if in_x:
383 if not v010_ok:
384 continue
385 else:
386 if not v011_ok:
387 continue
388 else:
389 if in_y:
390 if in_x:
391 if not v100_ok:
392 continue
393 else:
394 if not v101_ok:
395 continue
396 else:
397 if in_x:
398 if not v110_ok:
399 continue
400 else:
401 if not v111_ok:
402 continue
403 v = accumulator / accumulator_divisor
404 dst[i, j, k] = v
406 return dst
409@numba.njit
410def _iter_interp(iter_src, iter_dst, interp_function, *inds_weights):
411 n_iter = iter_src.shape[0]
412 for i in range(n_iter):
413 iter_dst[i, ...] = interp_function(
414 iter_src[i, ...], iter_dst[i, ...], *inds_weights
415 )
416 return iter_dst
419def _jit_interp(ndim_interp):
420 @numba.njit
421 def jit_interp_1d(src, dst, *inds_weights):
422 return _interp_1d(src, dst, *inds_weights)
424 @numba.njit
425 def jit_interp_2d(src, dst, *inds_weights):
426 return _interp_2d(src, dst, *inds_weights)
428 @numba.njit
429 def jit_interp_3d(src, dst, *inds_weights):
430 return _interp_3d(src, dst, *inds_weights)
432 if ndim_interp == 1:
433 jit_interp = jit_interp_1d
434 elif ndim_interp == 2:
435 jit_interp = jit_interp_2d
436 elif ndim_interp == 3:
437 jit_interp = jit_interp_3d
438 else:
439 raise NotImplementedError("cannot regrid over more than three dimensions")
441 return jit_interp
444def _make_interp(ndim_regrid):
445 jit_interp = _jit_interp(ndim_regrid)
447 @numba.njit
448 def iter_interp(iter_src, iter_dst, *inds_weights):
449 return _iter_interp(iter_src, iter_dst, jit_interp, *inds_weights)
451 return iter_interp
454def _nd_interp(src, dst, src_coords, dst_coords, iter_interp):
455 if len(src.shape) != len(dst.shape):
456 raise ValueError("shape mismatch between src and dst")
457 if len(src_coords) != len(dst_coords):
458 raise ValueError("coords mismatch between src and dst")
459 ndim_regrid = len(src_coords)
461 # Determine weights for every regrid dimension, and alloc_len,
462 # the maximum number of src cells that may end up in a single dst cell
463 inds_weights = []
464 for src_x, dst_x in zip(src_coords, dst_coords):
465 for elem in _linear_inds_weights_1d(src_x, dst_x):
466 inds_weights.append(elem)
468 iter_src, iter_dst = common._reshape(src, dst, ndim_regrid)
469 iter_dst = iter_interp(iter_src, iter_dst, *inds_weights)
471 return iter_dst.reshape(dst.shape)