Coverage for C:\src\imod-python\imod\util\spatial.py: 93%

238 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1""" 

2Utility functions for dealing with the spatial 

3location of rasters: :func:`imod.util.spatial.coord_reference`, 

4:func:`imod.util.spatial_reference` and :func:`imod.util.transform`. These are 

5used internally, but are not private since they may be useful to users as well. 

6""" 

7 

8import collections 

9import re 

10from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple, Union 

11 

12import affine 

13import numpy as np 

14import pandas as pd 

15import xarray as xr 

16import xugrid as xu 

17 

18from imod.typing import FloatArray, GridDataset, IntArray 

19from imod.util.imports import MissingOptionalModule 

20 

21# since rasterio, shapely, and geopandas are a big dependencies that are 

22# sometimes hard to install and not always required, we made this an optional 

23# dependency 

24try: 

25 import rasterio 

26except ImportError: 

27 rasterio = MissingOptionalModule("rasterio") 

28 

29try: 

30 import shapely 

31except ImportError: 

32 shapely = MissingOptionalModule("shapely") 

33 

34if TYPE_CHECKING: 

35 import geopandas as gpd 

36else: 

37 try: 

38 import geopandas as gpd 

39 except ImportError: 

40 gpd = MissingOptionalModule("geopandas") 

41 

42 

43def _xycoords(bounds, cellsizes) -> Dict[str, Any]: 

44 """Based on bounds and cellsizes, construct coords with spatial information""" 

45 # unpack tuples 

46 xmin, xmax, ymin, ymax = bounds 

47 dx, dy = cellsizes 

48 coords: collections.OrderedDict[str, Any] = collections.OrderedDict() 

49 # from cell size to x and y coordinates 

50 if isinstance(dx, (int, float, np.int_)): # equidistant 

51 coords["x"] = np.arange(xmin + dx / 2.0, xmax, dx) 

52 coords["y"] = np.arange(ymax + dy / 2.0, ymin, dy) 

53 coords["dx"] = np.array(float(dx)) 

54 coords["dy"] = np.array(float(dy)) 

55 else: # nonequidistant 

56 # even though IDF may store them as float32, we always convert them to float64 

57 dx = dx.astype(np.float64) 

58 dy = dy.astype(np.float64) 

59 coords["x"] = xmin + np.cumsum(dx) - 0.5 * dx 

60 coords["y"] = ymax + np.cumsum(dy) - 0.5 * dy 

61 if np.allclose(dx, dx[0]) and np.allclose(dy, dy[0]): 

62 coords["dx"] = np.array(float(dx[0])) 

63 coords["dy"] = np.array(float(dy[0])) 

64 else: 

65 coords["dx"] = ("x", dx) 

66 coords["dy"] = ("y", dy) 

67 return coords 

68 

69 

70def coord_reference(da_coord) -> Tuple[float, float, float]: 

71 """ 

72 Extracts dx, xmin, xmax for a coordinate DataArray, where x is any coordinate. 

73 

74 If the DataArray coordinates are nonequidistant, dx will be returned as 

75 1D ndarray instead of float. 

76 

77 Parameters 

78 ---------- 

79 a : xarray.DataArray of a coordinate 

80 

81 Returns 

82 -------------- 

83 tuple 

84 (dx, xmin, xmax) for a coordinate x 

85 """ 

86 x = da_coord.values 

87 

88 # Possibly non-equidistant 

89 dx_string = f"d{da_coord.name}" 

90 if dx_string in da_coord.coords: 

91 dx = da_coord.coords[dx_string] 

92 if (dx.shape == x.shape) and (dx.size != 1): 

93 # choose correctly for decreasing coordinate 

94 if dx[0] < 0.0: 

95 end = 0 

96 start = -1 

97 else: 

98 start = 0 

99 end = -1 

100 dx = dx.values.astype(np.float64) 

101 xmin = float(x.min()) - 0.5 * abs(dx[start]) 

102 xmax = float(x.max()) + 0.5 * abs(dx[end]) 

103 # As a single value if equidistant 

104 if np.allclose(dx, dx[0]): 

105 dx = dx[0] 

106 else: 

107 dx = float(dx) 

108 xmin = float(x.min()) - 0.5 * abs(dx) 

109 xmax = float(x.max()) + 0.5 * abs(dx) 

110 elif x.size == 1: 

111 raise ValueError( 

112 f"DataArray has size 1 along {da_coord.name}, so cellsize must be provided" 

113 f" as a coordinate named d{da_coord.name}." 

114 ) 

115 else: # Equidistant 

116 # TODO: decide on decent criterium for what equidistant means 

117 # make use of floating point epsilon? E.g: 

118 # https://github.com/ioam/holoviews/issues/1869#issuecomment-353115449 

119 dxs = np.diff(x.astype(np.float64)) 

120 dx = dxs[0] 

121 atolx = abs(1.0e-4 * dx) 

122 if not np.allclose(dxs, dx, atolx): 

123 raise ValueError( 

124 f"DataArray has to be equidistant along {da_coord.name}, or cellsizes" 

125 f" must be provided as a coordinate named d{da_coord.name}." 

126 ) 

127 

128 # as xarray uses midpoint coordinates 

129 xmin = float(x.min()) - 0.5 * abs(dx) 

130 xmax = float(x.max()) + 0.5 * abs(dx) 

131 

132 return dx, xmin, xmax 

133 

134 

135def spatial_reference( 

136 a: xr.DataArray, 

137) -> Tuple[float, float, float, float, float, float]: 

138 """ 

139 Extracts spatial reference from DataArray. 

140 

141 If the DataArray coordinates are nonequidistant, dx and dy will be returned 

142 as 1D ndarray instead of float. 

143 

144 Parameters 

145 ---------- 

146 a : xarray.DataArray 

147 

148 Returns 

149 -------------- 

150 tuple 

151 (dx, xmin, xmax, dy, ymin, ymax) 

152 

153 """ 

154 dx, xmin, xmax = coord_reference(a["x"]) 

155 dy, ymin, ymax = coord_reference(a["y"]) 

156 return dx, xmin, xmax, dy, ymin, ymax 

157 

158 

159def transform(a: xr.DataArray) -> affine.Affine: 

160 """ 

161 Extract the spatial reference information from the DataArray coordinates, 

162 into an affine.Affine object for writing to e.g. rasterio supported formats. 

163 

164 Parameters 

165 ---------- 

166 a : xarray.DataArray 

167 

168 Returns 

169 ------- 

170 affine.Affine 

171 

172 """ 

173 dx, xmin, _, dy, _, ymax = spatial_reference(a) 

174 

175 def equidistant(dx, name): 

176 if isinstance(dx, np.ndarray): 

177 if np.unique(dx).size == 1: 

178 return dx[0] 

179 else: 

180 raise ValueError(f"DataArray is not equidistant along {name}") 

181 else: 

182 return dx 

183 

184 dx = equidistant(dx, "x") 

185 dy = equidistant(dy, "y") 

186 

187 if dx < 0.0: 

188 raise ValueError("dx must be positive") 

189 if dy > 0.0: 

190 raise ValueError("dy must be negative") 

191 return affine.Affine(dx, 0.0, xmin, 0.0, dy, ymax) 

192 

193 

194def ugrid2d_data(da: xr.DataArray, face_dim: str) -> xr.DataArray: 

195 """ 

196 Reshape a structured (x, y) DataArray into unstructured (face) form. 

197 Extra dimensions are maintained: 

198 e.g. (time, layer, x, y) becomes (time, layer, face). 

199 

200 Parameters 

201 ---------- 

202 da: xr.DataArray 

203 Structured DataArray with last two dimensions ("y", "x"). 

204 

205 Returns 

206 ------- 

207 Unstructured DataArray with dimensions ("y", "x") replaced by ("face",). 

208 """ 

209 if da.dims[-2:] != ("y", "x"): 

210 raise ValueError('Last two dimensions of da must be ("y", "x")') 

211 dims = da.dims[:-2] 

212 coords = {k: da.coords[k] for k in dims} 

213 return xr.DataArray( 

214 da.data.reshape(*da.shape[:-2], -1), 

215 coords=coords, 

216 dims=[*dims, face_dim], 

217 name=da.name, 

218 ) 

219 

220 

221def unstack_dim_into_variable(dataset: GridDataset, dim: str) -> GridDataset: 

222 """ 

223 Unstack each variable containing ``dim`` into separate variables. 

224 """ 

225 unstacked = dataset.copy() 

226 

227 variables_containing_dim = [ 

228 variable for variable in dataset.data_vars if dim in dataset[variable].dims 

229 ] 

230 

231 for variable in variables_containing_dim: 

232 stacked = unstacked[variable] 

233 unstacked = unstacked.drop_vars(variable) # type: ignore 

234 for index in stacked[dim].values: 

235 unstacked[f"{variable}_{dim}_{index}"] = stacked.sel( 

236 indexers={dim: index}, drop=True 

237 ) 

238 if dim in unstacked.coords: 

239 unstacked = unstacked.drop_vars(dim) 

240 return unstacked 

241 

242 

243def mdal_compliant_ugrid2d(dataset: xr.Dataset) -> xr.Dataset: 

244 """ 

245 Ensures the xarray Dataset will be written to a UGRID netCDF that will be 

246 accepted by MDAL. 

247 

248 * Unstacks variables with a layer dimension into separate variables. 

249 * Removes absent entries from the mesh topology attributes. 

250 * Sets encoding to float for datetime variables. 

251 

252 Parameters 

253 ---------- 

254 dataset: xarray.Dataset 

255 

256 Returns 

257 ------- 

258 unstacked: xr.Dataset 

259 

260 """ 

261 ds = unstack_dim_into_variable(dataset, "layer") 

262 

263 # Find topology variables 

264 for variable in ds.data_vars: 

265 attrs = ds[variable].attrs 

266 if attrs.get("cf_role") == "mesh_topology": 

267 # Possible attributes: 

268 # 

269 # "cf_role" 

270 # "long_name" 

271 # "topology_dimension" 

272 # "node_dimension": required 

273 # "node_coordinates": required 

274 # "edge_dimension": optional 

275 # "edge_node_connectivity": optional 

276 # "face_dimension": required 

277 # "face_node_connectivity": required 

278 # "max_face_nodes_dimension": required 

279 # "face_coordinates": optional 

280 

281 node_dim = attrs.get("node_dimension") 

282 edge_dim = attrs.get("edge_dimension") 

283 face_dim = attrs.get("face_dimension") 

284 

285 # Drop the coordinates on the UGRID dimensions 

286 to_drop = [] 

287 for dim in (node_dim, edge_dim, face_dim): 

288 if dim is not None and dim in ds.coords: 

289 to_drop.append(dim) 

290 ds = ds.drop_vars(to_drop) 

291 

292 if edge_dim and edge_dim not in ds.dims: 

293 attrs.pop("edge_dimension") 

294 

295 face_coords = attrs.get("face_coordinates") 

296 if face_coords and face_coords not in ds.coords: 

297 attrs.pop("face_coordinates") 

298 

299 edge_nodes = attrs.get("edge_node_connectivity") 

300 if edge_nodes and edge_nodes not in ds: 

301 attrs.pop("edge_node_connectivity") 

302 

303 # Make sure time is encoded as a float for MDAL 

304 # TODO: MDAL requires all data variables to be float (this excludes the UGRID topology data) 

305 for var in ds.coords: 

306 if np.issubdtype(ds[var].dtype, np.datetime64): 

307 ds[var].encoding["dtype"] = np.float64 

308 

309 return ds 

310 

311 

312def from_mdal_compliant_ugrid2d(dataset: xu.UgridDataset): 

313 """ 

314 Undo some of the changes of ``mdal_compliant_ugrid2d``: re-stack the 

315 layers. 

316 

317 Parameters 

318 ---------- 

319 dataset: xugrid.UgridDataset 

320 

321 Returns 

322 ------- 

323 restacked: xugrid.UgridDataset 

324 

325 """ 

326 ds = dataset.ugrid.obj 

327 pattern = re.compile(r"(\w+)_layer_(\d+)") 

328 matches = [(variable, pattern.search(variable)) for variable in ds.data_vars] 

329 matches = [(variable, match) for (variable, match) in matches if match is not None] 

330 if not matches: 

331 return dataset 

332 

333 # First deal with the variables that may remain untouched. 

334 other_vars = set(ds.data_vars).difference([variable for (variable, _) in matches]) 

335 restacked = ds[list(other_vars)] 

336 

337 # Next group by name, which will be the output dataset variable name. 

338 grouped = collections.defaultdict(list) 

339 for variable, match in matches: 

340 name, layer = match.groups() # type: ignore 

341 da = ds[variable] 

342 grouped[name].append(da.assign_coords(layer=int(layer))) 

343 

344 # Concatenate, and make sure the dimension order is natural. 

345 ugrid_dims = {dim for grid in dataset.ugrid.grids for dim in grid.dimensions} 

346 for variable, das in grouped.items(): 

347 da = xr.concat(sorted(das, key=lambda da: da["layer"]), dim="layer") 

348 newdims = list(da.dims) 

349 newdims.remove("layer") 

350 # If it's a spatial dataset, the layer should be second last. 

351 if ugrid_dims.intersection(newdims): 

352 newdims.insert(-1, "layer") 

353 # If not, the layer should be last. 

354 else: 

355 newdims.append("layer") 

356 if tuple(newdims) != da.dims: 

357 da = da.transpose(*newdims) 

358 

359 restacked[variable] = da 

360 

361 return xu.UgridDataset(restacked, grids=dataset.ugrid.grids) 

362 

363 

364def to_ugrid2d(data: Union[xr.DataArray, xr.Dataset]) -> xr.Dataset: 

365 """ 

366 Convert a structured DataArray or Dataset into its UGRID-2D quadrilateral 

367 equivalent. 

368 

369 See: 

370 https://ugrid-conventions.github.io/ugrid-conventions/#2d-flexible-mesh-mixed-triangles-quadrilaterals-etc-topology 

371 

372 Parameters 

373 ---------- 

374 data: Union[xr.DataArray, xr.Dataset] 

375 Dataset or DataArray with last two dimensions ("y", "x"). 

376 In case of a Dataset, the 2D topology is defined once and variables are 

377 added one by one. 

378 In case of a DataArray, a name is required; a name can be set with: 

379 ``da.name = "..."``' 

380 

381 Returns 

382 ------- 

383 ugrid2d_dataset: xr.Dataset 

384 The equivalent data, in UGRID-2D quadrilateral form. 

385 """ 

386 if not isinstance(data, (xr.DataArray, xr.Dataset)): 

387 raise TypeError("data must be xarray.DataArray or xr.Dataset") 

388 

389 grid = xu.Ugrid2d.from_structured(data) 

390 ds = grid.to_dataset() 

391 

392 if isinstance(data, xr.Dataset): 

393 for variable in data.data_vars: 

394 ds[variable] = ugrid2d_data(data[variable], grid.face_dimension) 

395 if isinstance(data, xr.DataArray): 

396 if data.name is None: 

397 raise ValueError( 

398 'A name is required for the DataArray. It can be set with ``da.name = "..."`' 

399 ) 

400 ds[data.name] = ugrid2d_data(data, grid.face_dimension) 

401 return mdal_compliant_ugrid2d(ds) 

402 

403 

404def empty_2d( 

405 dx: Union[float, FloatArray], 

406 xmin: float, 

407 xmax: float, 

408 dy: Union[float, FloatArray], 

409 ymin: float, 

410 ymax: float, 

411) -> xr.DataArray: 

412 """ 

413 Create an empty 2D (x, y) DataArray. 

414 

415 ``dx`` and ``dy`` may be provided as: 

416 

417 * scalar: for equidistant spacing 

418 * array: for non-equidistant spacing 

419 

420 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

421 ``xmax`` are used to generate the appropriate midpoints. 

422 

423 Parameters 

424 ---------- 

425 dx: float, 1d array of floats 

426 cell size along x 

427 xmin: float 

428 xmax: float 

429 dy: float, 1d array of floats 

430 cell size along y 

431 ymin: float 

432 ymax: float 

433 

434 Returns 

435 ------- 

436 empty: xr.DataArray 

437 Filled with NaN. 

438 """ 

439 bounds = (xmin, xmax, ymin, ymax) 

440 cellsizes = (np.abs(dx), -np.abs(dy)) 

441 coords = _xycoords(bounds, cellsizes) 

442 nrow = coords["y"].size 

443 ncol = coords["x"].size 

444 return xr.DataArray( 

445 data=np.full((nrow, ncol), np.nan), coords=coords, dims=["y", "x"] 

446 ) 

447 

448 

449def empty_3d( 

450 dx: Union[float, FloatArray], 

451 xmin: float, 

452 xmax: float, 

453 dy: Union[float, FloatArray], 

454 ymin: float, 

455 ymax: float, 

456 layer: Union[int, Sequence[int], IntArray], 

457) -> xr.DataArray: 

458 """ 

459 Create an empty 2D (x, y) DataArray. 

460 

461 ``dx`` and ``dy`` may be provided as: 

462 

463 * scalar: for equidistant spacing 

464 * array: for non-equidistant spacing 

465 

466 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

467 ``xmax`` are used to generate the appropriate midpoints. 

468 

469 Parameters 

470 ---------- 

471 dx: float, 1d array of floats 

472 cell size along x 

473 xmin: float 

474 xmax: float 

475 dy: float, 1d array of floats 

476 cell size along y 

477 ymin: float 

478 ymax: float 

479 layer: int, sequence of integers, 1d array of integers 

480 

481 Returns 

482 ------- 

483 empty: xr.DataArray 

484 Filled with NaN. 

485 """ 

486 bounds = (xmin, xmax, ymin, ymax) 

487 cellsizes = (np.abs(dx), -np.abs(dy)) 

488 coords = _xycoords(bounds, cellsizes) 

489 nrow = coords["y"].size 

490 ncol = coords["x"].size 

491 layer = _layer(layer) 

492 coords["layer"] = layer 

493 

494 return xr.DataArray( 

495 data=np.full((layer.size, nrow, ncol), np.nan), 

496 coords=coords, 

497 dims=["layer", "y", "x"], 

498 ) 

499 

500 

501def empty_2d_transient( 

502 dx: Union[float, FloatArray], 

503 xmin: float, 

504 xmax: float, 

505 dy: Union[float, FloatArray], 

506 ymin: float, 

507 ymax: float, 

508 time: Any, 

509) -> xr.DataArray: 

510 """ 

511 Create an empty transient 2D (time, x, y) DataArray. 

512 

513 ``dx`` and ``dy`` may be provided as: 

514 

515 * scalar: for equidistant spacing 

516 * array: for non-equidistant spacing 

517 

518 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

519 ``xmax`` are used to generate the appropriate midpoints. 

520 

521 Parameters 

522 ---------- 

523 dx: float, 1d array of floats 

524 cell size along x 

525 xmin: float 

526 xmax: float 

527 dy: float, 1d array of floats 

528 cell size along y 

529 ymin: float 

530 ymax: float 

531 time: Any 

532 One or more of: str, numpy datetime64, pandas Timestamp 

533 

534 Returns 

535 ------- 

536 empty: xr.DataArray 

537 Filled with NaN. 

538 """ 

539 bounds = (xmin, xmax, ymin, ymax) 

540 cellsizes = (np.abs(dx), -np.abs(dy)) 

541 coords = _xycoords(bounds, cellsizes) 

542 nrow = coords["y"].size 

543 ncol = coords["x"].size 

544 time = _time(time) 

545 coords["time"] = time 

546 return xr.DataArray( 

547 data=np.full((time.size, nrow, ncol), np.nan), 

548 coords=coords, 

549 dims=["time", "y", "x"], 

550 ) 

551 

552 

553def empty_3d_transient( 

554 dx: Union[float, FloatArray], 

555 xmin: float, 

556 xmax: float, 

557 dy: Union[float, FloatArray], 

558 ymin: float, 

559 ymax: float, 

560 layer: Union[int, Sequence[int], IntArray], 

561 time: Any, 

562) -> xr.DataArray: 

563 """ 

564 Create an empty transient 3D (time, layer, x, y) DataArray. 

565 

566 ``dx`` and ``dy`` may be provided as: 

567 

568 * scalar: for equidistant spacing 

569 * array: for non-equidistant spacing 

570 

571 Note that xarray (and netCDF4) uses midpoint coordinates. ``xmin`` and 

572 ``xmax`` are used to generate the appropriate midpoints. 

573 

574 Parameters 

575 ---------- 

576 dx: float, 1d array of floats 

577 cell size along x 

578 xmin: float 

579 xmax: float 

580 dy: float, 1d array of floats 

581 cell size along y 

582 ymin: float 

583 ymax: float 

584 layer: int, sequence of integers, 1d array of integers 

585 time: Any 

586 One or more of: str, numpy datetime64, pandas Timestamp 

587 

588 Returns 

589 ------- 

590 empty: xr.DataArray 

591 Filled with NaN. 

592 """ 

593 bounds = (xmin, xmax, ymin, ymax) 

594 cellsizes = (np.abs(dx), -np.abs(dy)) 

595 coords = _xycoords(bounds, cellsizes) 

596 nrow = coords["y"].size 

597 ncol = coords["x"].size 

598 layer = _layer(layer) 

599 coords["layer"] = layer 

600 time = _time(time) 

601 coords["time"] = time 

602 return xr.DataArray( 

603 data=np.full((time.size, layer.size, nrow, ncol), np.nan), 

604 coords=coords, 

605 dims=["time", "layer", "y", "x"], 

606 ) 

607 

608 

609def _layer(layer: Union[int, Sequence[int], IntArray]) -> IntArray: 

610 layer = np.atleast_1d(layer) 

611 if layer.ndim > 1: 

612 raise ValueError("layer must be 1d") 

613 return layer 

614 

615 

616def _time(time: Any) -> Any: 

617 time = np.atleast_1d(time) 

618 if time.ndim > 1: 

619 raise ValueError("time must be 1d") 

620 return pd.to_datetime(time) 

621 

622 

623def is_divisor(numerator: Union[float, FloatArray], denominator: float) -> bool: 

624 """ 

625 Parameters 

626 ---------- 

627 numerator: np.array of floats or float 

628 denominator: float 

629 

630 Returns 

631 ------- 

632 is_divisor: bool 

633 """ 

634 denominator = np.abs(denominator) 

635 remainder = np.abs(numerator) % denominator 

636 return bool(np.all(np.isclose(remainder, 0.0) | np.isclose(remainder, denominator))) 

637 

638 

639def _polygonize(da: xr.DataArray) -> "gpd.GeoDataFrame": 

640 """ 

641 Polygonize a 2D-DataArray into a GeoDataFrame of polygons. 

642 

643 Private method located in util.spatial to work around circular imports. 

644 """ 

645 

646 if da.dims != ("y", "x"): 

647 raise ValueError('Dimensions must be ("y", "x")') 

648 

649 values = da.values 

650 if values.dtype == np.float64: 

651 values = values.astype(np.float32) 

652 

653 affine_transform = transform(da) 

654 shapes = rasterio.features.shapes(values, transform=affine_transform) 

655 

656 geometries = [] 

657 colvalues = [] 

658 for geom, colval in shapes: 

659 geometries.append(shapely.geometry.Polygon(geom["coordinates"][0])) 

660 colvalues.append(colval) 

661 

662 gdf = gpd.GeoDataFrame({"value": colvalues, "geometry": geometries}) 

663 gdf.crs = da.attrs.get("crs") 

664 return gdf