Coverage for C:\src\imod-python\imod\mf6\utilities\regrid.py: 97%
178 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
1import abc
2import copy
3from collections import defaultdict
4from typing import Any, Optional, Union
6import xarray as xr
7from fastcore.dispatch import typedispatch
8from xarray.core.utils import is_scalar
9from xugrid.regrid.regridder import BaseRegridder
11from imod.mf6.auxiliary_variables import (
12 expand_transient_auxiliary_variables,
13 remove_expanded_auxiliary_variables_from_dataset,
14)
15from imod.mf6.interfaces.ilinedatapackage import ILineDataPackage
16from imod.mf6.interfaces.imodel import IModel
17from imod.mf6.interfaces.ipackage import IPackage
18from imod.mf6.interfaces.ipointdatapackage import IPointDataPackage
19from imod.mf6.interfaces.iregridpackage import IRegridPackage
20from imod.mf6.interfaces.isimulation import ISimulation
21from imod.mf6.statusinfo import NestedStatusInfo
22from imod.mf6.utilities.clip import clip_by_grid
23from imod.mf6.utilities.regridding_types import RegridderType
24from imod.schemata import ValidationError
25from imod.typing.grid import GridDataArray, get_grid_geometry_hash, ones_like
28class RegridderWeightsCache:
29 """
30 This class stores any number of regridders that can regrid a single source grid to a single target grid.
31 By storing the regridders, we make sure the regridders can be re-used for different arrays on the same grid.
32 Regridders are stored based on their type (`see these docs<https://deltares.github.io/xugrid/examples/regridder_overview.html>`_) and planar coordinates (x, y).
33 This is important because computing the regridding weights is a costly affair.
34 """
36 def __init__(
37 self,
38 max_cache_size: int = 6,
39 ) -> None:
40 self.regridder_instances: dict[
41 tuple[type[BaseRegridder], Optional[str]], BaseRegridder
42 ] = {}
43 self.weights_cache = {}
44 self.max_cache_size = max_cache_size
46 def __get_regridder_class(
47 self, regridder_type: RegridderType | BaseRegridder
48 ) -> type[BaseRegridder]:
49 if isinstance(regridder_type, abc.ABCMeta):
50 if not issubclass(regridder_type, BaseRegridder):
51 raise ValueError(
52 "only derived types of BaseRegridder can be instantiated"
53 )
54 return regridder_type
55 elif isinstance(regridder_type, RegridderType):
56 return regridder_type.value
58 raise ValueError("invalid type for regridder")
60 def get_regridder(
61 self,
62 source_grid: GridDataArray,
63 target_grid: GridDataArray,
64 regridder_type: Union[RegridderType, BaseRegridder],
65 method: Optional[str] = None,
66 ) -> BaseRegridder:
67 """
68 returns a regridder of the specified type and with the specified method.
69 The desired type can be passed through the argument "regridder_type" as an enumerator or
70 as a class.
71 The following two are equivalent:
72 instancesCollection.get_regridder(RegridderType.OVERLAP, "mean")
73 instancesCollection.get_regridder(xu.OverlapRegridder, "mean")
76 Parameters
77 ----------
78 regridder_type: RegridderType or regridder class
79 indicates the desired regridder type
80 method: str or None
81 indicates the method the regridder should apply
83 Returns
84 -------
85 a regridder of the specified characteristics
86 """
87 regridder_class = self.__get_regridder_class(regridder_type)
89 if "layer" not in source_grid.coords and "layer" in target_grid.coords:
90 target_grid = target_grid.drop_vars("layer")
92 source_hash = get_grid_geometry_hash(source_grid)
93 target_hash = get_grid_geometry_hash(target_grid)
94 key = (source_hash, target_hash, regridder_class)
95 if key not in self.weights_cache.keys():
96 if len(self.weights_cache) >= self.max_cache_size:
97 self.remove_first_regridder()
98 kwargs = {"source": source_grid, "target": target_grid}
99 if method is not None:
100 kwargs["method"] = method
101 regridder = regridder_class(**kwargs)
102 self.weights_cache[key] = regridder.weights
103 else:
104 kwargs = {"weights": self.weights_cache[key], "target": target_grid}
105 if method is not None:
106 kwargs["method"] = method
107 regridder = regridder_class.from_weights(**kwargs)
109 return regridder
111 def remove_first_regridder(self):
112 keys = list(self.weights_cache.keys())
113 self.weights_cache.pop(keys[0])
116def assign_coord_if_present(
117 coordname: str, target_grid: GridDataArray, maybe_has_coords_attr: Any
118):
119 """
120 If ``maybe_has_coords`` has a ``coords`` attribute and if coordname in
121 target_grid, copy coord.
122 """
123 if coordname in target_grid.coords:
124 if coordname in target_grid.coords and hasattr(maybe_has_coords_attr, "coords"):
125 maybe_has_coords_attr = maybe_has_coords_attr.assign_coords(
126 {coordname: target_grid.coords[coordname].values[()]}
127 )
128 return maybe_has_coords_attr
131def _regrid_array(
132 package: IRegridPackage,
133 varname: str,
134 regridder_collection: RegridderWeightsCache,
135 regridder_name: str,
136 regridder_function: str,
137 target_grid: GridDataArray,
138) -> Optional[GridDataArray]:
139 """
140 Regrids a data_array. The array is specified by its key in the dataset.
141 Each data-array can represent:
142 -a scalar value, valid for the whole grid
143 -an array of a different scalar per layer
144 -an array with a value per grid block
145 -None
146 """
148 # skip regridding for arrays with no valid values (such as "None")
149 if not package._valid(package.dataset[varname].values[()]):
150 return None
152 # the dataarray might be a scalar. If it is, then it does not need regridding.
153 if is_scalar(package.dataset[varname]):
154 return package.dataset[varname].values[()]
156 if isinstance(package.dataset[varname], xr.DataArray):
157 coords = package.dataset[varname].coords
158 # if it is an xr.DataArray it may be layer-based; then no regridding is needed
159 if not ("x" in coords and "y" in coords):
160 return package.dataset[varname]
162 # if it is an xr.DataArray it needs the dx, dy coordinates for regridding, which are otherwise not mandatory
163 if not ("dx" in coords and "dy" in coords):
164 raise ValueError(
165 f"DataArray {varname} does not have both a dx and dy coordinates"
166 )
168 # obtain an instance of a regridder for the chosen method
169 regridder = regridder_collection.get_regridder(
170 package.dataset[varname],
171 target_grid,
172 regridder_name,
173 regridder_function,
174 )
176 # store original dtype of data
177 original_dtype = package.dataset[varname].dtype
179 # regrid data array
180 regridded_array = regridder.regrid(package.dataset[varname])
182 # reconvert the result to the same dtype as the original
183 return regridded_array.astype(original_dtype)
186def _get_unique_regridder_types(model: IModel) -> defaultdict[RegridderType, list[str]]:
187 """
188 This function loops over the packages and collects all regridder-types that are in use.
189 """
190 methods: defaultdict = defaultdict(list)
191 for _, pkg in model.items():
192 if isinstance(pkg, IRegridPackage):
193 pkg_methods = pkg.get_regrid_methods()
194 for variable, regrid_method in pkg_methods.items():
195 if variable in pkg.dataset.data_vars:
196 functiontype = None
197 regriddertype = regrid_method[0]
198 if len(regrid_method) > 1:
199 functiontype = regrid_method[1]
200 if functiontype not in methods[regriddertype]:
201 methods[regriddertype].append(functiontype)
202 return methods
205@typedispatch # type: ignore[no-redef]
206def _regrid_like(
207 package: IRegridPackage,
208 target_grid: GridDataArray,
209 regrid_context: RegridderWeightsCache,
210 regridder_types: Optional[dict[str, tuple[RegridderType, str]]] = None,
211) -> IPackage:
212 """
213 Creates a package of the same type as this package, based on another discretization.
214 It regrids all the arrays in this package to the desired discretization, and leaves the options
215 unmodified. At the moment only regridding to a different planar grid is supported, meaning
216 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
218 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults
219 that specify how each array should be regridded. These defaults can be overridden using the input
220 parameters of this function.
222 Examples
223 --------
224 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments:
226 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")})
229 Parameters
230 ----------
231 target_grid: xr.DataArray or xu.UgridDataArray
232 a grid defined over the same discretization as the one we want to regrid the package to
233 regridder_types: dict(str->(regridder type,str))
234 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str)
235 this dictionary can be used to override the default mapping method.
236 regrid_context: RegridderWeightsCache
237 stores regridder weights for different regridders. Can be used to speed up regridding,
238 if the same regridders are used several times for regridding different arrays.
240 Returns
241 -------
242 a package with the same options as this package, and with all the data-arrays regridded to another discretization,
243 similar to the one used in input argument "target_grid"
244 """
245 if not hasattr(package, "_regrid_method"):
246 raise NotImplementedError(
247 f"Package {type(package).__name__} does not support regridding"
248 )
250 if hasattr(package, "auxiliary_data_fields"):
251 remove_expanded_auxiliary_variables_from_dataset(package)
253 regridder_settings = package.get_regrid_methods()
254 if regridder_types is not None:
255 regridder_settings.update(regridder_types)
257 new_package_data = package.get_non_grid_data(list(regridder_settings.keys()))
259 for (
260 varname,
261 regridder_type_and_function,
262 ) in regridder_settings.items():
263 regridder_function = None
264 regridder_name = regridder_type_and_function[0]
265 if len(regridder_type_and_function) > 1:
266 regridder_function = regridder_type_and_function[1]
268 # skip variables that are not in this dataset
269 if varname not in package.dataset.keys():
270 continue
272 # regrid the variable
273 new_package_data[varname] = _regrid_array(
274 package,
275 varname,
276 regrid_context,
277 regridder_name,
278 regridder_function,
279 target_grid,
280 )
281 # set dx and dy if present in target_grid
282 new_package_data[varname] = assign_coord_if_present(
283 "dx", target_grid, new_package_data[varname]
284 )
285 new_package_data[varname] = assign_coord_if_present(
286 "dy", target_grid, new_package_data[varname]
287 )
288 if hasattr(package, "auxiliary_data_fields"):
289 expand_transient_auxiliary_variables(package)
291 return package.__class__(**new_package_data)
294@typedispatch # type: ignore[no-redef]
295def _regrid_like(
296 model: IModel,
297 target_grid: GridDataArray,
298 validate: bool = True,
299 regrid_context: Optional[RegridderWeightsCache] = None,
300) -> IModel:
301 """
302 Creates a model by regridding the packages of this model to another discretization.
303 It regrids all the arrays in the package using the default regridding methods.
304 At the moment only regridding to a different planar grid is supported, meaning
305 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
307 Parameters
308 ----------
309 target_grid: xr.DataArray or xu.UgridDataArray
310 a grid defined over the same discretization as the one we want to regrid the package to
311 validate: bool
312 set to true to validate the regridded packages
313 regrid_context: Optional RegridderWeightsCache
314 stores regridder weights for different regridders. Can be used to speed up regridding,
315 if the same regridders are used several times for regridding different arrays.
317 Returns
318 -------
319 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization,
320 similar to the one used in input argument "target_grid"
321 """
322 supported, error_with_object_name = model.is_regridding_supported()
323 if not supported:
324 raise ValueError(
325 f"regridding this model cannot be done due to the presence of package {error_with_object_name}"
326 )
327 new_model = model.__class__()
328 if regrid_context is None:
329 regrid_context = RegridderWeightsCache()
330 for pkg_name, pkg in model.items():
331 if isinstance(pkg, (IRegridPackage, ILineDataPackage, IPointDataPackage)):
332 new_model[pkg_name] = pkg.regrid_like(target_grid, regrid_context)
333 else:
334 raise NotImplementedError(
335 f"regridding is not implemented for package {pkg_name} of type {type(pkg)}"
336 )
338 methods = _get_unique_regridder_types(model)
339 output_domain = _get_regridding_domain(model, target_grid, regrid_context, methods)
340 new_model.mask_all_packages(output_domain)
341 new_model.purge_empty_packages()
342 if validate:
343 status_info = NestedStatusInfo("Model validation status")
344 status_info.add(new_model.validate("Regridded model"))
345 if status_info.has_errors():
346 raise ValidationError("\n" + status_info.to_string())
347 return new_model
350@typedispatch # type: ignore[no-redef]
351def _regrid_like(
352 simulation: ISimulation,
353 regridded_simulation_name: str,
354 target_grid: GridDataArray,
355 validate: bool = True,
356) -> ISimulation:
357 """
358 This method creates a new simulation object. The models contained in the new simulation are regridded versions
359 of the models in the input object (this).
360 Time discretization and solver settings are copied.
362 Parameters
363 ----------
364 regridded_simulation_name: str
365 name given to the output simulation
366 target_grid: xr.DataArray or xu.UgridDataArray
367 discretization onto which the models in this simulation will be regridded
368 validate: bool
369 set to true to validate the regridded packages
371 Returns
372 -------
373 a new simulation object with regridded models
374 """
376 if simulation.is_split():
377 raise RuntimeError(
378 "Unable to regrid simulation. Regridding can only be done on simulations that haven't been split."
379 + " Therefore regridding should be done before splitting the simulation."
380 )
381 if not simulation.has_one_flow_model():
382 raise ValueError(
383 "Unable to regrid simulation. Regridding can only be done on simulations that have a single flow model."
384 )
385 regrid_context = RegridderWeightsCache()
387 models = simulation.get_models()
388 for model_name, model in models.items():
389 supported, error_with_object_name = model.is_regridding_supported()
390 if not supported:
391 raise ValueError(
392 f"Unable to regrid simulation, due to the presence of package '{error_with_object_name}' in model {model_name} "
393 )
395 result = simulation.__class__(regridded_simulation_name)
396 for key, item in simulation.items():
397 if isinstance(item, IModel):
398 result[key] = item.regrid_like(target_grid, validate, regrid_context)
399 elif key == "gwtgwf_exchanges":
400 pass
401 elif isinstance(item, IPackage) and not isinstance(item, IRegridPackage):
402 result[key] = copy.deepcopy(item)
404 else:
405 raise NotImplementedError(f"regridding not supported for {key}")
407 return result
410@typedispatch # type: ignore[no-redef]
411def _regrid_like(
412 package: ILineDataPackage, target_grid: GridDataArray, *_
413) -> ILineDataPackage:
414 """
415 The regrid_like method is irrelevant for this package as it is
416 grid-agnostic, instead this method clips the package based on the grid
417 exterior.
418 """
419 return clip_by_grid(package, target_grid)
422@typedispatch # type: ignore[no-redef]
423def _regrid_like(
424 package: IPointDataPackage, target_grid: GridDataArray, *_
425) -> IPointDataPackage:
426 """
427 he regrid_like method is irrelevant for this package as it is
428 grid-agnostic, instead this method clips the package based on the grid
429 exterior.
430 """
431 target_grid_2d = target_grid.isel(layer=0, drop=True, missing_dims="ignore")
432 return clip_by_grid(package, target_grid_2d)
435@typedispatch # type: ignore[no-redef]
436def _regrid_like(package: object, target_grid: GridDataArray, *_) -> None:
437 raise TypeError("this object cannot be regridded")
440def _get_regridding_domain(
441 model: IModel,
442 target_grid: GridDataArray,
443 regrid_context: RegridderWeightsCache,
444 methods: defaultdict[RegridderType, list[str]],
445) -> GridDataArray:
446 """
447 This method computes the output-domain for a regridding operation by regridding idomain with
448 all regridders. Each regridder may leave some cells inactive. The output domain for the model consists of those
449 cells that all regridders consider active.
450 """
451 idomain = model.domain
452 included_in_all = ones_like(target_grid)
453 regridders = [
454 regrid_context.get_regridder(idomain, target_grid, regriddertype, function)
455 for regriddertype, functionlist in methods.items()
456 for function in functionlist
457 ]
458 for regridder in regridders:
459 regridded_idomain = regridder.regrid(idomain)
460 included_in_all = included_in_all.where(regridded_idomain.notnull())
461 included_in_all = regridded_idomain.where(
462 regridded_idomain <= 0, other=included_in_all
463 )
465 new_idomain = included_in_all.where(included_in_all.notnull(), other=0)
466 new_idomain = new_idomain.astype(int)
468 return new_idomain