Coverage for C:\src\imod-python\imod\typing\grid.py: 91%

191 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1import pickle 

2import textwrap 

3from typing import Callable, Mapping, Sequence 

4 

5import numpy as np 

6import xarray as xr 

7import xugrid as xu 

8from fastcore.dispatch import typedispatch 

9 

10from imod.typing import GridDataArray, GridDataset, structured 

11from imod.util.spatial import _polygonize 

12 

13 

14@typedispatch # type: ignore [no-redef] 

15def zeros_like(grid: xr.DataArray, *args, **kwargs): 

16 return xr.zeros_like(grid, *args, **kwargs) 

17 

18 

19@typedispatch # type: ignore [no-redef] 

20def zeros_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811 

21 return xu.zeros_like(grid, *args, **kwargs) 

22 

23 

24@typedispatch # type: ignore [no-redef] 

25def ones_like(grid: xr.DataArray, *args, **kwargs): 

26 return xr.ones_like(grid, *args, **kwargs) 

27 

28 

29@typedispatch # type: ignore [no-redef] 

30def ones_like(grid: xu.UgridDataArray, *args, **kwargs): # noqa: F811 

31 return xu.ones_like(grid, *args, **kwargs) 

32 

33 

34@typedispatch # type: ignore [no-redef] 

35def nan_like(grid: xr.DataArray, dtype=np.float32, *args, **kwargs): 

36 return xr.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs) 

37 

38 

39@typedispatch # type: ignore [no-redef] 

40def nan_like(grid: xu.UgridDataArray, dtype=np.float32, *args, **kwargs): # noqa: F811 

41 return xu.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs) 

42 

43 

44@typedispatch # type: ignore [no-redef] 

45def is_unstructured(grid: xu.UgridDataArray | xu.UgridDataset) -> bool: 

46 return True 

47 

48 

49@typedispatch # type: ignore [no-redef] 

50def is_unstructured(grid: xr.DataArray | xr.Dataset) -> bool: # noqa: F811 

51 return False 

52 

53 

54def _force_decreasing_y(structured_grid: xr.DataArray | xr.Dataset): 

55 flip = slice(None, None, -1) 

56 if structured_grid.indexes["y"].is_monotonic_increasing: 

57 structured_grid = structured_grid.isel(y=flip) 

58 elif not structured_grid.indexes["y"].is_monotonic_decreasing: 

59 raise RuntimeError( 

60 f"Non-monotonous y-coordinates for grid: {structured_grid.name}." 

61 ) 

62 return structured_grid 

63 

64 

65def _get_first_item(objects: Sequence): 

66 return next(iter(objects)) 

67 

68 

69# Typedispatching doesn't work based on types of list elements, therefore resort to 

70# isinstance testing 

71def _type_dispatch_functions_on_grid_sequence( 

72 objects: Sequence[GridDataArray | GridDataset], 

73 unstructured_func: Callable, 

74 structured_func: Callable, 

75 *args, 

76 **kwargs, 

77) -> GridDataArray | GridDataset: 

78 """ 

79 Type dispatch functions on sequence of grids. Functions like merging or concatenating. 

80 """ 

81 first_object = _get_first_item(objects) 

82 start_type = type(first_object) 

83 homogeneous = all(isinstance(o, start_type) for o in objects) 

84 if not homogeneous: 

85 unique_types = {type(o) for o in objects} 

86 raise TypeError( 

87 f"Only homogeneous sequences can be reduced, received sequence of {unique_types}" 

88 ) 

89 if isinstance(first_object, (xu.UgridDataArray, xu.UgridDataset)): 

90 return unstructured_func(objects, *args, **kwargs) 

91 elif isinstance(first_object, (xr.DataArray, xr.Dataset)): 

92 return _force_decreasing_y(structured_func(objects, *args, **kwargs)) 

93 raise TypeError( 

94 f"'{unstructured_func.__name__}' not supported for type {type(objects[0])}" 

95 ) 

96 

97 

98# Typedispatching doesn't work based on types of dict elements, therefore resort 

99# to manual type testing 

100def _type_dispatch_functions_on_dict( 

101 dict_of_objects: Mapping[str, GridDataArray | float | bool | int], 

102 unstructured_func: Callable, 

103 structured_func: Callable, 

104 *args, 

105 **kwargs, 

106): 

107 """ 

108 Typedispatch function on grid and scalar variables provided in dictionary. 

109 Types do not need to be homogeneous as scalars and grids can be mixed. No 

110 mixing of structured and unstructured grids is allowed. Also allows running 

111 function on dictionary with purely scalars, in which case it will call to 

112 the xarray function. 

113 """ 

114 

115 error_msg = textwrap.dedent( 

116 """ 

117 Received both structured grid (xr.DataArray) and xu.UgridDataArray. This 

118 means structured grids as well as unstructured grids were provided. 

119 """ 

120 ) 

121 

122 if dict_of_objects is None: 

123 return xr.Dataset() 

124 

125 types = [type(arg) for arg in dict_of_objects.values()] 

126 has_unstructured = xu.UgridDataArray in types 

127 # Test structured if xr.DataArray and spatial. 

128 has_structured_grid = any( 

129 isinstance(arg, xr.DataArray) and is_spatial_2D(arg) 

130 for arg in dict_of_objects.values() 

131 ) 

132 if has_structured_grid and has_unstructured: 

133 raise TypeError(error_msg) 

134 if has_unstructured: 

135 return unstructured_func([dict_of_objects], *args, **kwargs) 

136 

137 return structured_func([dict_of_objects], *args, **kwargs) 

138 

139 

140def merge( 

141 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

142) -> GridDataset: 

143 return _type_dispatch_functions_on_grid_sequence( 

144 objects, xu.merge, xr.merge, *args, **kwargs 

145 ) 

146 

147 

148def merge_partitions( 

149 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

150) -> GridDataArray | GridDataset: 

151 return _type_dispatch_functions_on_grid_sequence( 

152 objects, xu.merge_partitions, structured.merge_partitions, *args, **kwargs 

153 ) 

154 

155 

156def concat( 

157 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

158) -> GridDataArray | GridDataset: 

159 return _type_dispatch_functions_on_grid_sequence( 

160 objects, xu.concat, xr.concat, *args, **kwargs 

161 ) 

162 

163 

164def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs): 

165 """ 

166 Work around xugrid issue https://github.com/Deltares/xugrid/issues/179 

167 

168 Expects only one dictionary in list. List is used to have same API as 

169 xr.merge(). 

170 

171 Merges unstructured grids first, then manually assigns scalar variables. 

172 """ 

173 if len(variables_to_merge) > 1: 

174 raise ValueError( 

175 f"Only one dict of variables expected, got {len(variables_to_merge)}" 

176 ) 

177 

178 variables_to_merge_dict = variables_to_merge[0] 

179 

180 if not isinstance(variables_to_merge_dict, dict): 

181 raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}") 

182 

183 # Separate variables into list of grids and dict of scalar variables 

184 grids_ls = [] 

185 scalar_dict = {} 

186 for name, variable in variables_to_merge_dict.items(): 

187 if isinstance(variable, xu.UgridDataArray): 

188 grids_ls.append(variable.rename(name)) 

189 else: 

190 scalar_dict[name] = variable 

191 

192 # Merge grids 

193 dataset = xu.merge(grids_ls, *args, **kwargs) 

194 

195 # Temporarily work around this xugrid issue, until fixed: 

196 # https://github.com/Deltares/xugrid/issues/206 

197 grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids] 

198 unique_grid_hashes = np.unique(grid_hashes) 

199 if unique_grid_hashes.size > 1: 

200 raise ValueError( 

201 "Multiple grids provided, please provide data on one unique grid" 

202 ) 

203 else: 

204 # Possibly won't work anymore if this ever gets implemented: 

205 # https://github.com/Deltares/xugrid/issues/195 

206 dataset._grids = [dataset.grids[0]] 

207 

208 # Assign scalar variables manually 

209 for name, variable in scalar_dict.items(): 

210 dataset[name] = variable 

211 

212 return dataset 

213 

214 

215def merge_with_dictionary( 

216 variables_to_merge: Mapping[str, GridDataArray | float | bool | int], 

217 *args, 

218 **kwargs, 

219): 

220 return _type_dispatch_functions_on_dict( 

221 variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs 

222 ) 

223 

224 

225@typedispatch # type: ignore [no-redef] 

226def bounding_polygon(active: xr.DataArray): 

227 """Return bounding polygon of active cells""" 

228 to_polygonize = active.where(active, other=np.nan) 

229 polygons_gdf = _polygonize(to_polygonize) 

230 # Filter polygons with inactive values (NaN) 

231 is_active_polygon = polygons_gdf["value"] == 1.0 

232 return polygons_gdf.loc[is_active_polygon] 

233 

234 

235@typedispatch # type: ignore [no-redef] 

236def bounding_polygon(active: xu.UgridDataArray): # noqa: F811 

237 """Return bounding polygon of active cells""" 

238 active_indices = np.where(active > 0)[0] 

239 domain_slice = {f"{active.ugrid.grid.face_dimension}": active_indices} 

240 active_clipped = active.isel(domain_slice, missing_dims="ignore") 

241 

242 return active_clipped.ugrid.grid.bounding_polygon() 

243 

244 

245@typedispatch # type: ignore [no-redef] 

246def is_spatial_2D(array: xr.DataArray) -> bool: 

247 """Return True if the array contains data in at least 2 spatial dimensions""" 

248 coords = array.coords 

249 dims = array.dims 

250 has_spatial_coords = "x" in coords and "y" in coords 

251 has_spatial_dims = "x" in dims and "y" in dims 

252 return has_spatial_coords & has_spatial_dims 

253 

254 

255@typedispatch # type: ignore [no-redef] 

256def is_spatial_2D(array: xu.UgridDataArray) -> bool: # noqa: F811 

257 """Return True if the array contains data associated to cell faces""" 

258 face_dim = array.ugrid.grid.face_dimension 

259 dims = array.dims 

260 coords = array.coords 

261 has_spatial_coords = face_dim in coords 

262 has_spatial_dims = face_dim in dims 

263 return has_spatial_dims & has_spatial_coords 

264 

265 

266@typedispatch # type: ignore [no-redef] 

267def is_spatial_2D(_: object) -> bool: # noqa: F811 

268 return False 

269 

270 

271@typedispatch # type: ignore [no-redef] 

272def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool: 

273 return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid) 

274 

275 

276@typedispatch # type: ignore [no-redef] 

277def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: # noqa: F811 

278 return array1.equals(array2) 

279 

280 

281@typedispatch # type: ignore [no-redef] 

282def is_equal(array1: object, array2: object) -> bool: # noqa: F811 

283 return False 

284 

285 

286@typedispatch # type: ignore [no-redef] 

287def is_same_domain(grid1: xu.UgridDataArray, grid2: xu.UgridDataArray) -> bool: 

288 return grid1.coords.equals(grid2.coords) and grid1.ugrid.grid.equals( 

289 grid2.ugrid.grid 

290 ) 

291 

292 

293@typedispatch # type: ignore [no-redef] 

294def is_same_domain(grid1: xr.DataArray, grid2: xr.DataArray) -> bool: # noqa: F811 

295 return grid1.coords.equals(grid2.coords) 

296 

297 

298@typedispatch # type: ignore [no-redef] 

299def is_same_domain(grid1: object, grid2: object) -> bool: # noqa: F811 

300 return False 

301 

302 

303@typedispatch # type: ignore [no-redef] 

304def get_spatial_dimension_names(grid: xr.DataArray) -> list[str]: 

305 return ["x", "y", "layer", "dx", "dy"] 

306 

307 

308@typedispatch # type: ignore [no-redef] 

309def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]: # noqa: F811 

310 facedim = grid.ugrid.grid.face_dimension 

311 return [facedim, "layer"] 

312 

313 

314@typedispatch # type: ignore [no-redef] 

315def get_spatial_dimension_names(grid: object) -> list[str]: # noqa: F811 

316 return [] 

317 

318 

319@typedispatch # type: ignore [no-redef] 

320def get_grid_geometry_hash(grid: xr.DataArray) -> int: 

321 hash_x = hash(pickle.dumps(grid["x"].values)) 

322 hash_y = hash(pickle.dumps(grid["y"].values)) 

323 return (hash_x, hash_y) 

324 

325 

326@typedispatch # type: ignore [no-redef] 

327def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int: # noqa: F811 

328 hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x)) 

329 hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y)) 

330 hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity)) 

331 return (hash_x, hash_y, hash_connectivity) 

332 

333 

334@typedispatch # type: ignore [no-redef] 

335def get_grid_geometry_hash(grid: object) -> int: # noqa: F811 

336 raise ValueError("get_grid_geometry_hash not supported for this object.") 

337 

338 

339@typedispatch # type: ignore [no-redef] 

340def enforce_dim_order(grid: xr.DataArray) -> xr.DataArray: 

341 """Enforce dimension order to iMOD Python standard""" 

342 return grid.transpose("species", "time", "layer", "y", "x", missing_dims="ignore") 

343 

344 

345@typedispatch # type: ignore [no-redef] 

346def enforce_dim_order(grid: xu.UgridDataArray) -> xu.UgridDataArray: # noqa: F811 

347 """Enforce dimension order to iMOD Python standard""" 

348 face_dimension = grid.ugrid.grid.face_dimension 

349 return grid.transpose( 

350 "species", "time", "layer", face_dimension, missing_dims="ignore" 

351 ) 

352 

353 

354def _enforce_unstructured(obj: GridDataArray, ugrid2d=xu.Ugrid2d) -> xu.UgridDataArray: 

355 """Force obj to unstructured""" 

356 return xu.UgridDataArray(xr.DataArray(obj), ugrid2d) 

357 

358 

359def preserve_gridtype(func): 

360 """ 

361 Decorator to preserve gridtype, this is to work around the following xugrid 

362 behavior: 

363 

364 >>> UgridDataArray() * DataArray() -> UgridDataArray 

365 >>> DataArray() * UgridDataArray() -> DataArray 

366 

367 with this decorator: 

368 

369 >>> UgridDataArray() * DataArray() -> UgridDataArray 

370 >>> DataArray() * UgridDataArray() -> UgridDataArray 

371 """ 

372 

373 def decorator(*args, **kwargs): 

374 unstructured = False 

375 grid = None 

376 for arg in args: 

377 if is_unstructured(arg): 

378 unstructured = True 

379 grid = arg.ugrid.grid 

380 

381 x = func(*args, **kwargs) 

382 

383 if unstructured: 

384 # Multiple grids returned 

385 if isinstance(x, tuple): 

386 return tuple(_enforce_unstructured(i, grid) for i in x) 

387 return _enforce_unstructured(x, grid) 

388 return x 

389 

390 return decorator