Coverage for C:\src\imod-python\imod\mf6\utilities\regrid.py: 97%

178 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1import abc 

2import copy 

3from collections import defaultdict 

4from typing import Any, Optional, Union 

5 

6import xarray as xr 

7from fastcore.dispatch import typedispatch 

8from xarray.core.utils import is_scalar 

9from xugrid.regrid.regridder import BaseRegridder 

10 

11from imod.mf6.auxiliary_variables import ( 

12 expand_transient_auxiliary_variables, 

13 remove_expanded_auxiliary_variables_from_dataset, 

14) 

15from imod.mf6.interfaces.ilinedatapackage import ILineDataPackage 

16from imod.mf6.interfaces.imodel import IModel 

17from imod.mf6.interfaces.ipackage import IPackage 

18from imod.mf6.interfaces.ipointdatapackage import IPointDataPackage 

19from imod.mf6.interfaces.iregridpackage import IRegridPackage 

20from imod.mf6.interfaces.isimulation import ISimulation 

21from imod.mf6.statusinfo import NestedStatusInfo 

22from imod.mf6.utilities.clip import clip_by_grid 

23from imod.mf6.utilities.regridding_types import RegridderType 

24from imod.schemata import ValidationError 

25from imod.typing.grid import GridDataArray, get_grid_geometry_hash, ones_like 

26 

27 

28class RegridderWeightsCache: 

29 """ 

30 This class stores any number of regridders that can regrid a single source grid to a single target grid. 

31 By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid. 

32 Regridders are stored based on their type (`see these docs<https://deltares.github.io/xugrid/examples/regridder_overview.html>`_) and planar coordinates (x, y). 

33 This is important because computing the regridding weights is a costly affair. 

34 """ 

35 

36 def __init__( 

37 self, 

38 max_cache_size: int = 6, 

39 ) -> None: 

40 self.regridder_instances: dict[ 

41 tuple[type[BaseRegridder], Optional[str]], BaseRegridder 

42 ] = {} 

43 self.weights_cache = {} 

44 self.max_cache_size = max_cache_size 

45 

46 def __get_regridder_class( 

47 self, regridder_type: RegridderType | BaseRegridder 

48 ) -> type[BaseRegridder]: 

49 if isinstance(regridder_type, abc.ABCMeta): 

50 if not issubclass(regridder_type, BaseRegridder): 

51 raise ValueError( 

52 "only derived types of BaseRegridder can be instantiated" 

53 ) 

54 return regridder_type 

55 elif isinstance(regridder_type, RegridderType): 

56 return regridder_type.value 

57 

58 raise ValueError("invalid type for regridder") 

59 

60 def get_regridder( 

61 self, 

62 source_grid: GridDataArray, 

63 target_grid: GridDataArray, 

64 regridder_type: Union[RegridderType, BaseRegridder], 

65 method: Optional[str] = None, 

66 ) -> BaseRegridder: 

67 """ 

68 returns a regridder of the specified type and with the specified method. 

69 The desired type can be passed through the argument "regridder_type" as an enumerator or 

70 as a class. 

71 The following two are equivalent: 

72 instancesCollection.get_regridder(RegridderType.OVERLAP, "mean") 

73 instancesCollection.get_regridder(xu.OverlapRegridder, "mean") 

74 

75 

76 Parameters 

77 ---------- 

78 regridder_type: RegridderType or regridder class 

79 indicates the desired regridder type 

80 method: str or None 

81 indicates the method the regridder should apply 

82 

83 Returns 

84 ------- 

85 a regridder of the specified characteristics 

86 """ 

87 regridder_class = self.__get_regridder_class(regridder_type) 

88 

89 if "layer" not in source_grid.coords and "layer" in target_grid.coords: 

90 target_grid = target_grid.drop_vars("layer") 

91 

92 source_hash = get_grid_geometry_hash(source_grid) 

93 target_hash = get_grid_geometry_hash(target_grid) 

94 key = (source_hash, target_hash, regridder_class) 

95 if key not in self.weights_cache.keys(): 

96 if len(self.weights_cache) >= self.max_cache_size: 

97 self.remove_first_regridder() 

98 kwargs = {"source": source_grid, "target": target_grid} 

99 if method is not None: 

100 kwargs["method"] = method 

101 regridder = regridder_class(**kwargs) 

102 self.weights_cache[key] = regridder.weights 

103 else: 

104 kwargs = {"weights": self.weights_cache[key], "target": target_grid} 

105 if method is not None: 

106 kwargs["method"] = method 

107 regridder = regridder_class.from_weights(**kwargs) 

108 

109 return regridder 

110 

111 def remove_first_regridder(self): 

112 keys = list(self.weights_cache.keys()) 

113 self.weights_cache.pop(keys[0]) 

114 

115 

116def assign_coord_if_present( 

117 coordname: str, target_grid: GridDataArray, maybe_has_coords_attr: Any 

118): 

119 """ 

120 If ``maybe_has_coords`` has a ``coords`` attribute and if coordname in 

121 target_grid, copy coord. 

122 """ 

123 if coordname in target_grid.coords: 

124 if coordname in target_grid.coords and hasattr(maybe_has_coords_attr, "coords"): 

125 maybe_has_coords_attr = maybe_has_coords_attr.assign_coords( 

126 {coordname: target_grid.coords[coordname].values[()]} 

127 ) 

128 return maybe_has_coords_attr 

129 

130 

131def _regrid_array( 

132 package: IRegridPackage, 

133 varname: str, 

134 regridder_collection: RegridderWeightsCache, 

135 regridder_name: str, 

136 regridder_function: str, 

137 target_grid: GridDataArray, 

138) -> Optional[GridDataArray]: 

139 """ 

140 Regrids a data_array. The array is specified by its key in the dataset. 

141 Each data-array can represent: 

142 -a scalar value, valid for the whole grid 

143 -an array of a different scalar per layer 

144 -an array with a value per grid block 

145 -None 

146 """ 

147 

148 # skip regridding for arrays with no valid values (such as "None") 

149 if not package._valid(package.dataset[varname].values[()]): 

150 return None 

151 

152 # the dataarray might be a scalar. If it is, then it does not need regridding. 

153 if is_scalar(package.dataset[varname]): 

154 return package.dataset[varname].values[()] 

155 

156 if isinstance(package.dataset[varname], xr.DataArray): 

157 coords = package.dataset[varname].coords 

158 # if it is an xr.DataArray it may be layer-based; then no regridding is needed 

159 if not ("x" in coords and "y" in coords): 

160 return package.dataset[varname] 

161 

162 # if it is an xr.DataArray it needs the dx, dy coordinates for regridding, which are otherwise not mandatory 

163 if not ("dx" in coords and "dy" in coords): 

164 raise ValueError( 

165 f"DataArray {varname} does not have both a dx and dy coordinates" 

166 ) 

167 

168 # obtain an instance of a regridder for the chosen method 

169 regridder = regridder_collection.get_regridder( 

170 package.dataset[varname], 

171 target_grid, 

172 regridder_name, 

173 regridder_function, 

174 ) 

175 

176 # store original dtype of data 

177 original_dtype = package.dataset[varname].dtype 

178 

179 # regrid data array 

180 regridded_array = regridder.regrid(package.dataset[varname]) 

181 

182 # reconvert the result to the same dtype as the original 

183 return regridded_array.astype(original_dtype) 

184 

185 

186def _get_unique_regridder_types(model: IModel) -> defaultdict[RegridderType, list[str]]: 

187 """ 

188 This function loops over the packages and collects all regridder-types that are in use. 

189 """ 

190 methods: defaultdict = defaultdict(list) 

191 for _, pkg in model.items(): 

192 if isinstance(pkg, IRegridPackage): 

193 pkg_methods = pkg.get_regrid_methods() 

194 for variable, regrid_method in pkg_methods.items(): 

195 if variable in pkg.dataset.data_vars: 

196 functiontype = None 

197 regriddertype = regrid_method[0] 

198 if len(regrid_method) > 1: 

199 functiontype = regrid_method[1] 

200 if functiontype not in methods[regriddertype]: 

201 methods[regriddertype].append(functiontype) 

202 return methods 

203 

204 

205@typedispatch # type: ignore[no-redef] 

206def _regrid_like( 

207 package: IRegridPackage, 

208 target_grid: GridDataArray, 

209 regrid_context: RegridderWeightsCache, 

210 regridder_types: Optional[dict[str, tuple[RegridderType, str]]] = None, 

211) -> IPackage: 

212 """ 

213 Creates a package of the same type as this package, based on another discretization. 

214 It regrids all the arrays in this package to the desired discretization, and leaves the options 

215 unmodified. At the moment only regridding to a different planar grid is supported, meaning 

216 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

217 

218 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults 

219 that specify how each array should be regridded. These defaults can be overridden using the input 

220 parameters of this function. 

221 

222 Examples 

223 -------- 

224 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments: 

225 

226 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")}) 

227 

228 

229 Parameters 

230 ---------- 

231 target_grid: xr.DataArray or xu.UgridDataArray 

232 a grid defined over the same discretization as the one we want to regrid the package to 

233 regridder_types: dict(str->(regridder type,str)) 

234 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str) 

235 this dictionary can be used to override the default mapping method. 

236 regrid_context: RegridderWeightsCache 

237 stores regridder weights for different regridders. Can be used to speed up regridding, 

238 if the same regridders are used several times for regridding different arrays. 

239 

240 Returns 

241 ------- 

242 a package with the same options as this package, and with all the data-arrays regridded to another discretization, 

243 similar to the one used in input argument "target_grid" 

244 """ 

245 if not hasattr(package, "_regrid_method"): 

246 raise NotImplementedError( 

247 f"Package {type(package).__name__} does not support regridding" 

248 ) 

249 

250 if hasattr(package, "auxiliary_data_fields"): 

251 remove_expanded_auxiliary_variables_from_dataset(package) 

252 

253 regridder_settings = package.get_regrid_methods() 

254 if regridder_types is not None: 

255 regridder_settings.update(regridder_types) 

256 

257 new_package_data = package.get_non_grid_data(list(regridder_settings.keys())) 

258 

259 for ( 

260 varname, 

261 regridder_type_and_function, 

262 ) in regridder_settings.items(): 

263 regridder_function = None 

264 regridder_name = regridder_type_and_function[0] 

265 if len(regridder_type_and_function) > 1: 

266 regridder_function = regridder_type_and_function[1] 

267 

268 # skip variables that are not in this dataset 

269 if varname not in package.dataset.keys(): 

270 continue 

271 

272 # regrid the variable 

273 new_package_data[varname] = _regrid_array( 

274 package, 

275 varname, 

276 regrid_context, 

277 regridder_name, 

278 regridder_function, 

279 target_grid, 

280 ) 

281 # set dx and dy if present in target_grid 

282 new_package_data[varname] = assign_coord_if_present( 

283 "dx", target_grid, new_package_data[varname] 

284 ) 

285 new_package_data[varname] = assign_coord_if_present( 

286 "dy", target_grid, new_package_data[varname] 

287 ) 

288 if hasattr(package, "auxiliary_data_fields"): 

289 expand_transient_auxiliary_variables(package) 

290 

291 return package.__class__(**new_package_data) 

292 

293 

294@typedispatch # type: ignore[no-redef] 

295def _regrid_like( 

296 model: IModel, 

297 target_grid: GridDataArray, 

298 validate: bool = True, 

299 regrid_context: Optional[RegridderWeightsCache] = None, 

300) -> IModel: 

301 """ 

302 Creates a model by regridding the packages of this model to another discretization. 

303 It regrids all the arrays in the package using the default regridding methods. 

304 At the moment only regridding to a different planar grid is supported, meaning 

305 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

306 

307 Parameters 

308 ---------- 

309 target_grid: xr.DataArray or xu.UgridDataArray 

310 a grid defined over the same discretization as the one we want to regrid the package to 

311 validate: bool 

312 set to true to validate the regridded packages 

313 regrid_context: Optional RegridderWeightsCache 

314 stores regridder weights for different regridders. Can be used to speed up regridding, 

315 if the same regridders are used several times for regridding different arrays. 

316 

317 Returns 

318 ------- 

319 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, 

320 similar to the one used in input argument "target_grid" 

321 """ 

322 supported, error_with_object_name = model.is_regridding_supported() 

323 if not supported: 

324 raise ValueError( 

325 f"regridding this model cannot be done due to the presence of package {error_with_object_name}" 

326 ) 

327 new_model = model.__class__() 

328 if regrid_context is None: 

329 regrid_context = RegridderWeightsCache() 

330 for pkg_name, pkg in model.items(): 

331 if isinstance(pkg, (IRegridPackage, ILineDataPackage, IPointDataPackage)): 

332 new_model[pkg_name] = pkg.regrid_like(target_grid, regrid_context) 

333 else: 

334 raise NotImplementedError( 

335 f"regridding is not implemented for package {pkg_name} of type {type(pkg)}" 

336 ) 

337 

338 methods = _get_unique_regridder_types(model) 

339 output_domain = _get_regridding_domain(model, target_grid, regrid_context, methods) 

340 new_model.mask_all_packages(output_domain) 

341 new_model.purge_empty_packages() 

342 if validate: 

343 status_info = NestedStatusInfo("Model validation status") 

344 status_info.add(new_model.validate("Regridded model")) 

345 if status_info.has_errors(): 

346 raise ValidationError("\n" + status_info.to_string()) 

347 return new_model 

348 

349 

350@typedispatch # type: ignore[no-redef] 

351def _regrid_like( 

352 simulation: ISimulation, 

353 regridded_simulation_name: str, 

354 target_grid: GridDataArray, 

355 validate: bool = True, 

356) -> ISimulation: 

357 """ 

358 This method creates a new simulation object. The models contained in the new simulation are regridded versions 

359 of the models in the input object (this). 

360 Time discretization and solver settings are copied. 

361 

362 Parameters 

363 ---------- 

364 regridded_simulation_name: str 

365 name given to the output simulation 

366 target_grid: xr.DataArray or xu.UgridDataArray 

367 discretization onto which the models in this simulation will be regridded 

368 validate: bool 

369 set to true to validate the regridded packages 

370 

371 Returns 

372 ------- 

373 a new simulation object with regridded models 

374 """ 

375 

376 if simulation.is_split(): 

377 raise RuntimeError( 

378 "Unable to regrid simulation. Regridding can only be done on simulations that haven't been split." 

379 + " Therefore regridding should be done before splitting the simulation." 

380 ) 

381 if not simulation.has_one_flow_model(): 

382 raise ValueError( 

383 "Unable to regrid simulation. Regridding can only be done on simulations that have a single flow model." 

384 ) 

385 regrid_context = RegridderWeightsCache() 

386 

387 models = simulation.get_models() 

388 for model_name, model in models.items(): 

389 supported, error_with_object_name = model.is_regridding_supported() 

390 if not supported: 

391 raise ValueError( 

392 f"Unable to regrid simulation, due to the presence of package '{error_with_object_name}' in model {model_name} " 

393 ) 

394 

395 result = simulation.__class__(regridded_simulation_name) 

396 for key, item in simulation.items(): 

397 if isinstance(item, IModel): 

398 result[key] = item.regrid_like(target_grid, validate, regrid_context) 

399 elif key == "gwtgwf_exchanges": 

400 pass 

401 elif isinstance(item, IPackage) and not isinstance(item, IRegridPackage): 

402 result[key] = copy.deepcopy(item) 

403 

404 else: 

405 raise NotImplementedError(f"regridding not supported for {key}") 

406 

407 return result 

408 

409 

410@typedispatch # type: ignore[no-redef] 

411def _regrid_like( 

412 package: ILineDataPackage, target_grid: GridDataArray, *_ 

413) -> ILineDataPackage: 

414 """ 

415 The regrid_like method is irrelevant for this package as it is 

416 grid-agnostic, instead this method clips the package based on the grid 

417 exterior. 

418 """ 

419 return clip_by_grid(package, target_grid) 

420 

421 

422@typedispatch # type: ignore[no-redef] 

423def _regrid_like( 

424 package: IPointDataPackage, target_grid: GridDataArray, *_ 

425) -> IPointDataPackage: 

426 """ 

427 he regrid_like method is irrelevant for this package as it is 

428 grid-agnostic, instead this method clips the package based on the grid 

429 exterior. 

430 """ 

431 target_grid_2d = target_grid.isel(layer=0, drop=True, missing_dims="ignore") 

432 return clip_by_grid(package, target_grid_2d) 

433 

434 

435@typedispatch # type: ignore[no-redef] 

436def _regrid_like(package: object, target_grid: GridDataArray, *_) -> None: 

437 raise TypeError("this object cannot be regridded") 

438 

439 

440def _get_regridding_domain( 

441 model: IModel, 

442 target_grid: GridDataArray, 

443 regrid_context: RegridderWeightsCache, 

444 methods: defaultdict[RegridderType, list[str]], 

445) -> GridDataArray: 

446 """ 

447 This method computes the output-domain for a regridding operation by regridding idomain with 

448 all regridders. Each regridder may leave some cells inactive. The output domain for the model consists of those 

449 cells that all regridders consider active. 

450 """ 

451 idomain = model.domain 

452 included_in_all = ones_like(target_grid) 

453 regridders = [ 

454 regrid_context.get_regridder(idomain, target_grid, regriddertype, function) 

455 for regriddertype, functionlist in methods.items() 

456 for function in functionlist 

457 ] 

458 for regridder in regridders: 

459 regridded_idomain = regridder.regrid(idomain) 

460 included_in_all = included_in_all.where(regridded_idomain.notnull()) 

461 included_in_all = regridded_idomain.where( 

462 regridded_idomain <= 0, other=included_in_all 

463 ) 

464 

465 new_idomain = included_in_all.where(included_in_all.notnull(), other=0) 

466 new_idomain = new_idomain.astype(int) 

467 

468 return new_idomain