Coverage for C:\src\imod-python\imod\mf6\model.py: 90%

272 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1from __future__ import annotations 

2 

3import abc 

4import collections 

5import inspect 

6import pathlib 

7from copy import deepcopy 

8from pathlib import Path 

9from typing import Optional, Tuple, Union 

10 

11import cftime 

12import jinja2 

13import numpy as np 

14import tomli 

15import tomli_w 

16import xarray as xr 

17import xugrid as xu 

18from jinja2 import Template 

19 

20import imod 

21from imod.logging import standard_log_decorator 

22from imod.mf6.interfaces.imodel import IModel 

23from imod.mf6.package import Package 

24from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase 

25from imod.mf6.utilities.mask import _mask_all_packages 

26from imod.mf6.utilities.regrid import RegridderWeightsCache, _regrid_like 

27from imod.mf6.validation import pkg_errors_to_status_info 

28from imod.mf6.write_context import WriteContext 

29from imod.schemata import ValidationError 

30from imod.typing import GridDataArray 

31 

32 

33class Modflow6Model(collections.UserDict, IModel, abc.ABC): 

34 _mandatory_packages: tuple[str, ...] = () 

35 _model_id: Optional[str] = None 

36 _template: Template 

37 

38 @staticmethod 

39 def _initialize_template(name: str) -> Template: 

40 loader = jinja2.PackageLoader("imod", "templates/mf6") 

41 env = jinja2.Environment(loader=loader, keep_trailing_newline=True) 

42 return env.get_template(name) 

43 

44 def __init__(self, **kwargs): 

45 collections.UserDict.__init__(self) 

46 for k, v in kwargs.items(): 

47 self[k] = v 

48 

49 self._options = {} 

50 

51 def __setitem__(self, key, value): 

52 if len(key) > 16: 

53 raise KeyError( 

54 f"Received key with more than 16 characters: '{key}'" 

55 "Modflow 6 has a character limit of 16." 

56 ) 

57 

58 super().__setitem__(key, value) 

59 

60 def update(self, *args, **kwargs): 

61 for k, v in dict(*args, **kwargs).items(): 

62 self[k] = v 

63 

64 def _get_diskey(self): 

65 dis_pkg_ids = ["dis", "disv", "disu"] 

66 

67 diskeys = [ 

68 self._get_pkgkey(pkg_id) 

69 for pkg_id in dis_pkg_ids 

70 if self._get_pkgkey(pkg_id) is not None 

71 ] 

72 

73 if len(diskeys) > 1: 

74 raise ValueError(f"Found multiple discretizations {diskeys}") 

75 elif len(diskeys) == 0: 

76 raise ValueError("No model discretization found") 

77 else: 

78 return diskeys[0] 

79 

80 def _get_pkgkey(self, pkg_id): 

81 """ 

82 Get package key that belongs to a certain pkg_id, since the keys are 

83 user specified. 

84 """ 

85 key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id] 

86 nkey = len(key) 

87 if nkey > 1: 

88 raise ValueError(f"Multiple instances of {key} detected") 

89 elif nkey == 1: 

90 return key[0] 

91 else: 

92 return None 

93 

94 def _check_for_required_packages(self, modelkey: str) -> None: 

95 # Check for mandatory packages 

96 pkg_ids = {pkg._pkg_id for pkg in self.values()} 

97 dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids 

98 if not dispresent: 

99 raise ValueError(f"No dis/disv/disu package found in model {modelkey}") 

100 for required in self._mandatory_packages: 

101 if required not in pkg_ids: 

102 raise ValueError(f"No {required} package found in model {modelkey}") 

103 return 

104 

105 def _use_cftime(self): 

106 """ 

107 Also checks if datetime types are homogeneous across packages. 

108 """ 

109 types = [ 

110 type(pkg.dataset["time"].values[0]) 

111 for pkg in self.values() 

112 if "time" in pkg.dataset.coords 

113 ] 

114 set_of_types = set(types) 

115 # Types will be empty if there's no time dependent input 

116 if len(set_of_types) == 0: 

117 return False 

118 else: # there is time dependent input 

119 if not len(set_of_types) == 1: 

120 raise ValueError( 

121 f"Multiple datetime types detected: {set_of_types}" 

122 "Use either cftime or numpy.datetime64[ns]." 

123 ) 

124 # Since we compare types and not instances, we use issubclass 

125 if issubclass(types[0], cftime.datetime): 

126 return True 

127 elif issubclass(types[0], np.datetime64): 

128 return False 

129 else: 

130 raise ValueError("Use either cftime or numpy.datetime64[ns].") 

131 

132 def _yield_times(self): 

133 modeltimes = [] 

134 for pkg in self.values(): 

135 if "time" in pkg.dataset.coords: 

136 modeltimes.append(pkg.dataset["time"].values) 

137 repeat_stress = pkg.dataset.get("repeat_stress") 

138 if repeat_stress is not None and repeat_stress.values[()] is not None: 

139 modeltimes.append(repeat_stress.isel(repeat_items=0).values) 

140 return modeltimes 

141 

142 def render(self, modelname: str, write_context: WriteContext): 

143 dir_for_render = write_context.root_directory / modelname 

144 

145 d = {k: v for k, v in self._options.items() if not (v is None or v is False)} 

146 packages = [] 

147 for pkgname, pkg in self.items(): 

148 # Add the six to the package id 

149 pkg_id = pkg._pkg_id 

150 key = f"{pkg_id}6" 

151 path = dir_for_render / f"{pkgname}.{pkg_id}" 

152 packages.append((key, path.as_posix(), pkgname)) 

153 d["packages"] = packages 

154 return self._template.render(d) 

155 

156 def _model_checks(self, modelkey: str): 

157 """ 

158 Check model integrity (called before writing) 

159 """ 

160 

161 self._check_for_required_packages(modelkey) 

162 

163 def __get_domain_geometry( 

164 self, 

165 ) -> tuple[ 

166 Union[xr.DataArray, xu.UgridDataArray], 

167 Union[xr.DataArray, xu.UgridDataArray], 

168 Union[xr.DataArray, xu.UgridDataArray], 

169 ]: 

170 discretization = self[self._get_diskey()] 

171 if discretization is None: 

172 raise ValueError("Discretization not found") 

173 top = discretization["top"] 

174 bottom = discretization["bottom"] 

175 idomain = discretization["idomain"] 

176 return top, bottom, idomain 

177 

178 def __get_k(self): 

179 try: 

180 npf = self[imod.mf6.NodePropertyFlow._pkg_id] 

181 except RuntimeError: 

182 raise ValidationError("expected one package of type ModePropertyFlow") 

183 

184 k = npf["k"] 

185 return k 

186 

187 @standard_log_decorator() 

188 def validate(self, model_name: str = "") -> StatusInfoBase: 

189 try: 

190 diskey = self._get_diskey() 

191 except Exception as e: 

192 status_info = StatusInfo(f"{model_name} model") 

193 status_info.add_error(str(e)) 

194 return status_info 

195 

196 dis = self[diskey] 

197 # We'll use the idomain for checking dims, shape, nodata. 

198 idomain = dis["idomain"] 

199 bottom = dis["bottom"] 

200 

201 model_status_info = NestedStatusInfo(f"{model_name} model") 

202 for pkg_name, pkg in self.items(): 

203 # Check for all schemata when writing. Types and dimensions 

204 # may have been changed after initialization... 

205 

206 if pkg_name in ["adv"]: 

207 continue # some packages can be skipped 

208 

209 # Concatenate write and init schemata. 

210 schemata = deepcopy(pkg._init_schemata) 

211 for key, value in pkg._write_schemata.items(): 

212 if key not in schemata.keys(): 

213 schemata[key] = value 

214 else: 

215 schemata[key] += value 

216 

217 pkg_errors = pkg._validate( 

218 schemata=schemata, 

219 idomain=idomain, 

220 bottom=bottom, 

221 ) 

222 if len(pkg_errors) > 0: 

223 model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors)) 

224 

225 return model_status_info 

226 

227 @standard_log_decorator() 

228 def write( 

229 self, modelname, globaltimes, validate: bool, write_context: WriteContext 

230 ) -> StatusInfoBase: 

231 """ 

232 Write model namefile 

233 Write packages 

234 """ 

235 

236 workdir = write_context.simulation_directory 

237 modeldirectory = workdir / modelname 

238 Path(modeldirectory).mkdir(exist_ok=True, parents=True) 

239 if validate: 

240 model_status_info = self.validate(modelname) 

241 if model_status_info.has_errors(): 

242 return model_status_info 

243 

244 # write model namefile 

245 namefile_content = self.render(modelname, write_context) 

246 namefile_path = modeldirectory / f"{modelname}.nam" 

247 with open(namefile_path, "w") as f: 

248 f.write(namefile_content) 

249 

250 # write package contents 

251 pkg_write_context = write_context.copy_with_new_write_directory( 

252 new_write_directory=modeldirectory 

253 ) 

254 for pkg_name, pkg in self.items(): 

255 try: 

256 if isinstance(pkg, imod.mf6.Well): 

257 top, bottom, idomain = self.__get_domain_geometry() 

258 k = self.__get_k() 

259 mf6_well_pkg = pkg.to_mf6_pkg( 

260 idomain, 

261 top, 

262 bottom, 

263 k, 

264 validate, 

265 pkg_write_context.is_partitioned, 

266 ) 

267 

268 mf6_well_pkg.write( 

269 pkgname=pkg_name, 

270 globaltimes=globaltimes, 

271 write_context=pkg_write_context, 

272 ) 

273 elif isinstance(pkg, imod.mf6.HorizontalFlowBarrierBase): 

274 top, bottom, idomain = self.__get_domain_geometry() 

275 k = self.__get_k() 

276 mf6_hfb_pkg = pkg.to_mf6_pkg(idomain, top, bottom, k, validate) 

277 mf6_hfb_pkg.write( 

278 pkgname=pkg_name, 

279 globaltimes=globaltimes, 

280 write_context=pkg_write_context, 

281 ) 

282 else: 

283 pkg.write( 

284 pkgname=pkg_name, 

285 globaltimes=globaltimes, 

286 write_context=pkg_write_context, 

287 ) 

288 except Exception as e: 

289 raise type(e)(f"{e}\nError occured while writing {pkg_name}") 

290 

291 return NestedStatusInfo(modelname) 

292 

293 @standard_log_decorator() 

294 def dump( 

295 self, directory, modelname, validate: bool = True, mdal_compliant: bool = False 

296 ): 

297 modeldirectory = pathlib.Path(directory) / modelname 

298 modeldirectory.mkdir(exist_ok=True, parents=True) 

299 if validate: 

300 statusinfo = self.validate() 

301 if statusinfo.has_errors(): 

302 raise ValidationError(statusinfo.to_string()) 

303 

304 toml_content: dict = collections.defaultdict(dict) 

305 for pkgname, pkg in self.items(): 

306 pkg_path = f"{pkgname}.nc" 

307 toml_content[type(pkg).__name__][pkgname] = pkg_path 

308 dataset = pkg.dataset 

309 if isinstance(dataset, xu.UgridDataset): 

310 if mdal_compliant: 

311 dataset = pkg.dataset.ugrid.to_dataset() 

312 mdal_dataset = imod.util.spatial.mdal_compliant_ugrid2d(dataset) 

313 mdal_dataset.to_netcdf(modeldirectory / pkg_path) 

314 else: 

315 pkg.dataset.ugrid.to_netcdf(modeldirectory / pkg_path) 

316 else: 

317 pkg.to_netcdf(modeldirectory / pkg_path) 

318 

319 toml_path = modeldirectory / f"{modelname}.toml" 

320 with open(toml_path, "wb") as f: 

321 tomli_w.dump(toml_content, f) 

322 

323 return toml_path 

324 

325 @classmethod 

326 def from_file(cls, toml_path): 

327 pkg_classes = { 

328 name: pkg_cls 

329 for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass) 

330 if issubclass(pkg_cls, Package) 

331 } 

332 

333 toml_path = pathlib.Path(toml_path) 

334 with open(toml_path, "rb") as f: 

335 toml_content = tomli.load(f) 

336 

337 parentdir = toml_path.parent 

338 instance = cls() 

339 for key, entry in toml_content.items(): 

340 for pkgname, path in entry.items(): 

341 pkg_cls = pkg_classes[key] 

342 instance[pkgname] = pkg_cls.from_file(parentdir / path) 

343 

344 return instance 

345 

346 @classmethod 

347 def model_id(cls) -> str: 

348 if cls._model_id is None: 

349 raise ValueError("Model id has not been set") 

350 return cls._model_id 

351 

352 def clip_box( 

353 self, 

354 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

355 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

356 layer_min: Optional[int] = None, 

357 layer_max: Optional[int] = None, 

358 x_min: Optional[float] = None, 

359 x_max: Optional[float] = None, 

360 y_min: Optional[float] = None, 

361 y_max: Optional[float] = None, 

362 state_for_boundary: Optional[GridDataArray] = None, 

363 ): 

364 """ 

365 Clip a model by a bounding box (time, layer, y, x). 

366 

367 Slicing intervals may be half-bounded, by providing None: 

368 

369 * To select 500.0 <= x <= 1000.0: 

370 ``clip_box(x_min=500.0, x_max=1000.0)``. 

371 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

372 or ``clip_box(x_max=1000.0)``. 

373 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

374 or ``clip_box(x_min=1000.0)``. 

375 

376 Parameters 

377 ---------- 

378 time_min: optional 

379 time_max: optional 

380 layer_min: optional, int 

381 layer_max: optional, int 

382 x_min: optional, float 

383 x_max: optional, float 

384 y_min: optional, float 

385 y_max: optional, float 

386 state_for_boundary: optional, float 

387 """ 

388 supported, error_with_object = self.is_clipping_supported() 

389 if not supported: 

390 raise ValueError( 

391 f"model cannot be clipped due to presence of package '{error_with_object}' in model" 

392 ) 

393 

394 clipped = self._clip_box_packages( 

395 time_min, 

396 time_max, 

397 layer_min, 

398 layer_max, 

399 x_min, 

400 x_max, 

401 y_min, 

402 y_max, 

403 ) 

404 

405 return clipped 

406 

407 def _clip_box_packages( 

408 self, 

409 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

410 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

411 layer_min: Optional[int] = None, 

412 layer_max: Optional[int] = None, 

413 x_min: Optional[float] = None, 

414 x_max: Optional[float] = None, 

415 y_min: Optional[float] = None, 

416 y_max: Optional[float] = None, 

417 ): 

418 """ 

419 Clip a model by a bounding box (time, layer, y, x). 

420 

421 Slicing intervals may be half-bounded, by providing None: 

422 

423 * To select 500.0 <= x <= 1000.0: 

424 ``clip_box(x_min=500.0, x_max=1000.0)``. 

425 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

426 or ``clip_box(x_max=1000.0)``. 

427 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

428 or ``clip_box(x_min=1000.0)``. 

429 

430 Parameters 

431 ---------- 

432 time_min: optional 

433 time_max: optional 

434 layer_min: optional, int 

435 layer_max: optional, int 

436 x_min: optional, float 

437 x_max: optional, float 

438 y_min: optional, float 

439 y_max: optional, float 

440 

441 Returns 

442 ------- 

443 clipped : Modflow6Model 

444 """ 

445 

446 top, bottom, idomain = self.__get_domain_geometry() 

447 

448 clipped = type(self)(**self._options) 

449 for key, pkg in self.items(): 

450 clipped[key] = pkg.clip_box( 

451 time_min=time_min, 

452 time_max=time_max, 

453 layer_min=layer_min, 

454 layer_max=layer_max, 

455 x_min=x_min, 

456 x_max=x_max, 

457 y_min=y_min, 

458 y_max=y_max, 

459 top=top, 

460 bottom=bottom, 

461 ) 

462 

463 return clipped 

464 

465 def regrid_like( 

466 self, 

467 target_grid: GridDataArray, 

468 validate: bool = True, 

469 regrid_context: Optional[RegridderWeightsCache] = None, 

470 ) -> "Modflow6Model": 

471 """ 

472 Creates a model by regridding the packages of this model to another discretization. 

473 It regrids all the arrays in the package using the default regridding methods. 

474 At the moment only regridding to a different planar grid is supported, meaning 

475 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

476 

477 Parameters 

478 ---------- 

479 target_grid: xr.DataArray or xu.UgridDataArray 

480 a grid defined over the same discretization as the one we want to regrid the package to 

481 validate: bool 

482 set to true to validate the regridded packages 

483 regrid_context: Optional RegridderWeightsCache 

484 stores regridder weights for different regridders. Can be used to speed up regridding, 

485 if the same regridders are used several times for regridding different arrays. 

486 

487 Returns 

488 ------- 

489 a model with similar packages to the input model, and with all the data-arrays regridded to another discretization, 

490 similar to the one used in input argument "target_grid" 

491 """ 

492 return _regrid_like(self, target_grid, validate, regrid_context) 

493 

494 def mask_all_packages( 

495 self, 

496 mask: GridDataArray, 

497 ): 

498 """ 

499 This function applies a mask to all packages in a model. The mask must 

500 be presented as an idomain-like integer array that has 0 (inactive) or 

501 -1 (vertical passthrough) values in filtered cells and 1 in active 

502 cells. 

503 Masking will overwrite idomain with the mask where the mask is 0 or -1. 

504 Where the mask is 1, the original value of idomain will be kept. Masking 

505 will update the packages accordingly, blanking their input where needed, 

506 and is therefore not a reversible operation. 

507 

508 Parameters 

509 ---------- 

510 mask: xr.DataArray, xu.UgridDataArray of ints 

511 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, 

512 -1 sets cells to vertical passthrough 

513 """ 

514 

515 _mask_all_packages(self, mask) 

516 

517 def purge_empty_packages(self, model_name: Optional[str] = "") -> None: 

518 """ 

519 This function removes empty packages from the model. 

520 """ 

521 empty_packages = [ 

522 package_name for package_name, package in self.items() if package.is_empty() 

523 ] 

524 for package_name in empty_packages: 

525 self.pop(package_name) 

526 

527 @property 

528 def domain(self): 

529 dis = self._get_diskey() 

530 return self[dis]["idomain"] 

531 

532 @property 

533 def bottom(self): 

534 dis = self._get_diskey() 

535 return self[dis]["bottom"] 

536 

537 def __repr__(self) -> str: 

538 INDENT = " " 

539 typename = type(self).__name__ 

540 options = [ 

541 f"{INDENT}{key}={repr(value)}," for key, value in self._options.items() 

542 ] 

543 packages = [ 

544 f"{INDENT}{repr(key)}: {type(value).__name__}," 

545 for key, value in self.items() 

546 ] 

547 # Place the emtpy dict on the same line. Looks silly otherwise. 

548 if packages: 

549 content = [f"{typename}("] + options + ["){"] + packages + ["}"] 

550 else: 

551 content = [f"{typename}("] + options + ["){}"] 

552 return "\n".join(content) 

553 

554 def is_use_newton(self): 

555 return False 

556 

557 def is_splitting_supported(self) -> Tuple[bool, str]: 

558 """ 

559 Returns True if all the packages in the model supports splitting. If one 

560 of the packages in the model does not support splitting, it returns the 

561 name of the first one. 

562 """ 

563 for package_name, package in self.items(): 

564 if not package.is_splitting_supported(): 

565 return False, package_name 

566 return True, "" 

567 

568 def is_regridding_supported(self) -> Tuple[bool, str]: 

569 """ 

570 Returns True if all the packages in the model supports regridding. If one 

571 of the packages in the model does not support regridding, it returns the 

572 name of the first one. 

573 """ 

574 for package_name, package in self.items(): 

575 if not package.is_regridding_supported(): 

576 return False, package_name 

577 return True, "" 

578 

579 def is_clipping_supported(self) -> Tuple[bool, str]: 

580 """ 

581 Returns True if all the packages in the model supports clipping. If one 

582 of the packages in the model does not support clipping, it returns the 

583 name of the first one. 

584 """ 

585 for package_name, package in self.items(): 

586 if not package.is_clipping_supported(): 

587 return False, package_name 

588 return True, ""