Coverage for C:\src\imod-python\imod\mf6\package.py: 98%
287 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
1from __future__ import annotations
3import abc
4import pathlib
5from collections import defaultdict
6from typing import Any, Mapping, Optional, Tuple, Union
8import cftime
9import jinja2
10import numpy as np
11import xarray as xr
12import xugrid as xu
14import imod
15from imod.logging import standard_log_decorator
16from imod.mf6.auxiliary_variables import (
17 get_variable_names,
18)
19from imod.mf6.interfaces.ipackage import IPackage
20from imod.mf6.pkgbase import (
21 EXCHANGE_PACKAGES,
22 TRANSPORT_PACKAGES,
23 PackageBase,
24)
25from imod.mf6.utilities.mask import _mask
26from imod.mf6.utilities.regrid import (
27 RegridderType,
28 RegridderWeightsCache,
29 _regrid_like,
30)
31from imod.mf6.utilities.schemata import filter_schemata_dict
32from imod.mf6.validation import validation_pkg_error_message
33from imod.mf6.write_context import WriteContext
34from imod.schemata import (
35 AllNoDataSchema,
36 EmptyIndexesSchema,
37 SchemaType,
38 ValidationError,
39)
40from imod.typing import GridDataArray
43class Package(PackageBase, IPackage, abc.ABC):
44 """
45 Package is used to share methods for specific packages with no time
46 component.
48 It is not meant to be used directly, only to inherit from, to implement new
49 packages.
51 This class only supports `array input
52 <https://water.usgs.gov/water-resources/software/MODFLOW-6/mf6io_6.0.4.pdf#page=16>`_,
53 not the list input which is used in :class:`BoundaryCondition`.
54 """
56 _pkg_id = ""
57 _init_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {}
58 _write_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {}
59 _keyword_map: dict[str, str] = {}
61 def __init__(self, allargs: Mapping[str, GridDataArray | float | int | bool | str]):
62 super().__init__(allargs)
64 def isel(self):
65 raise NotImplementedError(
66 "Selection on packages not yet supported. To make a selection on "
67 f"the xr.Dataset, call {self._pkg_id}.dataset.isel instead."
68 "You can create a new package with a selection by calling "
69 f"{__class__.__name__}(**{self._pkg_id}.dataset.isel(**selection))"
70 )
72 def sel(self):
73 raise NotImplementedError(
74 "Selection on packages not yet supported. To make a selection on "
75 f"the xr.Dataset, call {self._pkg_id}.dataset.sel instead. "
76 "You can create a new package with a selection by calling "
77 f"{__class__.__name__}(**{self._pkg_id}.dataset.sel(**selection))"
78 )
80 def _valid(self, value):
81 """
82 Filters values that are None, False, or a numpy.bool_ False.
83 Needs to be this specific, since 0.0 and 0 are valid values, but are
84 equal to a boolean False.
85 """
86 # Test singletons
87 if value is False or value is None:
88 return False
89 # Test numpy bool (not singleton)
90 elif isinstance(value, np.bool_) and not value:
91 return False
92 # When dumping to netCDF and reading back, None will have been
93 # converted into a NaN. Only check NaN if it's a floating type to avoid
94 # TypeErrors.
95 elif np.issubdtype(type(value), np.floating) and np.isnan(value):
96 return False
97 else:
98 return True
100 @staticmethod
101 def _number_format(dtype: type):
102 if np.issubdtype(dtype, np.integer):
103 return "%i"
104 elif np.issubdtype(dtype, np.floating):
105 return "%.18G"
106 else:
107 raise TypeError("dtype should be either integer or float")
109 @staticmethod
110 def _initialize_template(pkg_id):
111 loader = jinja2.PackageLoader("imod", "templates/mf6")
112 env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
113 if pkg_id == "ims":
114 fname = "sln-ims.j2"
115 elif pkg_id == "tdis":
116 fname = "sim-tdis.j2"
117 elif pkg_id in TRANSPORT_PACKAGES:
118 fname = f"gwt-{pkg_id}.j2"
119 elif pkg_id in EXCHANGE_PACKAGES:
120 fname = f"exg-{pkg_id}.j2"
121 elif pkg_id == "api":
122 fname = f"{pkg_id}.j2"
123 else:
124 fname = f"gwf-{pkg_id}.j2"
125 return env.get_template(fname)
127 def write_blockfile(self, pkgname, globaltimes, write_context: WriteContext):
128 directory = write_context.get_formatted_write_directory()
130 content = self.render(
131 directory=directory,
132 pkgname=pkgname,
133 globaltimes=globaltimes,
134 binary=write_context.use_binary,
135 )
136 filename = write_context.write_directory / f"{pkgname}.{self._pkg_id}"
137 with open(filename, "w") as f:
138 f.write(content)
140 def write_binary_griddata(self, outpath, da, dtype):
141 # From the modflow6 source, the header is defined as:
142 # integer(I4B) :: kstp --> np.int32 : 1
143 # integer(I4B) :: kper --> np.int32 : 2
144 # real(DP) :: pertim --> 2 * np.int32 : 4
145 # real(DP) :: totim --> 2 * np.int32 : 6
146 # character(len=16) :: text --> 4 * np.int32 : 10
147 # integer(I4B) :: m1, m2, m3 --> 3 * np.int32 : 13
148 # so writing 13 bytes suffices to create a header.
150 # The following code is commented out due to modflow issue 189
151 # https://github.com/MODFLOW-USGS/modflow6/issues/189
152 # We never write LAYERED data.
153 # The (structured) dis array reader results in an error if you try to
154 # read a 3D botm array. By storing nlayer * nrow * ncol in the first
155 # header entry, the array is read properly.
157 # haslayer = "layer" in da.dims
158 # if haslayer:
159 # nlayer, nrow, ncol = da.shape
160 # else:
161 # nrow, ncol = da.shape
162 # nlayer = 1
164 # This is a work around for the abovementioned issue.
165 nval = np.product(da.shape)
166 header = np.zeros(13, np.int32)
167 header[-3] = np.int32(nval) # ncol
168 header[-2] = np.int32(1) # nrow
169 header[-1] = np.int32(1) # nlayer
171 with open(outpath, "w") as f:
172 header.tofile(f)
173 da.values.flatten().astype(dtype).tofile(f)
175 def write_text_griddata(self, outpath, da, dtype):
176 with open(outpath, "w") as f:
177 # Note: reshaping here avoids writing newlines after every number.
178 # This dumps all the values in a single row rather than a single
179 # column. This is to be preferred, since editors can easily
180 # "reshape" a long row with "word wrap"; they cannot as easily
181 # ignore newlines.
182 fmt = self._number_format(dtype)
183 data = da.values
184 if data.ndim > 2:
185 np.savetxt(fname=f, X=da.values.reshape((1, -1)), fmt=fmt)
186 else:
187 np.savetxt(fname=f, X=da.values, fmt=fmt)
189 def _get_render_dictionary(
190 self,
191 directory: pathlib.Path,
192 pkgname: str,
193 globaltimes: Union[list[np.datetime64], np.ndarray],
194 binary: bool,
195 ) -> dict[str, Any]:
196 d = {}
197 if directory is None:
198 pkg_directory = pkgname
199 else:
200 pkg_directory = pathlib.Path(directory) / pkgname
202 for varname in self.dataset.data_vars:
203 key = self._keyword_map.get(str(varname), str(varname))
205 if hasattr(self, "_grid_data") and varname in self._grid_data:
206 layered, value = self._compose_values(
207 self.dataset[varname], pkg_directory, key, binary=binary
208 )
209 if self._valid(value): # skip False or None
210 d[f"{key}_layered"], d[key] = layered, value
211 else:
212 value = self[varname].values[()]
213 if self._valid(value): # skip False or None
214 d[key] = value
216 if (hasattr(self, "_auxiliary_data")) and (names := get_variable_names(self)):
217 d["auxiliary"] = names
218 return d
220 def render(self, directory, pkgname, globaltimes, binary):
221 d = self._get_render_dictionary(directory, pkgname, globaltimes, binary)
222 return self._template.render(d)
224 @staticmethod
225 def _is_xy_data(obj):
226 if isinstance(obj, (xr.DataArray, xr.Dataset)):
227 xy = "x" in obj.dims and "y" in obj.dims
228 elif isinstance(obj, (xu.UgridDataArray, xu.UgridDataset)):
229 xy = obj.ugrid.grid.face_dimension in obj.dims
230 else:
231 raise TypeError(
232 "obj should be DataArray or UgridDataArray, "
233 f"received {type(obj)} instead"
234 )
235 return xy
237 def _compose_values(self, da, directory, name, binary):
238 """
239 Compose values of dictionary.
241 Ignores times. Time dependent boundary conditions use the method from
242 BoundaryCondition.
244 See documentation of wq
245 """
246 layered = False
247 values = []
248 if self._is_xy_data(da):
249 if binary:
250 path = (directory / f"{name}.bin").as_posix()
251 values.append(f"open/close {path} (binary)")
252 else:
253 path = (directory / f"{name}.dat").as_posix()
254 values.append(f"open/close {path}")
255 else:
256 if "layer" in da.dims:
257 layered = True
258 for layer in da.coords["layer"]:
259 values.append(f"constant {da.sel(layer=layer).values[()]}")
260 else:
261 value = da.values[()]
262 if self._valid(value): # skip None or False
263 values.append(f"constant {value}")
264 else:
265 values = None
267 return layered, values
269 @standard_log_decorator()
270 def write(
271 self,
272 pkgname: str,
273 globaltimes: Union[list[np.datetime64], np.ndarray],
274 write_context: WriteContext,
275 ):
276 directory = write_context.write_directory
277 binary = write_context.use_binary
278 self.write_blockfile(pkgname, globaltimes, write_context)
280 if hasattr(self, "_grid_data"):
281 if self._is_xy_data(self.dataset):
282 pkgdirectory = directory / pkgname
283 pkgdirectory.mkdir(exist_ok=True, parents=True)
284 for varname, dtype in self._grid_data.items():
285 key = self._keyword_map.get(varname, varname)
286 da = self.dataset[varname]
287 if self._is_xy_data(da):
288 if binary:
289 path = pkgdirectory / f"{key}.bin"
290 self.write_binary_griddata(path, da, dtype)
291 else:
292 path = pkgdirectory / f"{key}.dat"
293 self.write_text_griddata(path, da, dtype)
295 @standard_log_decorator()
296 def _validate(self, schemata: dict, **kwargs) -> dict[str, list[ValidationError]]:
297 errors = defaultdict(list)
298 for variable, var_schemata in schemata.items():
299 for schema in var_schemata:
300 if (
301 variable in self.dataset.keys()
302 ): # concentration only added to dataset if specified
303 try:
304 schema.validate(self.dataset[variable], **kwargs)
305 except ValidationError as e:
306 errors[variable].append(e)
307 return errors
309 def is_empty(self) -> bool:
310 """
311 Returns True if the package is empty- for example if it contains only no-data values.
312 """
314 # Create schemata dict only containing the
315 # variables with a AllNoDataSchema and EmptyIndexesSchema (in case of
316 # HFB) in the write schemata.
317 allnodata_schemata = filter_schemata_dict(
318 self._write_schemata, (AllNoDataSchema, EmptyIndexesSchema)
319 )
321 # Find if packages throws ValidationError for AllNoDataSchema or
322 # EmptyIndexesSchema.
323 allnodata_errors = self._validate(allnodata_schemata)
324 return len(allnodata_errors) > 0
326 def _validate_init_schemata(self, validate: bool):
327 """
328 Run the "cheap" schema validations.
330 The expensive validations are run during writing. Some are only
331 available then: e.g. idomain to determine active part of domain.
332 """
333 if not validate:
334 return
335 errors = self._validate(self._init_schemata)
336 if len(errors) > 0:
337 message = validation_pkg_error_message(errors)
338 raise ValidationError(message)
339 return
341 def copy(self) -> Any:
342 # All state should be contained in the dataset.
343 return type(self)(**self.dataset.copy().to_dict())
345 @staticmethod
346 def _clip_repeat_stress(
347 repeat_stress: xr.DataArray,
348 time,
349 time_start,
350 time_end,
351 ):
352 """
353 Selection may remove the original data which are repeated.
354 These should be re-inserted at the first occuring "key".
355 Next, remove these keys as they've been "promoted" to regular
356 timestamps with data.
357 """
358 # First, "pop" and filter.
359 keys, values = repeat_stress.values.T
360 keep = (keys >= time_start) & (keys <= time_end)
361 new_keys = keys[keep]
362 new_values = values[keep]
363 # Now detect which "value" entries have gone missing
364 insert_values, index = np.unique(new_values, return_index=True)
365 insert_keys = new_keys[index]
366 # Setup indexer
367 indexer = xr.DataArray(
368 data=np.arange(time.size),
369 coords={"time": time},
370 dims=("time",),
371 ).sel(time=insert_values)
372 indexer["time"] = insert_keys
374 # Update the key-value pairs. Discard keys that have been "promoted".
375 keep = np.in1d(new_keys, insert_keys, assume_unique=True, invert=True)
376 new_keys = new_keys[keep]
377 new_values = new_values[keep]
378 # Set the values to their new source.
379 new_values = insert_keys[np.searchsorted(insert_values, new_values)]
380 repeat_stress = xr.DataArray(
381 data=np.column_stack((new_keys, new_values)),
382 dims=("repeat", "repeat_items"),
383 )
384 return indexer, repeat_stress
386 @staticmethod
387 def _clip_time_indexer(
388 time,
389 time_start,
390 time_end,
391 ):
392 original = xr.DataArray(
393 data=np.arange(time.size),
394 coords={"time": time},
395 dims=("time",),
396 )
397 indexer = original.sel(time=slice(time_start, time_end))
399 # The selection might return a 0-sized dimension.
400 if indexer.size > 0:
401 first_time = indexer["time"].values[0]
402 else:
403 first_time = None
405 # If the first time matches exactly, xarray will have done thing we
406 # wanted and our work with the time dimension is finished.
407 if (time_start is not None) and (time_start != first_time):
408 # If the first time is before the original time, we need to
409 # backfill; otherwise, we need to ffill the first timestamp.
410 if time_start < time[0]:
411 method = "bfill"
412 else:
413 method = "ffill"
414 # Index with a list rather than a scalar to preserve the time
415 # dimension.
416 first = original.sel(time=[time_start], method=method)
417 first["time"] = [time_start]
418 indexer = xr.concat([first, indexer], dim="time")
420 return indexer
422 def __to_datetime(self, time, use_cftime):
423 """
424 Helper function that converts to datetime, except when None.
425 """
426 if time is None:
427 return time
428 else:
429 return imod.util.time.to_datetime_internal(time, use_cftime)
431 def clip_box(
432 self,
433 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
434 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
435 layer_min: Optional[int] = None,
436 layer_max: Optional[int] = None,
437 x_min: Optional[float] = None,
438 x_max: Optional[float] = None,
439 y_min: Optional[float] = None,
440 y_max: Optional[float] = None,
441 top: Optional[GridDataArray] = None,
442 bottom: Optional[GridDataArray] = None,
443 state_for_boundary: Optional[GridDataArray] = None,
444 ) -> Package:
445 """
446 Clip a package by a bounding box (time, layer, y, x).
448 Slicing intervals may be half-bounded, by providing None:
450 * To select 500.0 <= x <= 1000.0:
451 ``clip_box(x_min=500.0, x_max=1000.0)``.
452 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
453 or ``clip_box(x_max=1000.0)``.
454 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
455 or ``clip_box(x_min=1000.0)``.
457 Parameters
458 ----------
459 time_min: optional
460 time_max: optional
461 layer_min: optional, int
462 layer_max: optional, int
463 x_min: optional, float
464 x_max: optional, float
465 y_min: optional, float
466 y_max: optional, float
467 top: optional, GridDataArray
468 bottom: optional, GridDataArray
469 state_for_boundary: optional, GridDataArray
472 Returns
473 -------
474 clipped: Package
475 """
476 if not self.is_clipping_supported():
477 raise ValueError("this package does not support clipping.")
479 selection = self.dataset
480 if "time" in selection:
481 time = selection["time"].values
482 use_cftime = isinstance(time[0], cftime.datetime)
483 time_start = self.__to_datetime(time_min, use_cftime)
484 time_end = self.__to_datetime(time_max, use_cftime)
486 indexer = self._clip_time_indexer(
487 time=time,
488 time_start=time_start,
489 time_end=time_end,
490 )
492 if "repeat_stress" in selection.data_vars and self._valid(
493 selection["repeat_stress"].values[()]
494 ):
495 repeat_indexer, repeat_stress = self._clip_repeat_stress(
496 repeat_stress=selection["repeat_stress"],
497 time=time,
498 time_start=time_start,
499 time_end=time_end,
500 )
501 selection = selection.drop_vars("repeat_stress")
502 selection["repeat_stress"] = repeat_stress
503 indexer = repeat_indexer.combine_first(indexer).astype(int)
505 selection = selection.drop_vars("time").isel(time=indexer)
507 if "layer" in selection.coords:
508 layer_slice = slice(layer_min, layer_max)
509 # Cannot select if it's not a dimension!
510 if "layer" not in selection.dims:
511 selection = (
512 selection.expand_dims("layer")
513 .sel(layer=layer_slice)
514 .squeeze("layer")
515 )
516 else:
517 selection = selection.sel(layer=layer_slice)
519 x_slice = slice(x_min, x_max)
520 y_slice = slice(y_min, y_max)
521 if isinstance(selection, xu.UgridDataset):
522 selection = selection.ugrid.sel(x=x_slice, y=y_slice)
523 elif ("x" in selection.coords) and ("y" in selection.coords):
524 if selection.indexes["y"].is_monotonic_decreasing:
525 y_slice = slice(y_max, y_min)
526 selection = selection.sel(x=x_slice, y=y_slice)
528 cls = type(self)
529 new = cls.__new__(cls)
530 new.dataset = selection
531 return new
533 def mask(self, mask: GridDataArray) -> Any:
534 """
535 Mask values outside of domain.
537 Floating values outside of the condition are set to NaN (nodata).
538 Integer values outside of the condition are set to 0 (inactive in
539 MODFLOW terms).
541 Parameters
542 ----------
543 mask: xr.DataArray, xu.UgridDataArray of ints
544 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive,
545 -1 sets cells to vertical passthrough
547 Returns
548 -------
549 masked: Package
550 The package with part masked.
551 """
553 return _mask(self, mask)
555 def regrid_like(
556 self,
557 target_grid: GridDataArray,
558 regrid_context: RegridderWeightsCache,
559 regridder_types: Optional[dict[str, Tuple[RegridderType, str]]] = None,
560 ) -> "Package":
561 """
562 Creates a package of the same type as this package, based on another discretization.
563 It regrids all the arrays in this package to the desired discretization, and leaves the options
564 unmodified. At the moment only regridding to a different planar grid is supported, meaning
565 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
567 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults
568 that specify how each array should be regridded. These defaults can be overridden using the input
569 parameters of this function.
571 Examples
572 --------
573 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments:
575 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")})
578 Parameters
579 ----------
580 target_grid: xr.DataArray or xu.UgridDataArray
581 a grid defined over the same discretization as the one we want to regrid the package to
582 regridder_types: dict(str->(regridder type,str))
583 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str)
584 this dictionary can be used to override the default mapping method.
585 regrid_context: Optional RegridderWeightsCache
586 stores regridder weights for different regridders. Can be used to speed up regridding,
587 if the same regridders are used several times for regridding different arrays.
589 Returns
590 -------
591 a package with the same options as this package, and with all the data-arrays regridded to another discretization,
592 similar to the one used in input argument "target_grid"
593 """
594 try:
595 result = _regrid_like(self, target_grid, regrid_context, regridder_types)
596 except ValueError as e:
597 raise e
598 except Exception:
599 raise ValueError("package could not be regridded.")
600 return result
602 def _skip_masking_dataarray(self, array_name: str) -> bool:
603 if hasattr(self, "_skip_mask_arrays"):
604 return array_name in self._skip_mask_arrays
605 return False
607 @classmethod
608 def is_grid_agnostic_package(cls) -> bool:
609 return False
611 def __repr__(self) -> str:
612 typename = type(self).__name__
613 return f"{typename}\n{self.dataset.__repr__()}"
615 def _repr_html_(self) -> str:
616 typename = type(self).__name__
617 return f"<div>{typename}</div>{self.dataset._repr_html_()}"
619 @property
620 def auxiliary_data_fields(self) -> dict[str, str]:
621 if hasattr(self, "_auxiliary_data"):
622 return self._auxiliary_data
623 return {}
625 def get_non_grid_data(self, grid_names: list[str]) -> dict[str, Any]:
626 """
627 This function copies the attributes of a dataset that are scalars, such as options.
629 parameters
630 ----------
631 grid_names: list of str
632 the names of the attribbutes of a dataset that are grids.
633 """
634 result = {}
635 all_non_grid_data = list(self.dataset.keys())
636 for name in (
637 gridname for gridname in grid_names if gridname in all_non_grid_data
638 ):
639 all_non_grid_data.remove(name)
640 for name in all_non_grid_data:
641 if "time" in self.dataset[name].coords:
642 result[name] = self.dataset[name]
643 else:
644 result[name] = self.dataset[name].values[()]
645 return result
647 def is_splitting_supported(self) -> bool:
648 return True
650 def is_regridding_supported(self) -> bool:
651 return True
653 def is_clipping_supported(self) -> bool:
654 return True