Coverage for C:\src\imod-python\imod\prepare\interpolate.py: 91%

233 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1""" 

2This module provides linear interpolation methods on regular grids, up to three 

3dimensions. These functions are imported by the regrid.py module, which 

4incorporates them in the Regridder class. 

5 

6### Introduction 

7The interp_ functions do most of the work. One dimension is added per time 

8(1, 2, 3). The simplest way to implement linear interpolation is as follows: 

9 

10First, for one dimension. Let node 0 be located at x0, and have a value of v0, 

11and let node 1 be located at x1, and have a value of v1. Any value between node 

120 and 1 can then be calculated as: 

13 

14v(x) = v0 + (x - x0) / (x1 - x0) * (v1 - v0) 

15 

16(x - x0) / (x1 - x0) is effectively a weighting of the different cells, and is 

17called wx in the code below. 

18 

19Of course, we can rewrite it as: 

20 

21v(x) = v0 + (x - x0) * (v1 - v0) / (x1 - x0) 

22 

23Which is equal to the equation of a straight line with slope a, and intersect b. 

24 

25b = v0 

26a = (delta_v / delta_x) = (v1 - v0) / (x1 - x0) 

27 

28Multilinear interpolation can be implemented as repeated one dimensional linear 

29interpolation. Take two dimensions, with four cells. The values are located as 

30such: 

31 

32 v00 -- v01 

33 | | 

34 | | 

35 v10 -- v11 

36 

37v00 is upper left, v01 is upper right, etc. Given a point (x, y) between the 

38drawn boundaries, wx as above, and: 

39wy = (y - y0) / (y1 - y0) 

40 

41Then, to compute the value for (x, y), we interpolate in y twice, and in x once. 

42(Or vice versa.) 

43v0 = v00 + wy * (v01 - v00) # interpolate in y 

44v1 = v10 + wy * (v11 - v10) # interpolate in y 

45v = v0 + wx * (v1 - v0) # interpolate in x 

46 

47### Nodata 

48All of this work perfectly, until we run into nodata values. The code below 

49is aimed at xarray.DataArrays, which use np.nan values as a sentinel for nodata. 

50NaN's pollute: 0 + nan = nan, 1 + 0 * nan = nan, etc. 

51In between cells that are not nodata, interpolation goes fine. However, at the 

52edges (say, to the left of v00 and v10), we get nan values. Note that the cell 

53of v00 does cover the area, but given a nan to left, it'll end up as nodata. 

54 

55This is undesirable (the relative number of nodata cells increases). 

56We have to catch nan values before they pollute the computations. This is 

57straightforward in one dimension, but involves a bothersome number of 

58conditionals in 2d, and certainly in 3d! An easier implementation in this case 

59is not by doing repeated linear interpolations, but by using an accumulator, 

60and and an accumulating divisor. We simply add the contribution of every cell 

61to this accumulator, and skip it if it's nodata. 

62 

63The result is the same (the algebra to prove this is left as an exercise to the 

64reader). 

65 

66### Note quite there yet 

67Using the accumulator implementation, we end up creating more data, tiling over 

68nodata parts. Since we're always between two cells, and the implementation above 

69removes nodata values, non-nodata values will extend too far. 

70 

71The easiest way of remedying this is by taking into account where the cell 

72boundaries of the original cell lie. If it's fully within a nodata cell, no 

73extrapolation should occur. 

74 

75### Function responsibilities: 

76* linear_inds_weights: returns per single dimensions the weights (wx, wy), 

77 the source indexes (column, row numbers), and whether the destination x 

78 lies fully within the cell of the matching source index. The source index 

79 is always the "starting" or "leftmost" cell. Interpolation occurs between 

80 this value and its neighbor to the right (index + 1). 

81* interp_ function implement functionality described above, taking the weights, 

82 indices, and the boolean within to compute interpolated values. 

83* iter_interp: an example best serves to explain: let there be a DataArray with 

84 dimensions time, layer, y, x. We wish to interpolate in two dimensions, 

85 over y and x. Time and layer are "stacked" onto a single dimension, so that 

86 a single loop in iter_interp suffices to interpolate over y and x, rather 

87 than a nested loop (in which case the number of for loops depends on the on 

88 the number of dimensions). 

89* jit_interp: selects the right interpolation method for the number of 

90 dimensions. 

91* make_interp: provides the closure to avoid numba overhead. Basically, we 

92 "inject" code before compilation, so we don't have to pass functions as an 

93 argument at runtime. (Numba probably inlines the function in this case.) 

94* nd_interp: collects the weights, reshapes the array so iter_interp will take 

95 it. 

96""" 

97 

98import numba 

99import numpy as np 

100 

101from imod.prepare import common 

102 

103 

104def _linear_inds_weights_1d(src_x, dst_x): 

105 """ 

106 Returns indices and weights for linear interpolation along a single dimension. 

107 A sentinel value of -1 is added for dst cells that are fully out of bounds. 

108 

109 Parameters 

110 ---------- 

111 src_x : np.array 

112 vertex coordinates of source 

113 dst_x: np.array 

114 vertex coordinates of destination 

115 """ 

116 # Cannot interpolate "between" only one point 

117 if not src_x.size > 2: 

118 raise ValueError("src_x must larger than 2. Cannot interpolate with only point") 

119 

120 xmin = src_x.min() 

121 xmax = src_x.max() 

122 

123 # Compute midpoints for linear interpolation 

124 src_dx = np.diff(src_x) 

125 mid_src_x = src_x[:-1] + 0.5 * src_dx 

126 dst_dx = np.diff(dst_x) 

127 mid_dst_x = dst_x[:-1] + 0.5 * dst_dx 

128 

129 # From np.searchsorted docstring: 

130 # Find the indices into a sorted array a such that, if the corresponding 

131 # elements in v were inserted before the indices, the order of a would 

132 # be preserved. 

133 i = np.searchsorted(mid_src_x, mid_dst_x) - 1 

134 # Out of bounds indices 

135 i[i < 0] = 0 

136 i[i > mid_src_x.size - 2] = mid_src_x.size - 2 

137 

138 # ------------------------------------------------------------------------- 

139 # Visual example: interpolate from src with 2 cells to dst 3 cells 

140 # The period . marks the midpoint of the cell 

141 # The pipe | marks the cell edge 

142 # 

143 # |_____._____|_____._____| 

144 # src_x0 src_x1 

145 # 

146 # |___.___|___.___|___.___| 

147 # x0 x1 x2 

148 # 

149 # Then normalized weight for cell x1: 

150 # weight = (x1 - src_x0) / (src_x1 - src_x0) 

151 # ------------------------------------------------------------------------- 

152 

153 norm_weights = (mid_dst_x - mid_src_x[i]) / (mid_src_x[i + 1] - mid_src_x[i]) 

154 # deal with out of bounds locations 

155 # we place a sentinel value of -1 here 

156 i[mid_dst_x < xmin] = -1 

157 i[mid_dst_x > xmax] = -1 

158 # In case it's just inside of bounds, use only the value at the boundary 

159 norm_weights[norm_weights < 0.0] = 0.0 

160 norm_weights[norm_weights > 1.0] = 1.0 

161 # The following array is used only to deal with nodata values at the edges 

162 # Recall that src_x are the cell edges 

163 # Start inclusive (just like e.g. GDAL) 

164 within = (mid_dst_x >= src_x[i]) & (mid_dst_x < src_x[i + 1]) 

165 return i, norm_weights, within 

166 

167 

168@numba.njit(cache=True) 

169def _interp_1d(src, dst, *inds_weights): 

170 """ 

171 Parameters 

172 ---------- 

173 src : np.array 

174 dst : np.array 

175 """ 

176 # Unpack the variadic arguments 

177 kk, weights_x, within_x = inds_weights 

178 # k are indices of dst array 

179 for k, (ix, wx, in_x) in enumerate(zip(kk, weights_x, within_x)): 

180 if ix < 0: 

181 continue 

182 

183 # Fetch the values from source array, left v0, right v1 

184 v0 = src[ix] 

185 v1 = src[ix + 1] 

186 # Check whether they are nodata 

187 v0_ok = np.isfinite(v0) 

188 v1_ok = np.isfinite(v1) 

189 

190 # Initialize and add to accumulators 

191 accumulator = 0 

192 accumulator_divisor = 0 

193 if v0_ok: 

194 multiplier = 1 - wx 

195 accumulator += multiplier * v0 

196 accumulator_divisor += multiplier 

197 if v1_ok: 

198 multiplier = wx 

199 accumulator += multiplier * v1 

200 accumulator_divisor += multiplier 

201 

202 # Check if the point to interpolate to falls fully within a nodata cell 

203 # if that's the case, don't use the value, but continue with the next iteration. 

204 # else: use the value, fill it into the destination array. 

205 if accumulator_divisor > 0: 

206 if in_x: 

207 if not v0_ok: 

208 continue 

209 else: 

210 if not v1_ok: 

211 continue 

212 v = accumulator / accumulator_divisor 

213 dst[k] = v 

214 

215 return dst 

216 

217 

218@numba.njit(cache=True) 

219def _interp_2d(src, dst, *inds_weights): 

220 # Unpack the variadic arguments 

221 jj, weights_y, within_y, kk, weights_x, within_x = inds_weights 

222 # j, k are indices of dst array 

223 for j, (iy, wy, in_y) in enumerate(zip(jj, weights_y, within_y)): 

224 if iy < 0: 

225 continue 

226 

227 for k, (ix, wx, in_x) in enumerate(zip(kk, weights_x, within_x)): 

228 if ix < 0: 

229 continue 

230 

231 # Fetch the values from source array, upper left v00, lower right v11 

232 v00 = src[iy, ix] 

233 v01 = src[iy, ix + 1] 

234 v10 = src[iy + 1, ix] 

235 v11 = src[iy + 1, ix + 1] 

236 # Check whether they are nodata 

237 v00_ok = np.isfinite(v00) 

238 v01_ok = np.isfinite(v01) 

239 v10_ok = np.isfinite(v10) 

240 v11_ok = np.isfinite(v11) 

241 

242 # Initialize and add to accumulators 

243 accumulator = 0 

244 accumulator_divisor = 0 

245 if v00_ok: 

246 multiplier = (1 - wx) * (1 - wy) 

247 accumulator += multiplier * v00 

248 accumulator_divisor += multiplier 

249 if v01_ok: 

250 multiplier = wx * (1 - wy) 

251 accumulator += multiplier * v01 

252 accumulator_divisor += multiplier 

253 if v10_ok: 

254 multiplier = (1 - wx) * wy 

255 accumulator += multiplier * v10 

256 accumulator_divisor += multiplier 

257 if v11_ok: 

258 multiplier = wx * wy 

259 accumulator += multiplier * v11 

260 accumulator_divisor += multiplier 

261 

262 # Check if the point to interpolate to falls fully within a nodata cell 

263 # if that's the case, don't use the value, but continue with the next iteration. 

264 # else: use the value, fill it into the destination array. 

265 if accumulator_divisor > 0: 

266 if in_y: 

267 if in_x: 

268 if not v00_ok: 

269 continue 

270 else: 

271 if not v01_ok: 

272 continue 

273 else: 

274 if in_x: 

275 if not v10_ok: 

276 continue 

277 else: 

278 if not v11_ok: 

279 continue 

280 v = accumulator / accumulator_divisor 

281 dst[j, k] = v 

282 

283 return dst 

284 

285 

286@numba.njit(cache=True) 

287def _interp_3d(src, dst, *inds_weights): 

288 # Unpack the variadic arguments 

289 ( 

290 ii, 

291 weights_z, 

292 within_z, 

293 jj, 

294 weights_y, 

295 within_y, 

296 kk, 

297 weights_x, 

298 within_x, 

299 ) = inds_weights 

300 # i, j, k are indices of dst array 

301 for i, (iz, wz, in_z) in enumerate(zip(ii, weights_z, within_z)): 

302 if iz < 0: 

303 continue 

304 

305 for j, (iy, wy, in_y) in enumerate(zip(jj, weights_y, within_y)): 

306 if iy < 0: 

307 continue 

308 

309 for k, (ix, wx, in_x) in enumerate(zip(kk, weights_x, within_x)): 

310 if ix < 0: 

311 continue 

312 

313 # Fetch the values from source array, top upper left v000, 

314 # bottom lower right v11 

315 v000 = src[iz, iy, ix] 

316 v001 = src[iz, iy, ix + 1] 

317 v010 = src[iz, iy + 1, ix] 

318 v011 = src[iz, iy + 1, ix + 1] 

319 v100 = src[iz + 1, iy, ix] 

320 v101 = src[iz + 1, iy, ix + 1] 

321 v110 = src[iz + 1, iy + 1, ix] 

322 v111 = src[iz + 1, iy + 1, ix + 1] 

323 # Check whether they are nodata 

324 v000_ok = np.isfinite(v000) 

325 v001_ok = np.isfinite(v001) 

326 v010_ok = np.isfinite(v010) 

327 v011_ok = np.isfinite(v011) 

328 v100_ok = np.isfinite(v100) 

329 v101_ok = np.isfinite(v101) 

330 v110_ok = np.isfinite(v110) 

331 v111_ok = np.isfinite(v111) 

332 

333 # Initialize and add to accumulators 

334 accumulator = 0 

335 accumulator_divisor = 0 

336 if v000_ok: 

337 multiplier = (1 - wz) * (1 - wx) * (1 - wy) 

338 accumulator += multiplier * v000 

339 accumulator_divisor += multiplier 

340 if v001_ok: 

341 multiplier = (1 - wz) * wx * (1 - wy) 

342 accumulator += multiplier * v001 

343 accumulator_divisor += multiplier 

344 if v010_ok: 

345 multiplier = (1 - wz) * (1 - wx) * wy 

346 accumulator += multiplier * v010 

347 accumulator_divisor += multiplier 

348 if v011_ok: 

349 multiplier = (1 - wz) * wx * wy 

350 accumulator += multiplier * v011 

351 accumulator_divisor += multiplier 

352 if v100_ok: 

353 multiplier = wz * (1 - wx) * (1 - wy) 

354 accumulator += multiplier * v100 

355 accumulator_divisor += multiplier 

356 if v101_ok: 

357 multiplier = wz * wx * (1 - wy) 

358 accumulator += multiplier * v101 

359 accumulator_divisor += multiplier 

360 if v110_ok: 

361 multiplier = wz * (1 - wx) * wy 

362 accumulator += multiplier * v110 

363 accumulator_divisor += multiplier 

364 if v111_ok: 

365 multiplier = wz * wx * wy 

366 accumulator += multiplier * v111 

367 accumulator_divisor += multiplier 

368 

369 # Check if the point to interpolate to falls fully within a nodata cell 

370 # if that's the case, don't use the value, but continue with the next iteration. 

371 # else: use the value, fill it into the destination array. 

372 if accumulator_divisor > 0: 

373 if in_z: 

374 if in_y: 

375 if in_x: 

376 if not v000_ok: 

377 continue 

378 else: 

379 if not v001_ok: 

380 continue 

381 else: 

382 if in_x: 

383 if not v010_ok: 

384 continue 

385 else: 

386 if not v011_ok: 

387 continue 

388 else: 

389 if in_y: 

390 if in_x: 

391 if not v100_ok: 

392 continue 

393 else: 

394 if not v101_ok: 

395 continue 

396 else: 

397 if in_x: 

398 if not v110_ok: 

399 continue 

400 else: 

401 if not v111_ok: 

402 continue 

403 v = accumulator / accumulator_divisor 

404 dst[i, j, k] = v 

405 

406 return dst 

407 

408 

409@numba.njit 

410def _iter_interp(iter_src, iter_dst, interp_function, *inds_weights): 

411 n_iter = iter_src.shape[0] 

412 for i in range(n_iter): 

413 iter_dst[i, ...] = interp_function( 

414 iter_src[i, ...], iter_dst[i, ...], *inds_weights 

415 ) 

416 return iter_dst 

417 

418 

419def _jit_interp(ndim_interp): 

420 @numba.njit 

421 def jit_interp_1d(src, dst, *inds_weights): 

422 return _interp_1d(src, dst, *inds_weights) 

423 

424 @numba.njit 

425 def jit_interp_2d(src, dst, *inds_weights): 

426 return _interp_2d(src, dst, *inds_weights) 

427 

428 @numba.njit 

429 def jit_interp_3d(src, dst, *inds_weights): 

430 return _interp_3d(src, dst, *inds_weights) 

431 

432 if ndim_interp == 1: 

433 jit_interp = jit_interp_1d 

434 elif ndim_interp == 2: 

435 jit_interp = jit_interp_2d 

436 elif ndim_interp == 3: 

437 jit_interp = jit_interp_3d 

438 else: 

439 raise NotImplementedError("cannot regrid over more than three dimensions") 

440 

441 return jit_interp 

442 

443 

444def _make_interp(ndim_regrid): 

445 jit_interp = _jit_interp(ndim_regrid) 

446 

447 @numba.njit 

448 def iter_interp(iter_src, iter_dst, *inds_weights): 

449 return _iter_interp(iter_src, iter_dst, jit_interp, *inds_weights) 

450 

451 return iter_interp 

452 

453 

454def _nd_interp(src, dst, src_coords, dst_coords, iter_interp): 

455 if len(src.shape) != len(dst.shape): 

456 raise ValueError("shape mismatch between src and dst") 

457 if len(src_coords) != len(dst_coords): 

458 raise ValueError("coords mismatch between src and dst") 

459 ndim_regrid = len(src_coords) 

460 

461 # Determine weights for every regrid dimension, and alloc_len, 

462 # the maximum number of src cells that may end up in a single dst cell 

463 inds_weights = [] 

464 for src_x, dst_x in zip(src_coords, dst_coords): 

465 for elem in _linear_inds_weights_1d(src_x, dst_x): 

466 inds_weights.append(elem) 

467 

468 iter_src, iter_dst = common._reshape(src, dst, ndim_regrid) 

469 iter_dst = iter_interp(iter_src, iter_dst, *inds_weights) 

470 

471 return iter_dst.reshape(dst.shape)