Coverage for C:\src\imod-python\imod\mf6\package.py: 96%
287 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
1from __future__ import annotations
3import abc
4import pathlib
5from collections import defaultdict
6from typing import Any, Mapping, Optional, Tuple, Union
8import cftime
9import jinja2
10import numpy as np
11import xarray as xr
12import xugrid as xu
14import imod
15from imod.logging import standard_log_decorator
16from imod.mf6.auxiliary_variables import (
17 get_variable_names,
18)
19from imod.mf6.interfaces.ipackage import IPackage
20from imod.mf6.pkgbase import (
21 EXCHANGE_PACKAGES,
22 TRANSPORT_PACKAGES,
23 PackageBase,
24)
25from imod.mf6.utilities.mask import _mask
26from imod.mf6.utilities.regrid import (
27 RegridderType,
28 RegridderWeightsCache,
29 _regrid_like,
30)
31from imod.mf6.utilities.schemata import filter_schemata_dict
32from imod.mf6.validation import validation_pkg_error_message
33from imod.mf6.write_context import WriteContext
34from imod.schemata import (
35 AllNoDataSchema,
36 EmptyIndexesSchema,
37 SchemaType,
38 ValidationError,
39)
40from imod.typing import GridDataArray
43class Package(PackageBase, IPackage, abc.ABC):
44 """
45 Package is used to share methods for specific packages with no time
46 component.
48 It is not meant to be used directly, only to inherit from, to implement new
49 packages.
51 This class only supports `array input
52 <https://water.usgs.gov/water-resources/software/MODFLOW-6/mf6io_6.0.4.pdf#page=16>`_,
53 not the list input which is used in :class:`BoundaryCondition`.
54 """
56 _pkg_id = ""
57 _init_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {}
58 _write_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {}
59 _keyword_map: dict[str, str] = {}
61 def __init__(self, allargs: Mapping[str, GridDataArray | float | int | bool | str]):
62 super().__init__(allargs)
64 def isel(self):
65 raise NotImplementedError(
66 "Selection on packages not yet supported. To make a selection on "
67 f"the xr.Dataset, call {self._pkg_id}.dataset.isel instead."
68 "You can create a new package with a selection by calling "
69 f"{__class__.__name__}(**{self._pkg_id}.dataset.isel(**selection))"
70 )
72 def sel(self):
73 raise NotImplementedError(
74 "Selection on packages not yet supported. To make a selection on "
75 f"the xr.Dataset, call {self._pkg_id}.dataset.sel instead. "
76 "You can create a new package with a selection by calling "
77 f"{__class__.__name__}(**{self._pkg_id}.dataset.sel(**selection))"
78 )
80 def _valid(self, value):
81 """
82 Filters values that are None, False, or a numpy.bool_ False.
83 Needs to be this specific, since 0.0 and 0 are valid values, but are
84 equal to a boolean False.
85 """
86 # Test singletons
87 if value is False or value is None:
88 return False
89 # Test numpy bool (not singleton)
90 elif isinstance(value, np.bool_) and not value:
91 return False
92 # When dumping to netCDF and reading back, None will have been
93 # converted into a NaN. Only check NaN if it's a floating type to avoid
94 # TypeErrors.
95 elif np.issubdtype(type(value), np.floating) and np.isnan(value):
96 return False
97 else:
98 return True
100 @staticmethod
101 def _number_format(dtype: type):
102 if np.issubdtype(dtype, np.integer):
103 return "%i"
104 elif np.issubdtype(dtype, np.floating):
105 return "%.18G"
106 else:
107 raise TypeError("dtype should be either integer or float")
109 @staticmethod
110 def _initialize_template(pkg_id):
111 loader = jinja2.PackageLoader("imod", "templates/mf6")
112 env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
113 if pkg_id == "ims":
114 fname = "sln-ims.j2"
115 elif pkg_id == "tdis":
116 fname = "sim-tdis.j2"
117 elif pkg_id in TRANSPORT_PACKAGES:
118 fname = f"gwt-{pkg_id}.j2"
119 elif pkg_id in EXCHANGE_PACKAGES:
120 fname = f"exg-{pkg_id}.j2"
121 elif pkg_id == "api":
122 fname = f"{pkg_id}.j2"
123 else:
124 fname = f"gwf-{pkg_id}.j2"
125 return env.get_template(fname)
127 def write_blockfile(self, pkgname, globaltimes, write_context: WriteContext):
128 directory = write_context.get_formatted_write_directory()
130 content = self.render(
131 directory=directory,
132 pkgname=pkgname,
133 globaltimes=globaltimes,
134 binary=write_context.use_binary,
135 )
136 filename = write_context.write_directory / f"{pkgname}.{self._pkg_id}"
137 with open(filename, "w") as f:
138 f.write(content)
140 def write_binary_griddata(self, outpath, da, dtype):
141 # From the modflow6 source, the header is defined as:
142 # integer(I4B) :: kstp --> np.int32 : 1
143 # integer(I4B) :: kper --> np.int32 : 2
144 # real(DP) :: pertim --> 2 * np.int32 : 4
145 # real(DP) :: totim --> 2 * np.int32 : 6
146 # character(len=16) :: text --> 4 * np.int32 : 10
147 # integer(I4B) :: m1, m2, m3 --> 3 * np.int32 : 13
148 # so writing 13 bytes suffices to create a header.
150 # The following code is commented out due to modflow issue 189
151 # https://github.com/MODFLOW-USGS/modflow6/issues/189
152 # We never write LAYERED data.
153 # The (structured) dis array reader results in an error if you try to
154 # read a 3D botm array. By storing nlayer * nrow * ncol in the first
155 # header entry, the array is read properly.
157 # haslayer = "layer" in da.dims
158 # if haslayer:
159 # nlayer, nrow, ncol = da.shape
160 # else:
161 # nrow, ncol = da.shape
162 # nlayer = 1
164 # This is a work around for the abovementioned issue.
165 nval = np.product(da.shape)
166 header = np.zeros(13, np.int32)
167 header[-3] = np.int32(nval) # ncol
168 header[-2] = np.int32(1) # nrow
169 header[-1] = np.int32(1) # nlayer
171 with open(outpath, "w") as f:
172 header.tofile(f)
173 da.values.flatten().astype(dtype).tofile(f)
175 def write_text_griddata(self, outpath, da, dtype):
176 with open(outpath, "w") as f:
177 # Note: reshaping here avoids writing newlines after every number.
178 # This dumps all the values in a single row rather than a single
179 # column. This is to be preferred, since editors can easily
180 # "reshape" a long row with "word wrap"; they cannot as easily
181 # ignore newlines.
182 fmt = self._number_format(dtype)
183 data = da.values
184 if data.ndim > 2:
185 np.savetxt(fname=f, X=da.values.reshape((1, -1)), fmt=fmt)
186 else:
187 np.savetxt(fname=f, X=da.values, fmt=fmt)
189 def _get_render_dictionary(
190 self,
191 directory: pathlib.Path,
192 pkgname: str,
193 globaltimes: Union[list[np.datetime64], np.ndarray],
194 binary: bool,
195 ) -> dict[str, Any]:
196 d = {}
197 if directory is None:
198 pkg_directory = pkgname
199 else:
200 pkg_directory = pathlib.Path(directory) / pkgname
202 for varname in self.dataset.data_vars:
203 key = self._keyword_map.get(str(varname), str(varname))
205 if hasattr(self, "_grid_data") and varname in self._grid_data:
206 layered, value = self._compose_values(
207 self.dataset[varname], pkg_directory, key, binary=binary
208 )
209 if self._valid(value): # skip False or None
210 d[f"{key}_layered"], d[key] = layered, value
211 else:
212 value = self[varname].values[()]
213 if self._valid(value): # skip False or None
214 d[key] = value
216 if (hasattr(self, "_auxiliary_data")) and (names := get_variable_names(self)):
217 d["auxiliary"] = names
218 return d
220 def render(self, directory, pkgname, globaltimes, binary):
221 d = self._get_render_dictionary(directory, pkgname, globaltimes, binary)
222 return self._template.render(d)
224 @staticmethod
225 def _is_xy_data(obj):
226 if isinstance(obj, (xr.DataArray, xr.Dataset)):
227 xy = "x" in obj.dims and "y" in obj.dims
228 elif isinstance(obj, (xu.UgridDataArray, xu.UgridDataset)):
229 xy = obj.ugrid.grid.face_dimension in obj.dims
230 else:
231 raise TypeError(
232 "obj should be DataArray or UgridDataArray, "
233 f"received {type(obj)} instead"
234 )
235 return xy
237 def _compose_values(self, da, directory, name, binary):
238 """
239 Compose values of dictionary.
241 Ignores times. Time dependent boundary conditions use the method from
242 BoundaryCondition.
244 See documentation of wq
245 """
246 layered = False
247 values = []
248 if self._is_xy_data(da):
249 if binary:
250 path = (directory / f"{name}.bin").as_posix()
251 values.append(f"open/close {path} (binary)")
252 else:
253 path = (directory / f"{name}.dat").as_posix()
254 values.append(f"open/close {path}")
255 else:
256 if "layer" in da.dims:
257 layered = True
258 for layer in da.coords["layer"]:
259 values.append(f"constant {da.sel(layer=layer).values[()]}")
260 else:
261 value = da.values[()]
262 if self._valid(value): # skip None or False
263 values.append(f"constant {value}")
264 else:
265 values = None
267 return layered, values
269 @standard_log_decorator()
270 def write(
271 self,
272 pkgname: str,
273 globaltimes: Union[list[np.datetime64], np.ndarray],
274 write_context: WriteContext,
275 ):
276 directory = write_context.write_directory
277 binary = write_context.use_binary
278 self.write_blockfile(pkgname, globaltimes, write_context)
280 if hasattr(self, "_grid_data"):
281 if self._is_xy_data(self.dataset):
282 pkgdirectory = directory / pkgname
283 pkgdirectory.mkdir(exist_ok=True, parents=True)
284 for varname, dtype in self._grid_data.items():
285 key = self._keyword_map.get(varname, varname)
286 da = self.dataset[varname]
287 if self._is_xy_data(da):
288 if binary:
289 path = pkgdirectory / f"{key}.bin"
290 self.write_binary_griddata(path, da, dtype)
291 else:
292 path = pkgdirectory / f"{key}.dat"
293 self.write_text_griddata(path, da, dtype)
295 @standard_log_decorator()
296 def _validate(self, schemata: dict, **kwargs) -> dict[str, list[ValidationError]]:
297 errors = defaultdict(list)
298 for variable, var_schemata in schemata.items():
299 for schema in var_schemata:
300 if (
301 variable in self.dataset.keys()
302 ): # concentration only added to dataset if specified
303 try:
304 schema.validate(self.dataset[variable], **kwargs)
305 except ValidationError as e:
306 errors[variable].append(e)
307 return errors
309 def is_empty(self) -> bool:
310 """
311 Returns True if the package is empty- for example if it contains only no-data values.
312 """
314 # Create schemata dict only containing the
315 # variables with a AllNoDataSchema and EmptyIndexesSchema (in case of
316 # HFB) in the write schemata.
317 allnodata_schemata = filter_schemata_dict(
318 self._write_schemata, (AllNoDataSchema, EmptyIndexesSchema)
319 )
321 # Find if packages throws ValidationError for AllNoDataSchema or
322 # EmptyIndexesSchema.
323 allnodata_errors = self._validate(allnodata_schemata)
324 return len(allnodata_errors) > 0
326 def _validate_init_schemata(self, validate: bool):
327 """
328 Run the "cheap" schema validations.
330 The expensive validations are run during writing. Some are only
331 available then: e.g. idomain to determine active part of domain.
332 """
333 if not validate:
334 return
335 errors = self._validate(self._init_schemata)
336 if len(errors) > 0:
337 message = validation_pkg_error_message(errors)
338 raise ValidationError(message)
339 return
341 def copy(self) -> Any:
342 # All state should be contained in the dataset.
343 return type(self)(**self.dataset.copy().to_dict())
345 @staticmethod
346 def _clip_repeat_stress(
347 repeat_stress: xr.DataArray,
348 time,
349 time_start,
350 time_end,
351 ):
352 """
353 Selection may remove the original data which are repeated.
354 These should be re-inserted at the first occuring "key".
355 Next, remove these keys as they've been "promoted" to regular
356 timestamps with data.
357 """
358 # First, "pop" and filter.
359 keys, values = repeat_stress.values.T
360 keep = (keys >= time_start) & (keys <= time_end)
361 new_keys = keys[keep]
362 new_values = values[keep]
363 # Now detect which "value" entries have gone missing
364 insert_values, index = np.unique(new_values, return_index=True)
365 insert_keys = new_keys[index]
366 # Setup indexer
367 indexer = xr.DataArray(
368 data=np.arange(time.size),
369 coords={"time": time},
370 dims=("time",),
371 ).sel(time=insert_values)
372 indexer["time"] = insert_keys
374 # Update the key-value pairs. Discard keys that have been "promoted".
375 keep = np.in1d(new_keys, insert_keys, assume_unique=True, invert=True)
376 new_keys = new_keys[keep]
377 new_values = new_values[keep]
378 # Set the values to their new source.
379 new_values = insert_keys[np.searchsorted(insert_values, new_values)]
380 repeat_stress = xr.DataArray(
381 data=np.column_stack((new_keys, new_values)),
382 dims=("repeat", "repeat_items"),
383 )
384 return indexer, repeat_stress
386 @staticmethod
387 def _clip_time_indexer(
388 time,
389 time_start,
390 time_end,
391 ):
392 original = xr.DataArray(
393 data=np.arange(time.size),
394 coords={"time": time},
395 dims=("time",),
396 )
397 indexer = original.sel(time=slice(time_start, time_end))
399 # The selection might return a 0-sized dimension.
400 if indexer.size > 0:
401 first_time = indexer["time"].values[0]
402 else:
403 first_time = None
405 # If the first time matches exactly, xarray will have done thing we
406 # wanted and our work with the time dimension is finished.
407 if (time_start is not None) and (time_start != first_time):
408 # If the first time is before the original time, we need to
409 # backfill; otherwise, we need to ffill the first timestamp.
410 if time_start < time[0]:
411 method = "bfill"
412 else:
413 method = "ffill"
414 # Index with a list rather than a scalar to preserve the time
415 # dimension.
416 first = original.sel(time=[time_start], method=method)
417 first["time"] = [time_start]
418 indexer = xr.concat([first, indexer], dim="time")
420 return indexer
422 def __to_datetime(self, time, use_cftime):
423 """
424 Helper function that converts to datetime, except when None.
425 """
426 if time is None:
427 return time
428 else:
429 return imod.util.time.to_datetime_internal(time, use_cftime)
431 def clip_box(
432 self,
433 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
434 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
435 layer_min: Optional[int] = None,
436 layer_max: Optional[int] = None,
437 x_min: Optional[float] = None,
438 x_max: Optional[float] = None,
439 y_min: Optional[float] = None,
440 y_max: Optional[float] = None,
441 top: Optional[GridDataArray] = None,
442 bottom: Optional[GridDataArray] = None,
443 ) -> Package:
444 """
445 Clip a package by a bounding box (time, layer, y, x).
447 Slicing intervals may be half-bounded, by providing None:
449 * To select 500.0 <= x <= 1000.0:
450 ``clip_box(x_min=500.0, x_max=1000.0)``.
451 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
452 or ``clip_box(x_max=1000.0)``.
453 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
454 or ``clip_box(x_min=1000.0)``.
456 Parameters
457 ----------
458 time_min: optional
459 time_max: optional
460 layer_min: optional, int
461 layer_max: optional, int
462 x_min: optional, float
463 x_max: optional, float
464 y_min: optional, float
465 y_max: optional, float
466 top: optional, GridDataArray
467 bottom: optional, GridDataArray
468 state_for_boundary: optional, GridDataArray
471 Returns
472 -------
473 clipped: Package
474 """
475 if not self.is_clipping_supported():
476 raise ValueError("this package does not support clipping.")
478 selection = self.dataset
479 if "time" in selection:
480 time = selection["time"].values
481 use_cftime = isinstance(time[0], cftime.datetime)
482 time_start = self.__to_datetime(time_min, use_cftime)
483 time_end = self.__to_datetime(time_max, use_cftime)
485 indexer = self._clip_time_indexer(
486 time=time,
487 time_start=time_start,
488 time_end=time_end,
489 )
491 if "repeat_stress" in selection.data_vars and self._valid(
492 selection["repeat_stress"].values[()]
493 ):
494 repeat_indexer, repeat_stress = self._clip_repeat_stress(
495 repeat_stress=selection["repeat_stress"],
496 time=time,
497 time_start=time_start,
498 time_end=time_end,
499 )
500 selection = selection.drop_vars("repeat_stress")
501 selection["repeat_stress"] = repeat_stress
502 indexer = repeat_indexer.combine_first(indexer).astype(int)
504 selection = selection.drop_vars("time").isel(time=indexer)
506 if "layer" in selection.coords:
507 layer_slice = slice(layer_min, layer_max)
508 # Cannot select if it's not a dimension!
509 if "layer" not in selection.dims:
510 selection = (
511 selection.expand_dims("layer")
512 .sel(layer=layer_slice)
513 .squeeze("layer")
514 )
515 else:
516 selection = selection.sel(layer=layer_slice)
518 x_slice = slice(x_min, x_max)
519 y_slice = slice(y_min, y_max)
520 if isinstance(selection, xu.UgridDataset):
521 selection = selection.ugrid.sel(x=x_slice, y=y_slice)
522 elif ("x" in selection.coords) and ("y" in selection.coords):
523 if selection.indexes["y"].is_monotonic_decreasing:
524 y_slice = slice(y_max, y_min)
525 selection = selection.sel(x=x_slice, y=y_slice)
527 cls = type(self)
528 new = cls.__new__(cls)
529 new.dataset = selection
530 return new
532 def mask(self, mask: GridDataArray) -> Any:
533 """
534 Mask values outside of domain.
536 Floating values outside of the condition are set to NaN (nodata).
537 Integer values outside of the condition are set to 0 (inactive in
538 MODFLOW terms).
540 Parameters
541 ----------
542 mask: xr.DataArray, xu.UgridDataArray of ints
543 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive,
544 -1 sets cells to vertical passthrough
546 Returns
547 -------
548 masked: Package
549 The package with part masked.
550 """
552 return _mask(self, mask)
554 def regrid_like(
555 self,
556 target_grid: GridDataArray,
557 regrid_context: RegridderWeightsCache,
558 regridder_types: Optional[dict[str, Tuple[RegridderType, str]]] = None,
559 ) -> "Package":
560 """
561 Creates a package of the same type as this package, based on another discretization.
562 It regrids all the arrays in this package to the desired discretization, and leaves the options
563 unmodified. At the moment only regridding to a different planar grid is supported, meaning
564 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
566 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults
567 that specify how each array should be regridded. These defaults can be overridden using the input
568 parameters of this function.
570 Examples
571 --------
572 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments:
574 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")})
577 Parameters
578 ----------
579 target_grid: xr.DataArray or xu.UgridDataArray
580 a grid defined over the same discretization as the one we want to regrid the package to
581 regridder_types: dict(str->(regridder type,str))
582 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str)
583 this dictionary can be used to override the default mapping method.
584 regrid_context: Optional RegridderWeightsCache
585 stores regridder weights for different regridders. Can be used to speed up regridding,
586 if the same regridders are used several times for regridding different arrays.
588 Returns
589 -------
590 a package with the same options as this package, and with all the data-arrays regridded to another discretization,
591 similar to the one used in input argument "target_grid"
592 """
593 try:
594 result = _regrid_like(self, target_grid, regrid_context, regridder_types)
595 except ValueError as e:
596 raise e
597 except Exception:
598 raise ValueError("package could not be regridded.")
599 return result
601 def _skip_masking_dataarray(self, array_name: str) -> bool:
602 if hasattr(self, "_skip_mask_arrays"):
603 return array_name in self._skip_mask_arrays
604 return False
606 @classmethod
607 def is_grid_agnostic_package(cls) -> bool:
608 return False
610 def __repr__(self) -> str:
611 typename = type(self).__name__
612 return f"{typename}\n{self.dataset.__repr__()}"
614 def _repr_html_(self) -> str:
615 typename = type(self).__name__
616 return f"<div>{typename}</div>{self.dataset._repr_html_()}"
618 @property
619 def auxiliary_data_fields(self) -> dict[str, str]:
620 if hasattr(self, "_auxiliary_data"):
621 return self._auxiliary_data
622 return {}
624 def get_non_grid_data(self, grid_names: list[str]) -> dict[str, Any]:
625 """
626 This function copies the attributes of a dataset that are scalars, such as options.
628 parameters
629 ----------
630 grid_names: list of str
631 the names of the attribbutes of a dataset that are grids.
632 """
633 result = {}
634 all_non_grid_data = list(self.dataset.keys())
635 for name in (
636 gridname for gridname in grid_names if gridname in all_non_grid_data
637 ):
638 all_non_grid_data.remove(name)
639 for name in all_non_grid_data:
640 if "time" in self.dataset[name].coords:
641 result[name] = self.dataset[name]
642 else:
643 result[name] = self.dataset[name].values[()]
644 return result
646 def is_splitting_supported(self) -> bool:
647 return True
649 def is_regridding_supported(self) -> bool:
650 return True
652 def is_clipping_supported(self) -> bool:
653 return True