Coverage for C:\src\imod-python\imod\mf6\package.py: 96%

287 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1from __future__ import annotations 

2 

3import abc 

4import pathlib 

5from collections import defaultdict 

6from typing import Any, Mapping, Optional, Tuple, Union 

7 

8import cftime 

9import jinja2 

10import numpy as np 

11import xarray as xr 

12import xugrid as xu 

13 

14import imod 

15from imod.logging import standard_log_decorator 

16from imod.mf6.auxiliary_variables import ( 

17 get_variable_names, 

18) 

19from imod.mf6.interfaces.ipackage import IPackage 

20from imod.mf6.pkgbase import ( 

21 EXCHANGE_PACKAGES, 

22 TRANSPORT_PACKAGES, 

23 PackageBase, 

24) 

25from imod.mf6.utilities.mask import _mask 

26from imod.mf6.utilities.regrid import ( 

27 RegridderType, 

28 RegridderWeightsCache, 

29 _regrid_like, 

30) 

31from imod.mf6.utilities.schemata import filter_schemata_dict 

32from imod.mf6.validation import validation_pkg_error_message 

33from imod.mf6.write_context import WriteContext 

34from imod.schemata import ( 

35 AllNoDataSchema, 

36 EmptyIndexesSchema, 

37 SchemaType, 

38 ValidationError, 

39) 

40from imod.typing import GridDataArray 

41 

42 

43class Package(PackageBase, IPackage, abc.ABC): 

44 """ 

45 Package is used to share methods for specific packages with no time 

46 component. 

47 

48 It is not meant to be used directly, only to inherit from, to implement new 

49 packages. 

50 

51 This class only supports `array input 

52 <https://water.usgs.gov/water-resources/software/MODFLOW-6/mf6io_6.0.4.pdf#page=16>`_, 

53 not the list input which is used in :class:`BoundaryCondition`. 

54 """ 

55 

56 _pkg_id = "" 

57 _init_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {} 

58 _write_schemata: dict[str, list[SchemaType] | Tuple[SchemaType, ...]] = {} 

59 _keyword_map: dict[str, str] = {} 

60 

61 def __init__(self, allargs: Mapping[str, GridDataArray | float | int | bool | str]): 

62 super().__init__(allargs) 

63 

64 def isel(self): 

65 raise NotImplementedError( 

66 "Selection on packages not yet supported. To make a selection on " 

67 f"the xr.Dataset, call {self._pkg_id}.dataset.isel instead." 

68 "You can create a new package with a selection by calling " 

69 f"{__class__.__name__}(**{self._pkg_id}.dataset.isel(**selection))" 

70 ) 

71 

72 def sel(self): 

73 raise NotImplementedError( 

74 "Selection on packages not yet supported. To make a selection on " 

75 f"the xr.Dataset, call {self._pkg_id}.dataset.sel instead. " 

76 "You can create a new package with a selection by calling " 

77 f"{__class__.__name__}(**{self._pkg_id}.dataset.sel(**selection))" 

78 ) 

79 

80 def _valid(self, value): 

81 """ 

82 Filters values that are None, False, or a numpy.bool_ False. 

83 Needs to be this specific, since 0.0 and 0 are valid values, but are 

84 equal to a boolean False. 

85 """ 

86 # Test singletons 

87 if value is False or value is None: 

88 return False 

89 # Test numpy bool (not singleton) 

90 elif isinstance(value, np.bool_) and not value: 

91 return False 

92 # When dumping to netCDF and reading back, None will have been 

93 # converted into a NaN. Only check NaN if it's a floating type to avoid 

94 # TypeErrors. 

95 elif np.issubdtype(type(value), np.floating) and np.isnan(value): 

96 return False 

97 else: 

98 return True 

99 

100 @staticmethod 

101 def _number_format(dtype: type): 

102 if np.issubdtype(dtype, np.integer): 

103 return "%i" 

104 elif np.issubdtype(dtype, np.floating): 

105 return "%.18G" 

106 else: 

107 raise TypeError("dtype should be either integer or float") 

108 

109 @staticmethod 

110 def _initialize_template(pkg_id): 

111 loader = jinja2.PackageLoader("imod", "templates/mf6") 

112 env = jinja2.Environment(loader=loader, keep_trailing_newline=True) 

113 if pkg_id == "ims": 

114 fname = "sln-ims.j2" 

115 elif pkg_id == "tdis": 

116 fname = "sim-tdis.j2" 

117 elif pkg_id in TRANSPORT_PACKAGES: 

118 fname = f"gwt-{pkg_id}.j2" 

119 elif pkg_id in EXCHANGE_PACKAGES: 

120 fname = f"exg-{pkg_id}.j2" 

121 elif pkg_id == "api": 

122 fname = f"{pkg_id}.j2" 

123 else: 

124 fname = f"gwf-{pkg_id}.j2" 

125 return env.get_template(fname) 

126 

127 def write_blockfile(self, pkgname, globaltimes, write_context: WriteContext): 

128 directory = write_context.get_formatted_write_directory() 

129 

130 content = self.render( 

131 directory=directory, 

132 pkgname=pkgname, 

133 globaltimes=globaltimes, 

134 binary=write_context.use_binary, 

135 ) 

136 filename = write_context.write_directory / f"{pkgname}.{self._pkg_id}" 

137 with open(filename, "w") as f: 

138 f.write(content) 

139 

140 def write_binary_griddata(self, outpath, da, dtype): 

141 # From the modflow6 source, the header is defined as: 

142 # integer(I4B) :: kstp --> np.int32 : 1 

143 # integer(I4B) :: kper --> np.int32 : 2 

144 # real(DP) :: pertim --> 2 * np.int32 : 4 

145 # real(DP) :: totim --> 2 * np.int32 : 6 

146 # character(len=16) :: text --> 4 * np.int32 : 10 

147 # integer(I4B) :: m1, m2, m3 --> 3 * np.int32 : 13 

148 # so writing 13 bytes suffices to create a header. 

149 

150 # The following code is commented out due to modflow issue 189 

151 # https://github.com/MODFLOW-USGS/modflow6/issues/189 

152 # We never write LAYERED data. 

153 # The (structured) dis array reader results in an error if you try to 

154 # read a 3D botm array. By storing nlayer * nrow * ncol in the first 

155 # header entry, the array is read properly. 

156 

157 # haslayer = "layer" in da.dims 

158 # if haslayer: 

159 # nlayer, nrow, ncol = da.shape 

160 # else: 

161 # nrow, ncol = da.shape 

162 # nlayer = 1 

163 

164 # This is a work around for the abovementioned issue. 

165 nval = np.product(da.shape) 

166 header = np.zeros(13, np.int32) 

167 header[-3] = np.int32(nval) # ncol 

168 header[-2] = np.int32(1) # nrow 

169 header[-1] = np.int32(1) # nlayer 

170 

171 with open(outpath, "w") as f: 

172 header.tofile(f) 

173 da.values.flatten().astype(dtype).tofile(f) 

174 

175 def write_text_griddata(self, outpath, da, dtype): 

176 with open(outpath, "w") as f: 

177 # Note: reshaping here avoids writing newlines after every number. 

178 # This dumps all the values in a single row rather than a single 

179 # column. This is to be preferred, since editors can easily 

180 # "reshape" a long row with "word wrap"; they cannot as easily 

181 # ignore newlines. 

182 fmt = self._number_format(dtype) 

183 data = da.values 

184 if data.ndim > 2: 

185 np.savetxt(fname=f, X=da.values.reshape((1, -1)), fmt=fmt) 

186 else: 

187 np.savetxt(fname=f, X=da.values, fmt=fmt) 

188 

189 def _get_render_dictionary( 

190 self, 

191 directory: pathlib.Path, 

192 pkgname: str, 

193 globaltimes: Union[list[np.datetime64], np.ndarray], 

194 binary: bool, 

195 ) -> dict[str, Any]: 

196 d = {} 

197 if directory is None: 

198 pkg_directory = pkgname 

199 else: 

200 pkg_directory = pathlib.Path(directory) / pkgname 

201 

202 for varname in self.dataset.data_vars: 

203 key = self._keyword_map.get(str(varname), str(varname)) 

204 

205 if hasattr(self, "_grid_data") and varname in self._grid_data: 

206 layered, value = self._compose_values( 

207 self.dataset[varname], pkg_directory, key, binary=binary 

208 ) 

209 if self._valid(value): # skip False or None 

210 d[f"{key}_layered"], d[key] = layered, value 

211 else: 

212 value = self[varname].values[()] 

213 if self._valid(value): # skip False or None 

214 d[key] = value 

215 

216 if (hasattr(self, "_auxiliary_data")) and (names := get_variable_names(self)): 

217 d["auxiliary"] = names 

218 return d 

219 

220 def render(self, directory, pkgname, globaltimes, binary): 

221 d = self._get_render_dictionary(directory, pkgname, globaltimes, binary) 

222 return self._template.render(d) 

223 

224 @staticmethod 

225 def _is_xy_data(obj): 

226 if isinstance(obj, (xr.DataArray, xr.Dataset)): 

227 xy = "x" in obj.dims and "y" in obj.dims 

228 elif isinstance(obj, (xu.UgridDataArray, xu.UgridDataset)): 

229 xy = obj.ugrid.grid.face_dimension in obj.dims 

230 else: 

231 raise TypeError( 

232 "obj should be DataArray or UgridDataArray, " 

233 f"received {type(obj)} instead" 

234 ) 

235 return xy 

236 

237 def _compose_values(self, da, directory, name, binary): 

238 """ 

239 Compose values of dictionary. 

240 

241 Ignores times. Time dependent boundary conditions use the method from 

242 BoundaryCondition. 

243 

244 See documentation of wq 

245 """ 

246 layered = False 

247 values = [] 

248 if self._is_xy_data(da): 

249 if binary: 

250 path = (directory / f"{name}.bin").as_posix() 

251 values.append(f"open/close {path} (binary)") 

252 else: 

253 path = (directory / f"{name}.dat").as_posix() 

254 values.append(f"open/close {path}") 

255 else: 

256 if "layer" in da.dims: 

257 layered = True 

258 for layer in da.coords["layer"]: 

259 values.append(f"constant {da.sel(layer=layer).values[()]}") 

260 else: 

261 value = da.values[()] 

262 if self._valid(value): # skip None or False 

263 values.append(f"constant {value}") 

264 else: 

265 values = None 

266 

267 return layered, values 

268 

269 @standard_log_decorator() 

270 def write( 

271 self, 

272 pkgname: str, 

273 globaltimes: Union[list[np.datetime64], np.ndarray], 

274 write_context: WriteContext, 

275 ): 

276 directory = write_context.write_directory 

277 binary = write_context.use_binary 

278 self.write_blockfile(pkgname, globaltimes, write_context) 

279 

280 if hasattr(self, "_grid_data"): 

281 if self._is_xy_data(self.dataset): 

282 pkgdirectory = directory / pkgname 

283 pkgdirectory.mkdir(exist_ok=True, parents=True) 

284 for varname, dtype in self._grid_data.items(): 

285 key = self._keyword_map.get(varname, varname) 

286 da = self.dataset[varname] 

287 if self._is_xy_data(da): 

288 if binary: 

289 path = pkgdirectory / f"{key}.bin" 

290 self.write_binary_griddata(path, da, dtype) 

291 else: 

292 path = pkgdirectory / f"{key}.dat" 

293 self.write_text_griddata(path, da, dtype) 

294 

295 @standard_log_decorator() 

296 def _validate(self, schemata: dict, **kwargs) -> dict[str, list[ValidationError]]: 

297 errors = defaultdict(list) 

298 for variable, var_schemata in schemata.items(): 

299 for schema in var_schemata: 

300 if ( 

301 variable in self.dataset.keys() 

302 ): # concentration only added to dataset if specified 

303 try: 

304 schema.validate(self.dataset[variable], **kwargs) 

305 except ValidationError as e: 

306 errors[variable].append(e) 

307 return errors 

308 

309 def is_empty(self) -> bool: 

310 """ 

311 Returns True if the package is empty- for example if it contains only no-data values. 

312 """ 

313 

314 # Create schemata dict only containing the 

315 # variables with a AllNoDataSchema and EmptyIndexesSchema (in case of 

316 # HFB) in the write schemata. 

317 allnodata_schemata = filter_schemata_dict( 

318 self._write_schemata, (AllNoDataSchema, EmptyIndexesSchema) 

319 ) 

320 

321 # Find if packages throws ValidationError for AllNoDataSchema or 

322 # EmptyIndexesSchema. 

323 allnodata_errors = self._validate(allnodata_schemata) 

324 return len(allnodata_errors) > 0 

325 

326 def _validate_init_schemata(self, validate: bool): 

327 """ 

328 Run the "cheap" schema validations. 

329 

330 The expensive validations are run during writing. Some are only 

331 available then: e.g. idomain to determine active part of domain. 

332 """ 

333 if not validate: 

334 return 

335 errors = self._validate(self._init_schemata) 

336 if len(errors) > 0: 

337 message = validation_pkg_error_message(errors) 

338 raise ValidationError(message) 

339 return 

340 

341 def copy(self) -> Any: 

342 # All state should be contained in the dataset. 

343 return type(self)(**self.dataset.copy().to_dict()) 

344 

345 @staticmethod 

346 def _clip_repeat_stress( 

347 repeat_stress: xr.DataArray, 

348 time, 

349 time_start, 

350 time_end, 

351 ): 

352 """ 

353 Selection may remove the original data which are repeated. 

354 These should be re-inserted at the first occuring "key". 

355 Next, remove these keys as they've been "promoted" to regular 

356 timestamps with data. 

357 """ 

358 # First, "pop" and filter. 

359 keys, values = repeat_stress.values.T 

360 keep = (keys >= time_start) & (keys <= time_end) 

361 new_keys = keys[keep] 

362 new_values = values[keep] 

363 # Now detect which "value" entries have gone missing 

364 insert_values, index = np.unique(new_values, return_index=True) 

365 insert_keys = new_keys[index] 

366 # Setup indexer 

367 indexer = xr.DataArray( 

368 data=np.arange(time.size), 

369 coords={"time": time}, 

370 dims=("time",), 

371 ).sel(time=insert_values) 

372 indexer["time"] = insert_keys 

373 

374 # Update the key-value pairs. Discard keys that have been "promoted". 

375 keep = np.in1d(new_keys, insert_keys, assume_unique=True, invert=True) 

376 new_keys = new_keys[keep] 

377 new_values = new_values[keep] 

378 # Set the values to their new source. 

379 new_values = insert_keys[np.searchsorted(insert_values, new_values)] 

380 repeat_stress = xr.DataArray( 

381 data=np.column_stack((new_keys, new_values)), 

382 dims=("repeat", "repeat_items"), 

383 ) 

384 return indexer, repeat_stress 

385 

386 @staticmethod 

387 def _clip_time_indexer( 

388 time, 

389 time_start, 

390 time_end, 

391 ): 

392 original = xr.DataArray( 

393 data=np.arange(time.size), 

394 coords={"time": time}, 

395 dims=("time",), 

396 ) 

397 indexer = original.sel(time=slice(time_start, time_end)) 

398 

399 # The selection might return a 0-sized dimension. 

400 if indexer.size > 0: 

401 first_time = indexer["time"].values[0] 

402 else: 

403 first_time = None 

404 

405 # If the first time matches exactly, xarray will have done thing we 

406 # wanted and our work with the time dimension is finished. 

407 if (time_start is not None) and (time_start != first_time): 

408 # If the first time is before the original time, we need to 

409 # backfill; otherwise, we need to ffill the first timestamp. 

410 if time_start < time[0]: 

411 method = "bfill" 

412 else: 

413 method = "ffill" 

414 # Index with a list rather than a scalar to preserve the time 

415 # dimension. 

416 first = original.sel(time=[time_start], method=method) 

417 first["time"] = [time_start] 

418 indexer = xr.concat([first, indexer], dim="time") 

419 

420 return indexer 

421 

422 def __to_datetime(self, time, use_cftime): 

423 """ 

424 Helper function that converts to datetime, except when None. 

425 """ 

426 if time is None: 

427 return time 

428 else: 

429 return imod.util.time.to_datetime_internal(time, use_cftime) 

430 

431 def clip_box( 

432 self, 

433 time_min: Optional[cftime.datetime | np.datetime64 | str] = None, 

434 time_max: Optional[cftime.datetime | np.datetime64 | str] = None, 

435 layer_min: Optional[int] = None, 

436 layer_max: Optional[int] = None, 

437 x_min: Optional[float] = None, 

438 x_max: Optional[float] = None, 

439 y_min: Optional[float] = None, 

440 y_max: Optional[float] = None, 

441 top: Optional[GridDataArray] = None, 

442 bottom: Optional[GridDataArray] = None, 

443 ) -> Package: 

444 """ 

445 Clip a package by a bounding box (time, layer, y, x). 

446 

447 Slicing intervals may be half-bounded, by providing None: 

448 

449 * To select 500.0 <= x <= 1000.0: 

450 ``clip_box(x_min=500.0, x_max=1000.0)``. 

451 * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)`` 

452 or ``clip_box(x_max=1000.0)``. 

453 * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)`` 

454 or ``clip_box(x_min=1000.0)``. 

455 

456 Parameters 

457 ---------- 

458 time_min: optional 

459 time_max: optional 

460 layer_min: optional, int 

461 layer_max: optional, int 

462 x_min: optional, float 

463 x_max: optional, float 

464 y_min: optional, float 

465 y_max: optional, float 

466 top: optional, GridDataArray 

467 bottom: optional, GridDataArray 

468 state_for_boundary: optional, GridDataArray 

469 

470 

471 Returns 

472 ------- 

473 clipped: Package 

474 """ 

475 if not self.is_clipping_supported(): 

476 raise ValueError("this package does not support clipping.") 

477 

478 selection = self.dataset 

479 if "time" in selection: 

480 time = selection["time"].values 

481 use_cftime = isinstance(time[0], cftime.datetime) 

482 time_start = self.__to_datetime(time_min, use_cftime) 

483 time_end = self.__to_datetime(time_max, use_cftime) 

484 

485 indexer = self._clip_time_indexer( 

486 time=time, 

487 time_start=time_start, 

488 time_end=time_end, 

489 ) 

490 

491 if "repeat_stress" in selection.data_vars and self._valid( 

492 selection["repeat_stress"].values[()] 

493 ): 

494 repeat_indexer, repeat_stress = self._clip_repeat_stress( 

495 repeat_stress=selection["repeat_stress"], 

496 time=time, 

497 time_start=time_start, 

498 time_end=time_end, 

499 ) 

500 selection = selection.drop_vars("repeat_stress") 

501 selection["repeat_stress"] = repeat_stress 

502 indexer = repeat_indexer.combine_first(indexer).astype(int) 

503 

504 selection = selection.drop_vars("time").isel(time=indexer) 

505 

506 if "layer" in selection.coords: 

507 layer_slice = slice(layer_min, layer_max) 

508 # Cannot select if it's not a dimension! 

509 if "layer" not in selection.dims: 

510 selection = ( 

511 selection.expand_dims("layer") 

512 .sel(layer=layer_slice) 

513 .squeeze("layer") 

514 ) 

515 else: 

516 selection = selection.sel(layer=layer_slice) 

517 

518 x_slice = slice(x_min, x_max) 

519 y_slice = slice(y_min, y_max) 

520 if isinstance(selection, xu.UgridDataset): 

521 selection = selection.ugrid.sel(x=x_slice, y=y_slice) 

522 elif ("x" in selection.coords) and ("y" in selection.coords): 

523 if selection.indexes["y"].is_monotonic_decreasing: 

524 y_slice = slice(y_max, y_min) 

525 selection = selection.sel(x=x_slice, y=y_slice) 

526 

527 cls = type(self) 

528 new = cls.__new__(cls) 

529 new.dataset = selection 

530 return new 

531 

532 def mask(self, mask: GridDataArray) -> Any: 

533 """ 

534 Mask values outside of domain. 

535 

536 Floating values outside of the condition are set to NaN (nodata). 

537 Integer values outside of the condition are set to 0 (inactive in 

538 MODFLOW terms). 

539 

540 Parameters 

541 ---------- 

542 mask: xr.DataArray, xu.UgridDataArray of ints 

543 idomain-like integer array. 1 sets cells to active, 0 sets cells to inactive, 

544 -1 sets cells to vertical passthrough 

545 

546 Returns 

547 ------- 

548 masked: Package 

549 The package with part masked. 

550 """ 

551 

552 return _mask(self, mask) 

553 

554 def regrid_like( 

555 self, 

556 target_grid: GridDataArray, 

557 regrid_context: RegridderWeightsCache, 

558 regridder_types: Optional[dict[str, Tuple[RegridderType, str]]] = None, 

559 ) -> "Package": 

560 """ 

561 Creates a package of the same type as this package, based on another discretization. 

562 It regrids all the arrays in this package to the desired discretization, and leaves the options 

563 unmodified. At the moment only regridding to a different planar grid is supported, meaning 

564 ``target_grid`` has different ``"x"`` and ``"y"`` or different ``cell2d`` coords. 

565 

566 The regridding methods can be specified in the _regrid_method attribute of the package. These are the defaults 

567 that specify how each array should be regridded. These defaults can be overridden using the input 

568 parameters of this function. 

569 

570 Examples 

571 -------- 

572 To regrid the npf package with a non-default method for the k-field, call regrid_like with these arguments: 

573 

574 >>> new_npf = npf.regrid_like(like, {"k": (imod.RegridderType.OVERLAP, "mean")}) 

575 

576 

577 Parameters 

578 ---------- 

579 target_grid: xr.DataArray or xu.UgridDataArray 

580 a grid defined over the same discretization as the one we want to regrid the package to 

581 regridder_types: dict(str->(regridder type,str)) 

582 dictionary mapping arraynames (str) to a tuple of regrid type (a specialization class of BaseRegridder) and function name (str) 

583 this dictionary can be used to override the default mapping method. 

584 regrid_context: Optional RegridderWeightsCache 

585 stores regridder weights for different regridders. Can be used to speed up regridding, 

586 if the same regridders are used several times for regridding different arrays. 

587 

588 Returns 

589 ------- 

590 a package with the same options as this package, and with all the data-arrays regridded to another discretization, 

591 similar to the one used in input argument "target_grid" 

592 """ 

593 try: 

594 result = _regrid_like(self, target_grid, regrid_context, regridder_types) 

595 except ValueError as e: 

596 raise e 

597 except Exception: 

598 raise ValueError("package could not be regridded.") 

599 return result 

600 

601 def _skip_masking_dataarray(self, array_name: str) -> bool: 

602 if hasattr(self, "_skip_mask_arrays"): 

603 return array_name in self._skip_mask_arrays 

604 return False 

605 

606 @classmethod 

607 def is_grid_agnostic_package(cls) -> bool: 

608 return False 

609 

610 def __repr__(self) -> str: 

611 typename = type(self).__name__ 

612 return f"{typename}\n{self.dataset.__repr__()}" 

613 

614 def _repr_html_(self) -> str: 

615 typename = type(self).__name__ 

616 return f"<div>{typename}</div>{self.dataset._repr_html_()}" 

617 

618 @property 

619 def auxiliary_data_fields(self) -> dict[str, str]: 

620 if hasattr(self, "_auxiliary_data"): 

621 return self._auxiliary_data 

622 return {} 

623 

624 def get_non_grid_data(self, grid_names: list[str]) -> dict[str, Any]: 

625 """ 

626 This function copies the attributes of a dataset that are scalars, such as options. 

627 

628 parameters 

629 ---------- 

630 grid_names: list of str 

631 the names of the attribbutes of a dataset that are grids. 

632 """ 

633 result = {} 

634 all_non_grid_data = list(self.dataset.keys()) 

635 for name in ( 

636 gridname for gridname in grid_names if gridname in all_non_grid_data 

637 ): 

638 all_non_grid_data.remove(name) 

639 for name in all_non_grid_data: 

640 if "time" in self.dataset[name].coords: 

641 result[name] = self.dataset[name] 

642 else: 

643 result[name] = self.dataset[name].values[()] 

644 return result 

645 

646 def is_splitting_supported(self) -> bool: 

647 return True 

648 

649 def is_regridding_supported(self) -> bool: 

650 return True 

651 

652 def is_clipping_supported(self) -> bool: 

653 return True