Coverage for C:\src\imod-python\imod\prepare\regrid.py: 95%

253 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1""" 

2Module that provides a class to do a variety of regridding operations, up to 

3three dimensions. 

4 

5Before regridding, the dimension over which regridding should occur are 

6inferred, using the functions in the imod.prepare.common module. In case 

7multiple dimensions are represent, the data is reshaped such that a single loop 

8will regrid them all. 

9 

10For example: let there be a DataArray with dimensions time, layer, y, and x. We 

11wish to regrid using an area weighted mean, over x and y. This means values 

12across times and layers are not aggregated together. In this case, the array is 

13reshaped into a 3D array, rather than a 4D array. Time and layer are stacked 

14into this first dimension together, so that a single loop suffices (see 

15common._reshape and _iter_regrid). 

16 

17Functions can be incorporated into the multidimensional regridding. This is done 

18by making use of numba closures, since there's an overhead to passing function 

19objects directly. In this case, the function is simply compiled into the 

20specific regridding method, without additional overhead. 

21 

22The regrid methods _regrid_{n}d are quite straightfoward. Using the indices that 

23and weights that have been gathered by _weights_1d, these methods fetch the 

24values from the source array (src), and pass it on to the aggregation method. 

25The single aggregated value is then filled into the destination array (dst). 

26""" 

27 

28from collections import namedtuple 

29 

30import dask 

31import numba 

32import numpy as np 

33import xarray as xr 

34 

35from imod.prepare import common, interpolate 

36 

37_RegridInfo = namedtuple( 

38 typename="_RegridInfo", 

39 field_names=[ 

40 "matching_dims", 

41 "regrid_dims", 

42 "add_dims", 

43 "dst_shape", 

44 "dst_dims", 

45 "dst_da_coords", 

46 "src_coords_regrid", 

47 "dst_coords_regrid", 

48 ], 

49) 

50 

51 

52@numba.njit(cache=True) 

53def _regrid_1d(src, dst, values, weights, method, *inds_weights): 

54 """ 

55 numba compiled function to regrid in three dimensions 

56 

57 Parameters 

58 ---------- 

59 src : np.array 

60 dst : np.array 

61 src_coords : tuple of np.arrays of edges 

62 dst_coords : tuple of np.arrays of edges 

63 method : numba.njit'ed function 

64 """ 

65 kk, blocks_ix, blocks_weights_x = inds_weights 

66 # k are indices of dst array 

67 # block_i contains indices of src array 

68 # block_w contains weights of src array 

69 for countk, k in enumerate(kk): 

70 block_ix = blocks_ix[countk] 

71 block_wx = blocks_weights_x[countk] 

72 # Add the values and weights per cell in multi-dim block 

73 count = 0 

74 for ix, wx in zip(block_ix, block_wx): 

75 if ix < 0: 

76 break 

77 values[count] = src[ix] 

78 weights[count] = wx 

79 count += 1 

80 

81 # aggregate 

82 dst[k] = method(values[:count], weights[:count]) 

83 

84 # reset storage 

85 values[:count] = 0 

86 weights[:count] = 0 

87 

88 return dst 

89 

90 

91@numba.njit(cache=True) 

92def _regrid_2d(src, dst, values, weights, method, *inds_weights): 

93 """ 

94 numba compiled function to regrid in three dimensions 

95 

96 Parameters 

97 ---------- 

98 src : np.array 

99 dst : np.array 

100 src_coords : tuple of np.arrays of edges 

101 dst_coords : tuple of np.arrays of edges 

102 method : numba.njit'ed function 

103 """ 

104 jj, blocks_iy, blocks_weights_y, kk, blocks_ix, blocks_weights_x = inds_weights 

105 

106 # j, k are indices of dst array 

107 # block_i contains indices of src array 

108 # block_w contains weights of src array 

109 for countj, j in enumerate(jj): 

110 block_iy = blocks_iy[countj] 

111 block_wy = blocks_weights_y[countj] 

112 for countk, k in enumerate(kk): 

113 block_ix = blocks_ix[countk] 

114 block_wx = blocks_weights_x[countk] 

115 # Add the values and weights per cell in multi-dim block 

116 count = 0 

117 for iy, wy in zip(block_iy, block_wy): 

118 if iy < 0: 

119 break 

120 for ix, wx in zip(block_ix, block_wx): 

121 if ix < 0: 

122 break 

123 values[count] = src[iy, ix] 

124 weights[count] = wy * wx 

125 count += 1 

126 

127 # aggregate 

128 dst[j, k] = method(values[:count], weights[:count]) 

129 

130 # reset storage 

131 values[:count] = 0.0 

132 weights[:count] = 0.0 

133 

134 return dst 

135 

136 

137@numba.njit(cache=True) 

138def _regrid_3d(src, dst, values, weights, method, *inds_weights): 

139 """ 

140 numba compiled function to regrid in three dimensions 

141 

142 Parameters 

143 ---------- 

144 src : np.array 

145 dst : np.array 

146 src_coords : tuple of np.arrays of edges 

147 dst_coords : tuple of np.arrays of edges 

148 method : numba.njit'ed function 

149 """ 

150 ( 

151 ii, 

152 blocks_iz, 

153 blocks_weights_z, 

154 jj, 

155 blocks_iy, 

156 blocks_weights_y, 

157 kk, 

158 blocks_ix, 

159 blocks_weights_x, 

160 ) = inds_weights 

161 

162 # i, j, k are indices of dst array 

163 # block_i contains indices of src array 

164 # block_w contains weights of src array 

165 for counti, i in enumerate(ii): 

166 block_iz = blocks_iz[counti] 

167 block_wz = blocks_weights_z[counti] 

168 for countj, j in enumerate(jj): 

169 block_iy = blocks_iy[countj] 

170 block_wy = blocks_weights_y[countj] 

171 for countk, k in enumerate(kk): 

172 block_ix = blocks_ix[countk] 

173 block_wx = blocks_weights_x[countk] 

174 # Add the values and weights per cell in multi-dim block 

175 count = 0 

176 for iz, wz in zip(block_iz, block_wz): 

177 if iz < 0: 

178 break 

179 for iy, wy in zip(block_iy, block_wy): 

180 if iy < 0: 

181 break 

182 for ix, wx in zip(block_ix, block_wx): 

183 if ix < 0: 

184 break 

185 values[count] = src[iz, iy, ix] 

186 weights[count] = wz * wy * wx 

187 count += 1 

188 

189 # aggregate 

190 dst[i, j, k] = method(values[:count], weights[:count]) 

191 

192 # reset storage 

193 values[:count] = 0.0 

194 weights[:count] = 0.0 

195 

196 return dst 

197 

198 

199@numba.njit 

200def _iter_regrid(iter_src, iter_dst, alloc_len, regrid_function, *inds_weights): 

201 n_iter = iter_src.shape[0] 

202 # Pre-allocate temporary storage arrays 

203 values = np.zeros(alloc_len) 

204 weights = np.zeros(alloc_len) 

205 for i in range(n_iter): 

206 iter_dst[i, ...] = regrid_function( 

207 iter_src[i, ...], iter_dst[i, ...], values, weights, *inds_weights 

208 ) 

209 return iter_dst 

210 

211 

212def _jit_regrid(jit_method, ndim_regrid): 

213 """ 

214 Compile a specific aggregation function using the compiled external method 

215 Closure avoids numba overhead 

216 https://numba.pydata.org/numba-doc/dev/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function 

217 """ 

218 

219 @numba.njit 

220 def jit_regrid_1d(src, dst, values, weights, *inds_weights): 

221 return _regrid_1d(src, dst, values, weights, jit_method, *inds_weights) 

222 

223 @numba.njit 

224 def jit_regrid_2d(src, dst, values, weights, *inds_weights): 

225 return _regrid_2d(src, dst, values, weights, jit_method, *inds_weights) 

226 

227 @numba.njit 

228 def jit_regrid_3d(src, dst, values, weights, *inds_weights): 

229 return _regrid_3d(src, dst, values, weights, jit_method, *inds_weights) 

230 

231 if ndim_regrid == 1: 

232 jit_regrid = jit_regrid_1d 

233 elif ndim_regrid == 2: 

234 jit_regrid = jit_regrid_2d 

235 elif ndim_regrid == 3: 

236 jit_regrid = jit_regrid_3d 

237 else: 

238 raise NotImplementedError("cannot regrid over more than three dimensions") 

239 

240 return jit_regrid 

241 

242 

243def _make_regrid(method, ndim_regrid): 

244 """ 

245 Closure avoids numba overhead 

246 https://numba.pydata.org/numba-doc/dev/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function 

247 """ 

248 

249 # First, compile external method 

250 jit_method = numba.njit(method, cache=True) 

251 jit_regrid = _jit_regrid(jit_method, ndim_regrid) 

252 

253 # Finally, compile the iterating regrid method with the specific aggregation function 

254 @numba.njit 

255 def iter_regrid(iter_src, iter_dst, alloc_len, *inds_weights): 

256 return _iter_regrid(iter_src, iter_dst, alloc_len, jit_regrid, *inds_weights) 

257 

258 return iter_regrid 

259 

260 

261def _nd_regrid(src, dst, src_coords, dst_coords, iter_regrid, use_relative_weights): 

262 """ 

263 Regrids an ndarray up to maximum 3 dimensions. 

264 Dimensionality of regridding is determined by the the length of src_coords 

265 (== len(dst_coords)), which has to match with the provide iter_regrid 

266 function. 

267 

268 Parameters 

269 ---------- 

270 src : np.array 

271 dst : np.array 

272 src_coords : tuple of np.array 

273 dst_coords : tuple of np.array 

274 iter_regrid : function, numba compiled 

275 """ 

276 if len(src.shape) != len(dst.shape): 

277 raise ValueError("shape mismatch between src and dst") 

278 if len(src_coords) != len(dst_coords): 

279 raise ValueError("coords mismatch between src and dst") 

280 ndim_regrid = len(src_coords) 

281 

282 # Determine weights for every regrid dimension, and alloc_len, 

283 # the maximum number of src cells that may end up in a single dst cell 

284 inds_weights = [] 

285 alloc_len = 1 

286 for src_x, dst_x in zip(src_coords, dst_coords): 

287 size, i_w = common._weights_1d(src_x, dst_x, use_relative_weights) 

288 for elem in i_w: 

289 inds_weights.append(elem) 

290 alloc_len *= size 

291 

292 iter_src, iter_dst = common._reshape(src, dst, ndim_regrid) 

293 iter_dst = iter_regrid(iter_src, iter_dst, alloc_len, *inds_weights) 

294 

295 return iter_dst.reshape(dst.shape) 

296 

297 

298class Regridder(object): 

299 """ 

300 Object to repeatedly regrid similar objects. Compiles once on first call, 

301 can then be repeatedly called without JIT compilation overhead. 

302 

303 Attributes 

304 ---------- 

305 method : str, function 

306 The method to use for regridding. Default available methods are: 

307 ``{"nearest", "multilinear", mean", "harmonic_mean", "geometric_mean", 

308 "sum", "minimum", "maximum", "mode", "median", "conductance"}`` 

309 ndim_regrid : int, optional 

310 The number of dimensions over which to regrid. If not provided, 

311 ``ndim_regrid`` will be inferred. It serves to prevent regridding over an 

312 unexpected number of dimensions; say you want to regrid over only two 

313 dimensions. Due to an input error in the coordinates of ``like``, three 

314 dimensions may be inferred in the first ``.regrid`` call. An error will 

315 be raised if ndim_regrid not match the number of inferred dimensions. 

316 Default value is None. 

317 use_relative_weights : bool, optional 

318 Whether to use relative weights in the regridding method or not. 

319 Relative weights are defined as: cell_overlap / source_cellsize, for 

320 every axis. 

321 

322 This argument should only be used if you are providing your own 

323 ``method`` as a function, where the function requires relative, rather 

324 than absolute weights (the provided ``conductance`` method requires 

325 relative weights, for example). Default value is False. 

326 extra_overlap : integer, optional 

327 In case of chunked regridding, how many cells of additional overlap is 

328 necessary. Linear interpolation requires this for example, as it reaches 

329 beyond cell boundaries to compute values. Default value is 0. 

330 

331 Examples 

332 -------- 

333 Initialize the Regridder object: 

334 

335 >>> mean_regridder = imod.prepare.Regridder(method="mean") 

336 

337 Then call the ``regrid`` method to regrid. 

338 

339 >>> result = mean_regridder.regrid(source, like) 

340 

341 The regridder can be re-used if the number of regridding dimensions 

342 match, saving some time by not (re)compiling the regridding method. 

343 

344 >>> second_result = mean_regridder.regrid(second_source, like) 

345 

346 A one-liner is possible for single use: 

347 

348 >>> result = imod.prepare.Regridder(method="mean").regrid(source, like) 

349 

350 It's possible to provide your own methods to the ``Regridder``, provided that 

351 numba can compile them. They need to take the arguments ``values`` and 

352 ``weights``. Make sure they deal with ``nan`` values gracefully! 

353 

354 >>> def p30(values, weights): 

355 >>> return np.nanpercentile(values, 30) 

356 

357 >>> p30_regridder = imod.prepare.Regridder(method=p30) 

358 >>> p30_result = p30_regridder.regrid(source, like) 

359 

360 The Numba developers maintain a list of support Numpy features here: 

361 https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html 

362 

363 In general, however, the provided methods should be adequate for your 

364 regridding needs. 

365 """ 

366 

367 def __init__( 

368 self, method, ndim_regrid=None, use_relative_weights=False, extra_overlap=0 

369 ): 

370 _method = common._get_method(method, common.METHODS) 

371 self.method = _method 

372 self.ndim_regrid = ndim_regrid 

373 self._first_call = True 

374 if _method == common.METHODS["conductance"]: 

375 use_relative_weights = True 

376 self.use_relative_weights = use_relative_weights 

377 if _method == common.METHODS["multilinear"]: 

378 extra_overlap = 1 

379 self.extra_overlap = extra_overlap 

380 

381 def _make_regrid(self): 

382 iter_regrid = _make_regrid(self.method, self.ndim_regrid) 

383 iter_interp = interpolate._make_interp(self.ndim_regrid) 

384 

385 def nd_regrid(src, dst, src_coords_regrid, dst_coords_regrid): 

386 return _nd_regrid( 

387 src, 

388 dst, 

389 src_coords_regrid, 

390 dst_coords_regrid, 

391 iter_regrid, 

392 self.use_relative_weights, 

393 ) 

394 

395 def nd_interp(src, dst, src_coords_regrid, dst_coords_regrid): 

396 return interpolate._nd_interp( 

397 src, dst, src_coords_regrid, dst_coords_regrid, iter_interp 

398 ) 

399 

400 if self.method == "nearest": 

401 pass 

402 elif self.method == "multilinear": 

403 self._nd_regrid = nd_interp 

404 else: 

405 self._nd_regrid = nd_regrid 

406 

407 def _check_ndim_regrid(self, regrid_dims): 

408 if not len(regrid_dims) == self.ndim_regrid: 

409 raise ValueError( 

410 "Number of dimensions to regrid does not match: " 

411 f"Regridder.ndim_regrid = {self.ndim_regrid}" 

412 ) 

413 

414 def _prepare(self, regrid_dims): 

415 # Create tailor made regridding function: take method and ndims into 

416 # account and call it 

417 if self.ndim_regrid is None: 

418 self.ndim_regrid = len(regrid_dims) 

419 else: 

420 self._check_ndim_regrid(regrid_dims) 

421 

422 if self.method == common.METHODS["conductance"] and len(regrid_dims) > 2: 

423 raise ValueError( 

424 "The conductance method should not be applied to " 

425 "regridding more than two dimensions" 

426 ) 

427 # Create the method. 

428 self._make_regrid() 

429 

430 @staticmethod 

431 def _regrid_info(src, like): 

432 # Find coordinates that already match, and those that have to be 

433 # regridded, and those that exist in source but not in like (left 

434 # untouched) 

435 matching_dims, regrid_dims, add_dims = common._match_dims(src, like) 

436 

437 # Order dimensions in the right way: 

438 # dimensions that are regridded end up at the end for efficient iteration 

439 dst_dims = (*add_dims, *matching_dims, *regrid_dims) 

440 dims_from_src = (*add_dims, *matching_dims) 

441 dims_from_like = tuple(regrid_dims) 

442 

443 # Gather destination coordinates 

444 dst_da_coords, dst_shape = common._dst_coords( 

445 src, like, dims_from_src, dims_from_like 

446 ) 

447 

448 dst_tmp = xr.DataArray( 

449 data=dask.array.empty(dst_shape), coords=dst_da_coords, dims=dst_dims 

450 ) 

451 

452 # TODO: check that axes are aligned 

453 src_coords_regrid = [common._coord(src, dim) for dim in regrid_dims] 

454 dst_coords_regrid = [common._coord(dst_tmp, dim) for dim in regrid_dims] 

455 

456 return _RegridInfo( 

457 matching_dims=matching_dims, 

458 regrid_dims=regrid_dims, 

459 add_dims=add_dims, 

460 dst_shape=dst_shape, 

461 dst_da_coords=dst_da_coords, 

462 dst_dims=(*add_dims, *matching_dims, *regrid_dims), 

463 src_coords_regrid=src_coords_regrid, 

464 dst_coords_regrid=dst_coords_regrid, 

465 ) 

466 

467 def _regrid(self, src, fill_value, info): 

468 # Allocate dst 

469 dst = np.full(info.dst_shape, fill_value) 

470 # No overlap whatsoever, early exit 

471 if any(size == 0 for size in src.shape): 

472 return dst 

473 

474 # Transpose src so that dims to regrid are last 

475 src = src.transpose(*info.dst_dims) 

476 

477 # Exit early if nothing is to be done 

478 if len(info.regrid_dims) == 0: 

479 return src.values.copy() 

480 else: 

481 dst = self._nd_regrid( 

482 src.values, dst, info.src_coords_regrid, info.dst_coords_regrid 

483 ) 

484 return dst 

485 

486 def _delayed_regrid(self, src, like, fill_value, info): 

487 """ 

488 Deal with chunks in dimensions that will NOT be regridded. 

489 """ 

490 if len(info.add_dims) == 0: 

491 return self._chunked_regrid(src, like, fill_value) 

492 

493 src_dim_slices = [] 

494 shape_chunks = [] 

495 for dim, chunksize in zip(src.dims, src.chunks): 

496 if dim in info.add_dims: 

497 end = np.cumsum(chunksize) 

498 start = end - chunksize 

499 src_dim_slices.append([slice(s, e) for s, e in zip(start, end)]) 

500 shape_chunks.append(len(chunksize)) 

501 

502 src_expanded_slices = np.stack( 

503 [a.ravel() for a in np.meshgrid(*src_dim_slices, indexing="ij")], axis=-1 

504 ) 

505 src_das = common._sel_chunks(src, info.add_dims, src_expanded_slices) 

506 n_das = len(src_das) 

507 np_collection = np.full(n_das, None) 

508 

509 for i, src_da in enumerate(src_das): 

510 np_collection[i] = self._chunked_regrid(src_da, like, fill_value) 

511 

512 shape_chunks = shape_chunks + [1] * len(info.regrid_dims) 

513 reshaped_collection = np.reshape(np_collection, shape_chunks).tolist() 

514 data = dask.array.block(reshaped_collection) 

515 return data 

516 

517 def _chunked_regrid(self, src, like, fill_value): 

518 """ 

519 Deal with chunks in dimensions that will be regridded. 

520 """ 

521 like_expanded_slices, shape_chunks = common._define_slices(src, like) 

522 like_das = common._sel_chunks(like, like.dims, like_expanded_slices) 

523 n_das = len(like_das) 

524 np_collection = np.full(n_das, None) 

525 

526 # Regridder should compute first chunk once 

527 # so numba has compiled the necessary functions for subsequent chunks 

528 for i, dst_da in enumerate(like_das): 

529 chunk_src = common._slice_src(src, dst_da, self.extra_overlap) 

530 info = self._regrid_info(chunk_src, dst_da) 

531 

532 if any( 

533 size == 0 for size in chunk_src.shape 

534 ): # zero overlap for the chunk, zero size chunk 

535 # N.B. Make sure to include chunks=-1, defaults to chunks="auto", which 

536 # automatically results in unnecessary, error prone chunks. 

537 # TODO: Not covered by tests -- but also rather hard to test. 

538 dask_array = dask.array.full( 

539 shape=info.dst_shape, 

540 fill_value=fill_value, 

541 dtype=src.dtype, 

542 chunks=-1, 

543 ) 

544 elif self._first_call: 

545 # NOT delayed, trigger compilation 

546 a = self._regrid(chunk_src, fill_value, info) 

547 dask_array = dask.array.from_array(a, chunks=-1) 

548 self._first_call = False 

549 else: 

550 # Alllocation occurs inside 

551 a = dask.delayed(self._regrid, pure=True)(chunk_src, fill_value, info) 

552 dask_array = dask.array.from_delayed( 

553 a, shape=info.dst_shape, dtype=src.dtype 

554 ) 

555 

556 np_collection[i] = dask_array 

557 

558 # Determine the shape of the chunks, and reshape so dask.block does the right thing 

559 reshaped_collection = np.reshape(np_collection, shape_chunks).tolist() 

560 data = dask.array.block(reshaped_collection) 

561 return data 

562 

563 def regrid(self, source, like, fill_value=np.nan): 

564 """ 

565 Regrid ``source`` along dimensions that ``source`` and ``like`` share. 

566 These dimensions will be inferred the first time ``.regrid`` is called 

567 for the Regridder object. 

568 

569 Following xarray conventions, nodata is assumed to ``np.nan``. 

570 

571 Parameters 

572 ---------- 

573 source : xr.DataArray of floats 

574 like : xr.DataArray of floats 

575 The like array present what the coordinates should look like. 

576 fill_value : float 

577 The fill_value. Defaults to np.nan 

578 

579 Returns 

580 ------- 

581 result : xr.DataArray 

582 Regridded result. 

583 """ 

584 if not isinstance(source, xr.DataArray): 

585 raise TypeError("source must be a DataArray") 

586 if not isinstance(like, xr.DataArray): 

587 raise TypeError("like must be a DataArray") 

588 

589 # Don't mutate source; src stands for source, dst for destination 

590 src = source.copy(deep=False) 

591 like = like.copy(deep=False) 

592 _, regrid_dims, _ = common._match_dims(src, like) 

593 # Exit early if nothing is to be done 

594 if len(regrid_dims) == 0: 

595 return source.copy(deep=True) 

596 

597 # Collect dimensions to flip to make everything ascending 

598 src, _ = common._increasing_dims(src, regrid_dims) 

599 like, flip_dst = common._increasing_dims(like, regrid_dims) 

600 

601 info = self._regrid_info(source, like) 

602 # Use xarray for nearest 

603 # TODO: replace by more efficient, specialized method 

604 if self.method == "nearest": 

605 dst = source.reindex_like(like, method="nearest") 

606 dst = dst.assign_coords(info.dst_da_coords) 

607 return dst 

608 

609 # Prepare for regridding; quick checks 

610 if self._first_call: 

611 self._prepare(info.regrid_dims) 

612 self._check_ndim_regrid(info.regrid_dims) 

613 

614 if src.chunks is None: 

615 src = common._slice_src(src, like, self.extra_overlap) 

616 # Recollect info with sliced part of src 

617 info = self._regrid_info(src, like) 

618 data = self._regrid(src, fill_value, info) 

619 self._first_call = False 

620 else: 

621 # Ensure all dimensions have a dx coordinate, so that if the chunks 

622 # results in chunks which are size 1 along a dimension, the cellsize 

623 # can still be determined. 

624 src = common._set_cellsizes(src, info.regrid_dims) 

625 like = common._set_cellsizes(like, info.regrid_dims) 

626 data = self._delayed_regrid(src, like, fill_value, info) 

627 

628 dst = xr.DataArray(data=data, coords=info.dst_da_coords, dims=info.dst_dims) 

629 # Replace equidistant cellsize arrays by scalar values 

630 dst = common._set_scalar_cellsizes(dst) 

631 

632 # Flip dimensions to return as like 

633 for dim in flip_dst: 

634 dst = dst.sel({dim: slice(None, None, -1)}) 

635 

636 # Transpose to original dimension coordinates 

637 return dst.transpose(*source.dims)