Coverage for C:\src\imod-python\imod\mf6\model.py: 90%
272 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-08 14:15 +0200
1from __future__ import annotations
3import abc
4import collections
5import inspect
6import pathlib
7from copy import deepcopy
8from pathlib import Path
9from typing import Optional, Tuple, Union
11import cftime
12import jinja2
13import numpy as np
14import tomli
15import tomli_w
16import xarray as xr
17import xugrid as xu
18from jinja2 import Template
20import imod
21from imod.logging import standard_log_decorator
22from imod.mf6.interfaces.imodel import IModel
23from imod.mf6.package import Package
24from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
25from imod.mf6.utilities.mask import _mask_all_packages
26from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like
27from imod.mf6.validation import pkg_errors_to_status_info
28from imod.mf6.write_context import WriteContext
29from imod.schemata import ValidationError
30from imod.typing import GridDataArray
33class Modflow6Model(collections.UserDict, IModel, abc.ABC):
34 _mandatory_packages: tuple[str, ...] = ()
35 _model_id: Optional[str] = None
36 _template: Template
38 @staticmethod
39 def _initialize_template(name: str) -> Template:
40 loader = jinja2.PackageLoader("imod", "templates/mf6")
41 env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
42 return env.get_template(name)
44 def __init__(self, **kwargs):
45 collections.UserDict.__init__(self)
46 for k, v in kwargs.items():
47 self[k] = v
49 self._options = {}
51 def __setitem__(self, key, value):
52 if len(key) > 16:
53 raise KeyError(
54 f"Received key with more than 16 characters: '{key}'"
55 "Modflow 6 has a character limit of 16."
56 )
58 super().__setitem__(key, value)
60 def update(self, *args, **kwargs):
61 for k, v in dict(*args, **kwargs).items():
62 self[k] = v
64 def _get_diskey(self):
65 dis_pkg_ids = ["dis", "disv", "disu"]
67 diskeys = [
68 self._get_pkgkey(pkg_id)
69 for pkg_id in dis_pkg_ids
70 if self._get_pkgkey(pkg_id) is not None
71 ]
73 if len(diskeys) > 1:
74 raise ValueError(f"Found multiple discretizations {diskeys}")
75 elif len(diskeys) == 0:
76 raise ValueError("No model discretization found")
77 else:
78 return diskeys[0]
80 def _get_pkgkey(self, pkg_id):
81 """
82 Get package key that belongs to a certain pkg_id, since the keys are
83 user specified.
84 """
85 key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id]
86 nkey = len(key)
87 if nkey > 1:
88 raise ValueError(f"Multiple instances of {key} detected")
89 elif nkey == 1:
90 return key[0]
91 else:
92 return None
94 def _check_for_required_packages(self, modelkey: str) -> None:
95 # Check for mandatory packages
96 pkg_ids = {pkg._pkg_id for pkg in self.values()}
97 dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids
98 if not dispresent:
99 raise ValueError(f"No dis/disv/disu package found in model {modelkey}")
100 for required in self._mandatory_packages:
101 if required not in pkg_ids:
102 raise ValueError(f"No {required} package found in model {modelkey}")
103 return
105 def _use_cftime(self):
106 """
107 Also checks if datetime types are homogeneous across packages.
108 """
109 types = [
110 type(pkg.dataset["time"].values[0])
111 for pkg in self.values()
112 if "time" in pkg.dataset.coords
113 ]
114 set_of_types = set(types)
115 # Types will be empty if there's no time dependent input
116 if len(set_of_types) == 0:
117 return False
118 else: # there is time dependent input
119 if not len(set_of_types) == 1:
120 raise ValueError(
121 f"Multiple datetime types detected: {set_of_types}"
122 "Use either cftime or numpy.datetime64[ns]."
123 )
124 # Since we compare types and not instances, we use issubclass
125 if issubclass(types[0], cftime.datetime):
126 return True
127 elif issubclass(types[0], np.datetime64):
128 return False
129 else:
130 raise ValueError("Use either cftime or numpy.datetime64[ns].")
132 def _yield_times(self):
133 modeltimes = []
134 for pkg in self.values():
135 if "time" in pkg.dataset.coords:
136 modeltimes.append(pkg.dataset["time"].values)
137 repeat_stress = pkg.dataset.get("repeat_stress")
138 if repeat_stress is not None and repeat_stress.values[()] is not None:
139 modeltimes.append(repeat_stress.isel(repeat_items=0).values)
140 return modeltimes
142 def render(self, modelname: str, write_context: WriteContext):
143 dir_for_render = write_context.root_directory / modelname
145 d = {k: v for k, v in self._options.items() if not (v is None or v is False)}
146 packages = []
147 for pkgname, pkg in self.items():
148 # Add the six to the package id
149 pkg_id = pkg._pkg_id
150 key = f"{pkg_id}6"
151 path = dir_for_render / f"{pkgname}.{pkg_id}"
152 packages.append((key, path.as_posix(), pkgname))
153 d["packages"] = packages
154 return self._template.render(d)
156 def _model_checks(self, modelkey: str):
157 """
158 Check model integrity (called before writing)
159 """
161 self._check_for_required_packages(modelkey)
163 def __get_domain_geometry(
164 self,
165 ) -> tuple[
166 Union[xr.DataArray, xu.UgridDataArray],
167 Union[xr.DataArray, xu.UgridDataArray],
168 Union[xr.DataArray, xu.UgridDataArray],
169 ]:
170 discretization = self[self._get_diskey()]
171 if discretization is None:
172 raise ValueError("Discretization not found")
173 top = discretization["top"]
174 bottom = discretization["bottom"]
175 idomain = discretization["idomain"]
176 return top, bottom, idomain
178 def __get_k(self):
179 try:
180 npf = self[imod.mf6.NodePropertyFlow._pkg_id]
181 except RuntimeError:
182 raise ValidationError("expected one package of type ModePropertyFlow")
184 k = npf["k"]
185 return k
187 @standard_log_decorator()
188 def validate(self, model_name: str = "") -> StatusInfoBase:
189 try:
190 diskey = self._get_diskey()
191 except Exception as e:
192 status_info = StatusInfo(f"{model_name} model")
193 status_info.add_error(str(e))
194 return status_info
196 dis = self[diskey]
197 # We'll use the idomain for checking dims, shape, nodata.
198 idomain = dis["idomain"]
199 bottom = dis["bottom"]
201 model_status_info = NestedStatusInfo(f"{model_name} model")
202 for pkg_name, pkg in self.items():
203 # Check for all schemata when writing. Types and dimensions
204 # may have been changed after initialization...
206 if pkg_name in ["adv"]:
207 continue # some packages can be skipped
209 # Concatenate write and init schemata.
210 schemata = deepcopy(pkg._init_schemata)
211 for key, value in pkg._write_schemata.items():
212 if key not in schemata.keys():
213 schemata[key] = value
214 else:
215 schemata[key] += value
217 pkg_errors = pkg._validate(
218 schemata=schemata,
219 idomain=idomain,
220 bottom=bottom,
221 )
222 if len(pkg_errors) > 0:
223 model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors))
225 return model_status_info
227 @standard_log_decorator()
228 def write(
229 self, modelname, globaltimes, validate: bool, write_context: WriteContext
230 ) -> StatusInfoBase:
231 """
232 Write model namefile
233 Write packages
234 """
236 workdir = write_context.simulation_directory
237 modeldirectory = workdir / modelname
238 Path(modeldirectory).mkdir(exist_ok=True, parents=True)
239 if validate:
240 model_status_info = self.validate(modelname)
241 if model_status_info.has_errors():
242 return model_status_info
244 # write model namefile
245 namefile_content = self.render(modelname, write_context)
246 namefile_path = modeldirectory / f"{modelname}.nam"
247 with open(namefile_path, "w") as f:
248 f.write(namefile_content)
250 # write package contents
251 pkg_write_context = write_context.copy_with_new_write_directory(
252 new_write_directory=modeldirectory
253 )
254 for pkg_name, pkg in self.items():
255 try:
256 if isinstance(pkg, imod.mf6.Well):
257 top, bottom, idomain = self.__get_domain_geometry()
258 k = self.__get_k()
259 mf6_well_pkg = pkg.to_mf6_pkg(
260 idomain,
261 top,
262 bottom,
263 k,
264 validate,
265 pkg_write_context.is_partitioned,
266 )
268 mf6_well_pkg.write(
269 pkgname=pkg_name,
270 globaltimes=globaltimes,
271 write_context=pkg_write_context,
272 )
273 elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase):
274 top, bottom, idomain = self.__get_domain_geometry()
275 k = self.__get_k()
276 mf6_hfb_pkg = pkg.to_mf6_pkg(idomain, top, bottom, k, validate)
277 mf6_hfb_pkg.write(
278 pkgname=pkg_name,
279 globaltimes=globaltimes,
280 write_context=pkg_write_context,
281 )
282 else:
283 pkg.write(
284 pkgname=pkg_name,
285 globaltimes=globaltimes,
286 write_context=pkg_write_context,
287 )
288 except Exception as e:
289 raise type(e)(f"{e}\nError occured while writing {pkg_name}")
291 return NestedStatusInfo(modelname)
293 @standard_log_decorator()
294 def dump(
295 self, directory, modelname, validate: bool = True, mdal_compliant: bool = False
296 ):
297 modeldirectory = pathlib.Path(directory) / modelname
298 modeldirectory.mkdir(exist_ok=True, parents=True)
299 if validate:
300 statusinfo = self.validate()
301 if statusinfo.has_errors():
302 raise ValidationError(statusinfo.to_string())
304 toml_content: dict = collections.defaultdict(dict)
305 for pkgname, pkg in self.items():
306 pkg_path = f"{pkgname}.nc"
307 toml_content[type(pkg).__name__][pkgname] = pkg_path
308 dataset = pkg.dataset
309 if isinstance(dataset, xu.UgridDataset):
310 if mdal_compliant:
311 dataset = pkg.dataset.ugrid.to_dataset()
312 mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d(dataset)
313 mdal_dataset.to_netcdf(modeldirectory / pkg_path)
314 else:
315 pkg.dataset.ugrid.to_netcdf(modeldirectory / pkg_path)
316 else:
317 pkg.to_netcdf(modeldirectory / pkg_path)
319 toml_path = modeldirectory / f"{modelname}.toml"
320 with open(toml_path, "wb") as f:
321 tomli_w.dump(toml_content, f)
323 return toml_path
325 @classmethod
326 def from_file(cls, toml_path):
327 pkg_classes = {
328 name: pkg_cls
329 for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass)
330 if issubclass(pkg_cls, Package)
331 }
333 toml_path = pathlib.Path(toml_path)
334 with open(toml_path, "rb") as f:
335 toml_content = tomli.load(f)
337 parentdir = toml_path.parent
338 instance = cls()
339 for key, entry in toml_content.items():
340 for pkgname, path in entry.items():
341 pkg_cls = pkg_classes[key]
342 instance[pkgname] = pkg_cls.from_file(parentdir / path)
344 return instance
346 @classmethod
347 def model_id(cls) -> str:
348 if cls._model_id is None:
349 raise ValueError("Model id has not been set")
350 return cls._model_id
352 def clip_box(
353 self,
354 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
355 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
356 layer_min: Optional[int] = None,
357 layer_max: Optional[int] = None,
358 x_min: Optional[float] = None,
359 x_max: Optional[float] = None,
360 y_min: Optional[float] = None,
361 y_max: Optional[float] = None,
362 state_for_boundary: Optional[GridDataArray] = None,
363 ):
364 """
365 Clip a model by a bounding box (time, layer, y, x).
367 Slicing intervals may be half-bounded, by providing None:
369 * To select 500.0 <= x <= 1000.0:
370 ``clip_box(x_min=500.0, x_max=1000.0)``.
371 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
372 or ``clip_box(x_max=1000.0)``.
373 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
374 or ``clip_box(x_min=1000.0)``.
376 Parameters
377 ----------
378 time_min: optional
379 time_max: optional
380 layer_min: optional, int
381 layer_max: optional, int
382 x_min: optional, float
383 x_max: optional, float
384 y_min: optional, float
385 y_max: optional, float
386 state_for_boundary: optional, float
387 """
388 supported, error_with_object = self.is_clipping_supported()
389 if not supported:
390 raise ValueError(
391 f"model cannot be clipped due to presence of package '{error_with_object}' in model"
392 )
394 clipped = self._clip_box_packages(
395 time_min,
396 time_max,
397 layer_min,
398 layer_max,
399 x_min,
400 x_max,
401 y_min,
402 y_max,
403 )
405 return clipped
407 def _clip_box_packages(
408 self,
409 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
410 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
411 layer_min: Optional[int] = None,
412 layer_max: Optional[int] = None,
413 x_min: Optional[float] = None,
414 x_max: Optional[float] = None,
415 y_min: Optional[float] = None,
416 y_max: Optional[float] = None,
417 ):
418 """
419 Clip a model by a bounding box (time, layer, y, x).
421 Slicing intervals may be half-bounded, by providing None:
423 * To select 500.0 <= x <= 1000.0:
424 ``clip_box(x_min=500.0, x_max=1000.0)``.
425 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
426 or ``clip_box(x_max=1000.0)``.
427 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
428 or ``clip_box(x_min=1000.0)``.
430 Parameters
431 ----------
432 time_min: optional
433 time_max: optional
434 layer_min: optional, int
435 layer_max: optional, int
436 x_min: optional, float
437 x_max: optional, float
438 y_min: optional, float
439 y_max: optional, float
441 Returns
442 -------
443 clipped : Modflow6Model
444 """
446 top, bottom, idomain = self.__get_domain_geometry()
448 clipped = type(self)(**self._options)
449 for key, pkg in self.items():
450 clipped[key] = pkg.clip_box(
451 time_min=time_min,
452 time_max=time_max,
453 layer_min=layer_min,
454 layer_max=layer_max,
455 x_min=x_min,
456 x_max=x_max,
457 y_min=y_min,
458 y_max=y_max,
459 top=top,
460 bottom=bottom,
461 )
463 return clipped
465 def regrid_like(
466 self,
467 target_grid: GridDataArray,
468 validate: bool = True,
469 regrid_context: Optional[RegridderWeightsCache] = None,
470 ) -> "Modflow6Model":
471 """
472 Creates a model by regridding the packages of this model to another discretization.
473 It regrids all the arrays in the package using the default regridding methods.
474 At the moment only regridding to a different planar grid is supported, meaning
475 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
477 Parameters
478 ----------
479 target_grid: xr.DataArray or xu.UgridDataArray
480 a grid defined over the same discretization as the one we want to regrid the package to
481 validate: bool
482 set to true to validate the regridded packages
483 regrid_context: Optional RegridderWeightsCache
484 stores regridder weights for different regridders. Can be used to speed up regridding,
485 if the same regridders are used several times for regridding different arrays.
487 Returns
488 -------
489 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization,
490 similar to the one used in input argument "target_grid"
491 """
492 return _regrid_like(self, target_grid, validate, regrid_context)
494 def mask_all_packages(
495 self,
496 mask: GridDataArray,
497 ):
498 """
499 This function applies a mask to all packages in a model. The mask must
500 be presented as an idomain-like integer array that has 0 (inactive) or
501 -1 (vertical passthrough) values in filtered cells and 1 in active
502 cells.
503 Masking will overwrite idomain with the mask where the mask is 0 or -1.
504 Where the mask is 1, the original value of idomain will be kept. Masking
505 will update the packages accordingly, blanking their input where needed,
506 and is therefore not a reversible operation.
508 Parameters
509 ----------
510 mask: xr.DataArray, xu.UgridDataArray of ints
511 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive,
512 -1 sets cells to vertical passthrough
513 """
515 _mask_all_packages(self, mask)
517 def purge_empty_packages(self, model_name: Optional[str] = "") -> None:
518 """
519 This function removes empty packages from the model.
520 """
521 empty_packages = [
522 package_name for package_name, package in self.items() if package.is_empty()
523 ]
524 for package_name in empty_packages:
525 self.pop(package_name)
527 @property
528 def domain(self):
529 dis = self._get_diskey()
530 return self[dis]["idomain"]
532 @property
533 def bottom(self):
534 dis = self._get_diskey()
535 return self[dis]["bottom"]
537 def __repr__(self) -> str:
538 INDENT = " "
539 typename = type(self).__name__
540 options = [
541 f"{INDENT}{key}={repr(value)}," for key, value in self._options.items()
542 ]
543 packages = [
544 f"{INDENT}{repr(key)}: {type(value).__name__},"
545 for key, value in self.items()
546 ]
547 # Place the emtpy dict on the same line. Looks silly otherwise.
548 if packages:
549 content = [f"{typename}("] + options + ["){"] + packages + ["}"]
550 else:
551 content = [f"{typename}("] + options + ["){}"]
552 return "\n".join(content)
554 def is_use_newton(self):
555 return False
557 def is_splitting_supported(self) -> Tuple[bool, str]:
558 """
559 Returns True if all the packages in the model supports splitting. If one
560 of the packages in the model does not support splitting, it returns the
561 name of the first one.
562 """
563 for package_name, package in self.items():
564 if not package.is_splitting_supported():
565 return False, package_name
566 return True, ""
568 def is_regridding_supported(self) -> Tuple[bool, str]:
569 """
570 Returns True if all the packages in the model supports regridding. If one
571 of the packages in the model does not support regridding, it returns the
572 name of the first one.
573 """
574 for package_name, package in self.items():
575 if not package.is_regridding_supported():
576 return False, package_name
577 return True, ""
579 def is_clipping_supported(self) -> Tuple[bool, str]:
580 """
581 Returns True if all the packages in the model supports clipping. If one
582 of the packages in the model does not support clipping, it returns the
583 name of the first one.
584 """
585 for package_name, package in self.items():
586 if not package.is_clipping_supported():
587 return False, package_name
588 return True, ""