Coverage for C:\src\imod-python\imod\mf6\utilities\regrid.py: 93%

178 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1import abc 

2import copy 

3from collections import defaultdict 

4from typing import Any, Optional, Union 

5 

6import xarray as xr 

7import xugrid as xu 

8from fastcore.dispatch import typedispatch 

9from xarray.core.utils import is_scalar 

10from xugrid.regrid.regridder import BaseRegridder 

11 

12from imod.mf6.auxiliary_variables import ( 

13 expand_transient_auxiliary_variables, 

14 remove_expanded_auxiliary_variables_from_dataset, 

15) 

16from imod.mf6.interfaces.ilinedatapackage import ILineDataPackage 

17from imod.mf6.interfaces.imodel import IModel 

18from imod.mf6.interfaces.ipackage import IPackage 

19from imod.mf6.interfaces.ipointdatapackage import IPointDataPackage 

20from imod.mf6.interfaces.iregridpackage import IRegridPackage 

21from imod.mf6.interfaces.isimulation import ISimulation 

22from imod.mf6.statusinfo import NestedStatusInfo 

23from imod.mf6.utilities.clip import clip_by_grid 

24from imod.mf6.utilities.regridding_types import RegridderType 

25from imod.schemata import ValidationError 

26from imod.typing.grid import GridDataArray, get_grid_geometry_hash, ones_like 

27 

28 

29class RegridderWeightsCache: 

30 """ 

31 This class stores any number of regridders that can regrid a single source grid to a single target grid. 

32 By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid. 

33 Regridders are stored based on their type (`see these docs<https://deltares.github.io/xugrid/examples/regridder_overview.html>`_) and planar coordinates (x, y). 

34 This is important because computing the regridding weights is a costly affair. 

35 """ 

36 

37 def __init__( 

38 self, 

39 source_grid: Union[xr.DataArray, xu.UgridDataArray], 

40 target_grid: Union[xr.DataArray, xu.UgridDataArray], 

41 max_cache_size: int = 6, 

42 ) -> None: 

43 self.regridder_instances: dict[ 

44 tuple[type[BaseRegridder], Optional[str]], BaseRegridder 

45 ] = {} 

46 self._source_grid = source_grid 

47 self._target_grid = target_grid 

48 

49 self.weights_cache = {} 

50 self.max_cache_size = max_cache_size 

51 

52 def __get_regridder_class( 

53 self, regridder_type: RegridderType | BaseRegridder 

54 ) -> type[BaseRegridder]: 

55 if isinstance(regridder_type, abc.ABCMeta): 

56 if not issubclass(regridder_type, BaseRegridder): 

57 raise ValueError( 

58 "only derived types of BaseRegridder can be instantiated" 

59 ) 

60 return regridder_type 

61 elif isinstance(regridder_type, RegridderType): 

62 return regridder_type.value 

63 

64 raise ValueError("invalid type for regridder") 

65 

66 def get_regridder( 

67 self, 

68 source_grid: GridDataArray, 

69 target_grid: GridDataArray, 

70 regridder_type: Union[RegridderType, BaseRegridder], 

71 method: Optional[str] = None, 

72 ) -> BaseRegridder: 

73 """ 

74 returns a regridder of the specified type and with the specified method. 

75 The desired type can be passed through the argument "regridder_type" as an enumerator or 

76 as a class. 

77 The following two are equivalent: 

78 instancesCollection.get_regridder(RegridderType.OVERLAP, "mean") 

79 instancesCollection.get_regridder(xu.OverlapRegridder, "mean") 

80 

81 

82 Parameters 

83 ---------- 

84 regridder_type: RegridderType or regridder class 

85 indicates the desired regridder type 

86 method: str or None 

87 indicates the method the regridder should apply 

88 

89 Returns 

90 ------- 

91 a regridder of the specified characteristics 

92 """ 

93 regridder_class = self.__get_regridder_class(regridder_type) 

94 

95 if "layer" not in source_grid.coords and "layer" in target_grid.coords: 

96 target_grid = target_grid.drop_vars("layer") 

97 

98 source_hash = get_grid_geometry_hash(source_grid) 

99 target_hash = get_grid_geometry_hash(target_grid) 

100 key = (source_hash, target_hash, regridder_class) 

101 if not key in self.weights_cache.keys(): 

102 if len(self.weights_cache) >= self.max_cache_size: 

103 self.remove_first_regridder() 

104 kwargs = {"source": source_grid, "target": target_grid} 

105 if method is not None: 

106 kwargs["method"] = method 

107 regridder = regridder_class(**kwargs) 

108 self.weights_cache[key] = regridder.weights 

109 else: 

110 kwargs = {"weights": self.weights_cache[key], "target": target_grid} 

111 if method is not None: 

112 kwargs["method"] = method 

113 regridder = regridder_class.from_weights(**kwargs) 

114 

115 return regridder 

116 

117 def remove_first_regridder(self): 

118 keys = list(self.weights_cache.keys()) 

119 self.weights_cache.pop(keys[0]) 

120 

121 

122def assign_coord_if_present( 

123 coordname: str, target_grid: GridDataArray, maybe_has_coords_attr: Any 

124): 

125 """ 

126 If ``maybe_has_coords`` has a ``coords`` attribute and if coordname in 

127 target_grid, copy coord. 

128 """ 

129 if coordname in target_grid.coords: 

130 if coordname in target_grid.coords and hasattr(maybe_has_coords_attr, "coords"): 

131 maybe_has_coords_attr = maybe_has_coords_attr.assign_coords( 

132 {coordname: target_grid.coords[coordname].values[()]} 

133 ) 

134 return maybe_has_coords_attr 

135 

136 

137def _regrid_array( 

138 package: IRegridPackage, 

139 varname: str, 

140 regridder_collection: RegridderWeightsCache, 

141 regridder_name: str, 

142 regridder_function: str, 

143 target_grid: GridDataArray, 

144) -> Optional[GridDataArray]: 

145 """ 

146 Regrids a data_array. The array is specified by its key in the dataset. 

147 Each data-array can represent: 

148 -a scalar value, valid for the whole grid 

149 -an array of a different scalar per layer 

150 -an array with a value per grid block 

151 -None 

152 """ 

153 

154 # skip regridding for arrays with no valid values (such as "None") 

155 if not package._valid(package.dataset[varname].values[()]): 

156 return None 

157 

158 # the dataarray might be a scalar. If it is, then it does not need regridding. 

159 if is_scalar(package.dataset[varname]): 

160 return package.dataset[varname].values[()] 

161 

162 if isinstance(package.dataset[varname], xr.DataArray): 

163 coords = package.dataset[varname].coords 

164 # if it is an xr.DataArray it may be layer-based; then no regridding is needed 

165 if not ("x" in coords and "y" in coords): 

166 return package.dataset[varname] 

167 

168 # if it is an xr.DataArray it needs the dx, dy coordinates for regridding, which are otherwise not mandatory 

169 if not ("dx" in coords and "dy" in coords): 

170 raise ValueError( 

171 f"DataArray {varname} does not have both a dx and dy coordinates" 

172 ) 

173 

174 # obtain an instance of a regridder for the chosen method 

175 regridder = regridder_collection.get_regridder( 

176 package.dataset[varname], 

177 target_grid, 

178 regridder_name, 

179 regridder_function, 

180 ) 

181 

182 # store original dtype of data 

183 original_dtype = package.dataset[varname].dtype 

184 

185 # regrid data array 

186 regridded_array = regridder.regrid(package.dataset[varname]) 

187 

188 # reconvert the result to the same dtype as the original 

189 return regridded_array.astype(original_dtype) 

190 

191 

192def _get_unique_regridder_types(model: IModel) -> defaultdict[RegridderType, list[str]]: 

193 """ 

194 This function loops over the packages and collects all regridder-types that are in use. 

195 """ 

196 methods: defaultdict = defaultdict(list) 

197 for _, pkg in model.items(): 

198 if isinstance(pkg, IRegridPackage): 

199 pkg_methods = pkg.get_regrid_methods() 

200 for variable in pkg_methods: 

201 if ( 

202 variable in pkg.dataset.data_vars 

203 and pkg.dataset[variable].values[()] is not None 

204 ): 

205 regriddertype = pkg_methods[variable][0] 

206 functiontype = pkg_methods[variable][1] 

207 if functiontype not in methods[regriddertype]: 

208 methods[regriddertype].append(functiontype) 

209 return methods 

210 

211 

212@typedispatch # type: ignore[no-redef] 

213def _regrid_like( 

214 package: IRegridPackage, 

215 target_grid: GridDataArray, 

216 regrid_context: RegridderWeightsCache, 

217 regridder_types: Optional[dict[str, tuple[RegridderType, str]]] = None, 

218) -> IPackage: 

219 """ 

220 Creates a package of the same type as this package, based on another discretization. 

221 It regrids all the arrays in this package to the desired discretization, and leaves the options 

222 unmodified. At the moment only regridding to a different planar grid is supported, meaning 

223 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

224 

225 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults 

226 that specify how each array should be regridded. These defaults can be overridden using the input 

227 parameters of this function. 

228 

229 Examples 

230 -------- 

231 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments: 

232 

233 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")}) 

234 

235 

236 Parameters 

237 ---------- 

238 target_grid: xr.DataArray or xu.UgridDataArray 

239 a grid defined over the same discretization as the one we want to regrid the package to 

240 regridder_types: dict(str->(regridder type,str)) 

241 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str) 

242 this dictionary can be used to override the default mapping method. 

243 regrid_context: RegridderWeightsCache 

244 stores regridder weights for different regridders. Can be used to speed up regridding, 

245 if the same regridders are used several times for regridding different arrays. 

246 

247 Returns 

248 ------- 

249 a package with the same options as this package, and with all the data-arrays regridded to another discretization, 

250 similar to the one used in input argument "target_grid" 

251 """ 

252 if not hasattr(package, "_regrid_method"): 

253 raise NotImplementedError( 

254 f"Package {type(package).__name__} does not support regridding" 

255 ) 

256 

257 if hasattr(package, "auxiliary_data_fields"): 

258 remove_expanded_auxiliary_variables_from_dataset(package) 

259 

260 regridder_settings = package.get_regrid_methods() 

261 if regridder_types is not None: 

262 regridder_settings.update(regridder_types) 

263 

264 new_package_data = package.get_non_grid_data(list(regridder_settings.keys())) 

265 

266 for ( 

267 varname, 

268 regridder_type_and_function, 

269 ) in regridder_settings.items(): 

270 regridder_name, regridder_function = regridder_type_and_function 

271 

272 # skip variables that are not in this dataset 

273 if varname not in package.dataset.keys(): 

274 continue 

275 

276 # regrid the variable 

277 new_package_data[varname] = _regrid_array( 

278 package, 

279 varname, 

280 regrid_context, 

281 regridder_name, 

282 regridder_function, 

283 target_grid, 

284 ) 

285 # set dx and dy if present in target_grid 

286 new_package_data[varname] = assign_coord_if_present( 

287 "dx", target_grid, new_package_data[varname] 

288 ) 

289 new_package_data[varname] = assign_coord_if_present( 

290 "dy", target_grid, new_package_data[varname] 

291 ) 

292 if hasattr(package, "auxiliary_data_fields"): 

293 expand_transient_auxiliary_variables(package) 

294 

295 return package.__class__(**new_package_data) 

296 

297 

298@typedispatch # type: ignore[no-redef] 

299def _regrid_like( 

300 model: IModel, 

301 target_grid: GridDataArray, 

302 validate: bool = True, 

303 regrid_context: Optional[RegridderWeightsCache] = None, 

304) -> IModel: 

305 """ 

306 Creates a model by regridding the packages of this model to another discretization. 

307 It regrids all the arrays in the package using the default regridding methods. 

308 At the moment only regridding to a different planar grid is supported, meaning 

309 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

310 

311 Parameters 

312 ---------- 

313 target_grid: xr.DataArray or xu.UgridDataArray 

314 a grid defined over the same discretization as the one we want to regrid the package to 

315 validate: bool 

316 set to true to validate the regridded packages 

317 regrid_context: Optional RegridderWeightsCache 

318 stores regridder weights for different regridders. Can be used to speed up regridding, 

319 if the same regridders are used several times for regridding different arrays. 

320 

321 Returns 

322 ------- 

323 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, 

324 similar to the one used in input argument "target_grid" 

325 """ 

326 supported, error_with_object_name = model.is_regridding_supported() 

327 if not supported: 

328 raise ValueError( 

329 f"regridding this model cannot be done due to the presence of package {error_with_object_name}" 

330 ) 

331 new_model = model.__class__() 

332 if regrid_context is None: 

333 regrid_context = RegridderWeightsCache(model.domain, target_grid) 

334 for pkg_name, pkg in model.items(): 

335 if isinstance(pkg, (IRegridPackage, ILineDataPackage, IPointDataPackage)): 

336 new_model[pkg_name] = pkg.regrid_like(target_grid, regrid_context) 

337 else: 

338 raise NotImplementedError( 

339 f"regridding is not implemented for package {pkg_name} of type {type(pkg)}" 

340 ) 

341 

342 methods = _get_unique_regridder_types(model) 

343 output_domain = _get_regridding_domain(model, target_grid, regrid_context, methods) 

344 new_model.mask_all_packages(output_domain) 

345 new_model.purge_empty_packages() 

346 if validate: 

347 status_info = NestedStatusInfo("Model validation status") 

348 status_info.add(new_model.validate("Regridded model")) 

349 if status_info.has_errors(): 

350 raise ValidationError("\n" + status_info.to_string()) 

351 return new_model 

352 

353 

354@typedispatch # type: ignore[no-redef] 

355def _regrid_like( 

356 simulation: ISimulation, 

357 regridded_simulation_name: str, 

358 target_grid: GridDataArray, 

359 validate: bool = True, 

360) -> ISimulation: 

361 """ 

362 This method creates a new simulation object. The models contained in the new simulation are regridded versions 

363 of the models in the input object (this). 

364 Time discretization and solver settings are copied. 

365 

366 Parameters 

367 ---------- 

368 regridded_simulation_name: str 

369 name given to the output simulation 

370 target_grid: xr.DataArray or xu.UgridDataArray 

371 discretization onto which the models in this simulation will be regridded 

372 validate: bool 

373 set to true to validate the regridded packages 

374 

375 Returns 

376 ------- 

377 a new simulation object with regridded models 

378 """ 

379 

380 if simulation.is_split(): 

381 raise RuntimeError( 

382 "Unable to regrid simulation. Regridding can only be done on simulations that haven't been split." 

383 + " Therefore regridding should be done before splitting the simulation." 

384 ) 

385 if not simulation.has_one_flow_model(): 

386 raise ValueError( 

387 "Unable to regrid simulation. Regridding can only be done on simulations that have a single flow model." 

388 ) 

389 flow_models = simulation.get_models_of_type("gwf6") 

390 old_grid = list(flow_models.values())[0].domain 

391 regrid_context = RegridderWeightsCache(old_grid, target_grid) 

392 

393 models = simulation.get_models() 

394 for model_name, model in models.items(): 

395 supported, error_with_object_name = model.is_regridding_supported() 

396 if not supported: 

397 raise ValueError( 

398 f"Unable to regrid simulation, due to the presence of package '{error_with_object_name}' in model {model_name} " 

399 ) 

400 

401 result = simulation.__class__(regridded_simulation_name) 

402 for key, item in simulation.items(): 

403 if isinstance(item, IModel): 

404 result[key] = item.regrid_like(target_grid, validate, regrid_context) 

405 elif key == "gwtgwf_exchanges": 

406 pass 

407 elif isinstance(item, IPackage) and not isinstance(item, IRegridPackage): 

408 result[key] = copy.deepcopy(item) 

409 

410 else: 

411 raise NotImplementedError(f"regridding not supported for {key}") 

412 

413 return result 

414 

415 

416@typedispatch # type: ignore[no-redef] 

417def _regrid_like( 

418 package: ILineDataPackage, target_grid: GridDataArray, *_ 

419) -> ILineDataPackage: 

420 """ 

421 The regrid_like method is irrelevant for this package as it is 

422 grid-agnostic, instead this method clips the package based on the grid 

423 exterior. 

424 """ 

425 return clip_by_grid(package, target_grid) 

426 

427 

428@typedispatch # type: ignore[no-redef] 

429def _regrid_like( 

430 package: IPointDataPackage, target_grid: GridDataArray, *_ 

431) -> IPointDataPackage: 

432 """ 

433 he regrid_like method is irrelevant for this package as it is 

434 grid-agnostic, instead this method clips the package based on the grid 

435 exterior. 

436 """ 

437 target_grid_2d = target_grid.isel(layer=0, drop=True, missing_dims="ignore") 

438 return clip_by_grid(package, target_grid_2d) 

439 

440 

441@typedispatch # type: ignore[no-redef] 

442def _regrid_like(package: object, target_grid: GridDataArray, *_) -> None: 

443 raise TypeError("this object cannot be regridded") 

444 

445 

446def _get_regridding_domain( 

447 model: IModel, 

448 target_grid: GridDataArray, 

449 regrid_context: RegridderWeightsCache, 

450 methods: defaultdict[RegridderType, list[str]], 

451) -> GridDataArray: 

452 """ 

453 This method computes the output-domain for a regridding operation by regridding idomain with 

454 all regridders. Each regridder may leave some cells inactive. The output domain for the model consists of those 

455 cells that all regridders consider active. 

456 """ 

457 idomain = model.domain 

458 included_in_all = ones_like(target_grid) 

459 regridders = [ 

460 regrid_context.get_regridder(idomain, target_grid, regriddertype, function) 

461 for regriddertype, functionlist in methods.items() 

462 for function in functionlist 

463 ] 

464 for regridder in regridders: 

465 regridded_idomain = regridder.regrid(idomain) 

466 included_in_all = included_in_all.where(regridded_idomain.notnull()) 

467 included_in_all = regridded_idomain.where( 

468 regridded_idomain <= 0, other=included_in_all 

469 ) 

470 

471 new_idomain = included_in_all.where(included_in_all.notnull(), other=0) 

472 new_idomain = new_idomain.astype(int) 

473 

474 return new_idomain