Coverage for C:\src\imod-python\imod\mf6\model.py: 88%
272 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-08 13:27 +0200
1from __future__ import annotations
3import abc
4import collections
5import inspect
6import pathlib
7from copy import deepcopy
8from pathlib import Path
9from typing import Optional, Tuple, Union
11import cftime
12import jinja2
13import numpy as np
14import tomli
15import tomli_w
16import xarray as xr
17import xugrid as xu
18from jinja2 import Template
20import imod
21from imod.logging import standard_log_decorator
22from imod.mf6.interfaces.imodel import IModel
23from imod.mf6.package import Package
24from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
25from imod.mf6.utilities.mask import _mask_all_packages
26from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like
27from imod.mf6.validation import pkg_errors_to_status_info
28from imod.mf6.write_context import WriteContext
29from imod.schemata import ValidationError
30from imod.typing import GridDataArray
33class Modflow6Model(collections.UserDict, IModel, abc.ABC):
34 _mandatory_packages: tuple[str, ...] = ()
35 _model_id: Optional[str] = None
36 _template: Template
38 @staticmethod
39 def _initialize_template(name: str) -> Template:
40 loader = jinja2.PackageLoader("imod", "templates/mf6")
41 env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
42 return env.get_template(name)
44 def __init__(self, **kwargs):
45 collections.UserDict.__init__(self)
46 for k, v in kwargs.items():
47 self[k] = v
49 self._options = {}
51 def __setitem__(self, key, value):
52 if len(key) > 16:
53 raise KeyError(
54 f"Received key with more than 16 characters: '{key}'"
55 "Modflow 6 has a character limit of 16."
56 )
58 super().__setitem__(key, value)
60 def update(self, *args, **kwargs):
61 for k, v in dict(*args, **kwargs).items():
62 self[k] = v
64 def _get_diskey(self):
65 dis_pkg_ids = ["dis", "disv", "disu"]
67 diskeys = [
68 self._get_pkgkey(pkg_id)
69 for pkg_id in dis_pkg_ids
70 if self._get_pkgkey(pkg_id) is not None
71 ]
73 if len(diskeys) > 1:
74 raise ValueError(f"Found multiple discretizations {diskeys}")
75 elif len(diskeys) == 0:
76 raise ValueError("No model discretization found")
77 else:
78 return diskeys[0]
80 def _get_pkgkey(self, pkg_id):
81 """
82 Get package key that belongs to a certain pkg_id, since the keys are
83 user specified.
84 """
85 key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id]
86 nkey = len(key)
87 if nkey > 1:
88 raise ValueError(f"Multiple instances of {key} detected")
89 elif nkey == 1:
90 return key[0]
91 else:
92 return None
94 def _check_for_required_packages(self, modelkey: str) -> None:
95 # Check for mandatory packages
96 pkg_ids = set([pkg._pkg_id for pkg in self.values()])
97 dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids
98 if not dispresent:
99 raise ValueError(f"No dis/disv/disu package found in model {modelkey}")
100 for required in self._mandatory_packages:
101 if required not in pkg_ids:
102 raise ValueError(f"No {required} package found in model {modelkey}")
103 return
105 def _use_cftime(self):
106 """
107 Also checks if datetime types are homogeneous across packages.
108 """
109 types = [
110 type(pkg.dataset["time"].values[0])
111 for pkg in self.values()
112 if "time" in pkg.dataset.coords
113 ]
114 set_of_types = set(types)
115 # Types will be empty if there's no time dependent input
116 if len(set_of_types) == 0:
117 return False
118 else: # there is time dependent input
119 if not len(set_of_types) == 1:
120 raise ValueError(
121 f"Multiple datetime types detected: {set_of_types}"
122 "Use either cftime or numpy.datetime64[ns]."
123 )
124 # Since we compare types and not instances, we use issubclass
125 if issubclass(types[0], cftime.datetime):
126 return True
127 elif issubclass(types[0], np.datetime64):
128 return False
129 else:
130 raise ValueError("Use either cftime or numpy.datetime64[ns].")
132 def _yield_times(self):
133 modeltimes = []
134 for pkg in self.values():
135 if "time" in pkg.dataset.coords:
136 modeltimes.append(pkg.dataset["time"].values)
137 repeat_stress = pkg.dataset.get("repeat_stress")
138 if repeat_stress is not None and repeat_stress.values[()] is not None:
139 modeltimes.append(repeat_stress.isel(repeat_items=0).values)
140 return modeltimes
142 def render(self, modelname: str, write_context: WriteContext):
143 dir_for_render = write_context.root_directory / modelname
145 d = {k: v for k, v in self._options.items() if not (v is None or v is False)}
146 packages = []
147 for pkgname, pkg in self.items():
148 # Add the six to the package id
149 pkg_id = pkg._pkg_id
150 key = f"{pkg_id}6"
151 path = dir_for_render / f"{pkgname}.{pkg_id}"
152 packages.append((key, path.as_posix(), pkgname))
153 d["packages"] = packages
154 return self._template.render(d)
156 def _model_checks(self, modelkey: str):
157 """
158 Check model integrity (called before writing)
159 """
161 self._check_for_required_packages(modelkey)
163 def __get_domain_geometry(
164 self,
165 ) -> tuple[
166 Union[xr.DataArray, xu.UgridDataArray],
167 Union[xr.DataArray, xu.UgridDataArray],
168 Union[xr.DataArray, xu.UgridDataArray],
169 ]:
170 discretization = self[self._get_diskey()]
171 if discretization is None:
172 raise ValueError("Discretization not found")
173 top = discretization["top"]
174 bottom = discretization["bottom"]
175 idomain = discretization["idomain"]
176 return top, bottom, idomain
178 def __get_k(self):
179 try:
180 npf = self[imod.mf6.NodePropertyFlow._pkg_id]
181 except RuntimeError:
182 raise ValidationError("expected one package of type ModePropertyFlow")
184 k = npf["k"]
185 return k
187 @standard_log_decorator()
188 def validate(self, model_name: str = "") -> StatusInfoBase:
189 try:
190 diskey = self._get_diskey()
191 except Exception as e:
192 status_info = StatusInfo(f"{model_name} model")
193 status_info.add_error(str(e))
194 return status_info
196 dis = self[diskey]
197 # We'll use the idomain for checking dims, shape, nodata.
198 idomain = dis["idomain"]
199 bottom = dis["bottom"]
201 model_status_info = NestedStatusInfo(f"{model_name} model")
202 for pkg_name, pkg in self.items():
203 # Check for all schemata when writing. Types and dimensions
204 # may have been changed after initialization...
206 if pkg_name in ["adv"]:
207 continue # some packages can be skipped
209 # Concatenate write and init schemata.
210 schemata = deepcopy(pkg._init_schemata)
211 for key, value in pkg._write_schemata.items():
212 if key not in schemata.keys():
213 schemata[key] = value
214 else:
215 schemata[key] += value
217 pkg_errors = pkg._validate(
218 schemata=schemata,
219 idomain=idomain,
220 bottom=bottom,
221 )
222 if len(pkg_errors) > 0:
223 model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors))
225 return model_status_info
227 @standard_log_decorator()
228 def write(
229 self, modelname, globaltimes, validate: bool, write_context: WriteContext
230 ) -> StatusInfoBase:
231 """
232 Write model namefile
233 Write packages
234 """
236 workdir = write_context.simulation_directory
237 modeldirectory = workdir / modelname
238 Path(modeldirectory).mkdir(exist_ok=True, parents=True)
239 if validate:
240 model_status_info = self.validate(modelname)
241 if model_status_info.has_errors():
242 return model_status_info
244 # write model namefile
245 namefile_content = self.render(modelname, write_context)
246 namefile_path = modeldirectory / f"{modelname}.nam"
247 with open(namefile_path, "w") as f:
248 f.write(namefile_content)
250 # write package contents
251 pkg_write_context = write_context.copy_with_new_write_directory(
252 new_write_directory=modeldirectory
253 )
254 for pkg_name, pkg in self.items():
255 try:
256 if isinstance(pkg, imod.mf6.Well):
257 top, bottom, idomain = self.__get_domain_geometry()
258 k = self.__get_k()
259 mf6_well_pkg = pkg.to_mf6_pkg(
260 idomain,
261 top,
262 bottom,
263 k,
264 validate,
265 pkg_write_context.is_partitioned,
266 )
268 mf6_well_pkg.write(
269 pkgname=pkg_name,
270 globaltimes=globaltimes,
271 write_context=pkg_write_context,
272 )
273 elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase):
274 top, bottom, idomain = self.__get_domain_geometry()
275 k = self.__get_k()
276 mf6_hfb_pkg = pkg.to_mf6_pkg(idomain, top, bottom, k, validate)
277 mf6_hfb_pkg.write(
278 pkgname=pkg_name,
279 globaltimes=globaltimes,
280 write_context=pkg_write_context,
281 )
282 else:
283 pkg.write(
284 pkgname=pkg_name,
285 globaltimes=globaltimes,
286 write_context=pkg_write_context,
287 )
288 except Exception as e:
289 raise type(e)(f"{e}\nError occured while writing {pkg_name}")
291 return NestedStatusInfo(modelname)
293 @standard_log_decorator()
294 def dump(
295 self, directory, modelname, validate: bool = True, mdal_compliant: bool = False
296 ):
297 modeldirectory = pathlib.Path(directory) / modelname
298 modeldirectory.mkdir(exist_ok=True, parents=True)
299 if validate:
300 statusinfo = self.validate()
301 if statusinfo.has_errors():
302 raise ValidationError(statusinfo.to_string())
304 toml_content: dict = collections.defaultdict(dict)
305 for pkgname, pkg in self.items():
306 pkg_path = f"{pkgname}.nc"
307 toml_content[type(pkg).__name__][pkgname] = pkg_path
308 dataset = pkg.dataset
309 if isinstance(dataset, xu.UgridDataset):
310 if mdal_compliant:
311 dataset = pkg.dataset.ugrid.to_dataset()
312 mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d(dataset)
313 mdal_dataset.to_netcdf(modeldirectory / pkg_path)
314 else:
315 pkg.dataset.ugrid.to_netcdf(modeldirectory / pkg_path)
316 else:
317 pkg.to_netcdf(modeldirectory / pkg_path)
319 toml_path = modeldirectory / f"{modelname}.toml"
320 with open(toml_path, "wb") as f:
321 tomli_w.dump(toml_content, f)
323 return toml_path
325 @classmethod
326 def from_file(cls, toml_path):
327 pkg_classes = {
328 name: pkg_cls
329 for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass)
330 if issubclass(pkg_cls, Package)
331 }
333 toml_path = pathlib.Path(toml_path)
334 with open(toml_path, "rb") as f:
335 toml_content = tomli.load(f)
337 parentdir = toml_path.parent
338 instance = cls()
339 for key, entry in toml_content.items():
340 for pkgname, path in entry.items():
341 pkg_cls = pkg_classes[key]
342 instance[pkgname] = pkg_cls.from_file(parentdir / path)
344 return instance
346 @classmethod
347 def model_id(cls) -> str:
348 if cls._model_id is None:
349 raise ValueError("Model id has not been set")
350 return cls._model_id
352 def clip_box(
353 self,
354 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
355 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
356 layer_min: Optional[int] = None,
357 layer_max: Optional[int] = None,
358 x_min: Optional[float] = None,
359 x_max: Optional[float] = None,
360 y_min: Optional[float] = None,
361 y_max: Optional[float] = None,
362 state_for_boundary: Optional[GridDataArray] = None,
363 ):
364 """
365 Clip a model by a bounding box (time, layer, y, x).
367 Slicing intervals may be half-bounded, by providing None:
369 * To select 500.0 <= x <= 1000.0:
370 ``clip_box(x_min=500.0, x_max=1000.0)``.
371 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
372 or ``clip_box(x_max=1000.0)``.
373 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
374 or ``clip_box(x_min=1000.0)``.
376 Parameters
377 ----------
378 time_min: optional
379 time_max: optional
380 layer_min: optional, int
381 layer_max: optional, int
382 x_min: optional, float
383 x_max: optional, float
384 y_min: optional, float
385 y_max: optional, float
386 state_for_boundary :
387 """
388 supported, error_with_object = self.is_clipping_supported()
389 if not supported:
390 raise ValueError(
391 f"model cannot be clipped due to presence of package '{error_with_object}' in model"
392 )
394 clipped = self._clip_box_packages(
395 time_min,
396 time_max,
397 layer_min,
398 layer_max,
399 x_min,
400 x_max,
401 y_min,
402 y_max,
403 state_for_boundary,
404 )
406 return clipped
408 def _clip_box_packages(
409 self,
410 time_min: Optional[cftime.datetime | np.datetime64 | str] = None,
411 time_max: Optional[cftime.datetime | np.datetime64 | str] = None,
412 layer_min: Optional[int] = None,
413 layer_max: Optional[int] = None,
414 x_min: Optional[float] = None,
415 x_max: Optional[float] = None,
416 y_min: Optional[float] = None,
417 y_max: Optional[float] = None,
418 state_for_boundary: Optional[GridDataArray] = None,
419 ):
420 """
421 Clip a model by a bounding box (time, layer, y, x).
423 Slicing intervals may be half-bounded, by providing None:
425 * To select 500.0 <= x <= 1000.0:
426 ``clip_box(x_min=500.0, x_max=1000.0)``.
427 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
428 or ``clip_box(x_max=1000.0)``.
429 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
430 or ``clip_box(x_min=1000.0)``.
432 Parameters
433 ----------
434 time_min: optional
435 time_max: optional
436 layer_min: optional, int
437 layer_max: optional, int
438 x_min: optional, float
439 x_max: optional, float
440 y_min: optional, float
441 y_max: optional, float
443 Returns
444 -------
445 clipped : Modflow6Model
446 """
448 top, bottom, idomain = self.__get_domain_geometry()
450 clipped = type(self)(**self._options)
451 for key, pkg in self.items():
452 clipped[key] = pkg.clip_box(
453 time_min=time_min,
454 time_max=time_max,
455 layer_min=layer_min,
456 layer_max=layer_max,
457 x_min=x_min,
458 x_max=x_max,
459 y_min=y_min,
460 y_max=y_max,
461 top=top,
462 bottom=bottom,
463 state_for_boundary=state_for_boundary,
464 )
466 return clipped
468 def regrid_like(
469 self,
470 target_grid: GridDataArray,
471 validate: bool = True,
472 regrid_context: Optional[RegridderWeightsCache] = None,
473 ) -> "Modflow6Model":
474 """
475 Creates a model by regridding the packages of this model to another discretization.
476 It regrids all the arrays in the package using the default regridding methods.
477 At the moment only regridding to a different planar grid is supported, meaning
478 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords.
480 Parameters
481 ----------
482 target_grid: xr.DataArray or xu.UgridDataArray
483 a grid defined over the same discretization as the one we want to regrid the package to
484 validate: bool
485 set to true to validate the regridded packages
486 regrid_context: Optional RegridderWeightsCache
487 stores regridder weights for different regridders. Can be used to speed up regridding,
488 if the same regridders are used several times for regridding different arrays.
490 Returns
491 -------
492 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization,
493 similar to the one used in input argument "target_grid"
494 """
495 return _regrid_like(self, target_grid, validate, regrid_context)
497 def mask_all_packages(
498 self,
499 mask: GridDataArray,
500 ):
501 """
502 This function applies a mask to all packages in a model. The mask must
503 be presented as an idomain-like integer array that has 0 (inactive) or
504 -1 (vertical passthrough) values in filtered cells and 1 in active
505 cells.
506 Masking will overwrite idomain with the mask where the mask is 0 or -1.
507 Where the mask is 1, the original value of idomain will be kept. Masking
508 will update the packages accordingly, blanking their input where needed,
509 and is therefore not a reversible operation.
511 Parameters
512 ----------
513 mask: xr.DataArray, xu.UgridDataArray of ints
514 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive,
515 -1 sets cells to vertical passthrough
516 """
518 _mask_all_packages(self, mask)
520 def purge_empty_packages(self, model_name: Optional[str] = "") -> None:
521 """
522 This function removes empty packages from the model.
523 """
524 empty_packages = [
525 package_name for package_name, package in self.items() if package.is_empty()
526 ]
527 for package_name in empty_packages:
528 self.pop(package_name)
530 @property
531 def domain(self):
532 dis = self._get_diskey()
533 return self[dis]["idomain"]
535 @property
536 def bottom(self):
537 dis = self._get_diskey()
538 return self[dis]["bottom"]
540 def __repr__(self) -> str:
541 INDENT = " "
542 typename = type(self).__name__
543 options = [
544 f"{INDENT}{key}={repr(value)}," for key, value in self._options.items()
545 ]
546 packages = [
547 f"{INDENT}{repr(key)}: {type(value).__name__},"
548 for key, value in self.items()
549 ]
550 # Place the emtpy dict on the same line. Looks silly otherwise.
551 if packages:
552 content = [f"{typename}("] + options + ["){"] + packages + ["}"]
553 else:
554 content = [f"{typename}("] + options + ["){}"]
555 return "\n".join(content)
557 def is_use_newton(self):
558 return False
560 def is_splitting_supported(self) -> Tuple[bool, str]:
561 """
562 Returns True if all the packages in the model supports splitting. If one
563 of the packages in the model does not support splitting, it returns the
564 name of the first one.
565 """
566 for package_name, package in self.items():
567 if not package.is_splitting_supported():
568 return False, package_name
569 return True, ""
571 def is_regridding_supported(self) -> Tuple[bool, str]:
572 """
573 Returns True if all the packages in the model supports regridding. If one
574 of the packages in the model does not support regridding, it returns the
575 name of the first one.
576 """
577 for package_name, package in self.items():
578 if not package.is_regridding_supported():
579 return False, package_name
580 return True, ""
582 def is_clipping_supported(self) -> Tuple[bool, str]:
583 """
584 Returns True if all the packages in the model supports clipping. If one
585 of the packages in the model does not support clipping, it returns the
586 name of the first one.
587 """
588 for package_name, package in self.items():
589 if not package.is_clipping_supported():
590 return False, package_name
591 return True, ""