Coverage for C:\src\imod-python\imod\typing\grid.py: 89%

167 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-08 13:27 +0200

1import pickle 

2import textwrap 

3from typing import Callable, Mapping, Sequence 

4 

5import numpy as np 

6import xarray as xr 

7import xugrid as xu 

8from fastcore.dispatch import typedispatch 

9 

10from imod.typing import GridDataArray, GridDataset, structured 

11from imod.util.spatial import _polygonize 

12 

13 

14@typedispatch 

15def zeros_like(grid: xr.DataArray, *args, **kwargs): 

16 return xr.zeros_like(grid, *args, **kwargs) 

17 

18 

19@typedispatch 

20def zeros_like(grid: xu.UgridDataArray, *args, **kwargs): 

21 return xu.zeros_like(grid, *args, **kwargs) 

22 

23 

24@typedispatch 

25def ones_like(grid: xr.DataArray, *args, **kwargs): 

26 return xr.ones_like(grid, *args, **kwargs) 

27 

28 

29@typedispatch 

30def ones_like(grid: xu.UgridDataArray, *args, **kwargs): 

31 return xu.ones_like(grid, *args, **kwargs) 

32 

33 

34@typedispatch 

35def nan_like(grid: xr.DataArray, dtype=np.float32, *args, **kwargs): 

36 return xr.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs) 

37 

38 

39@typedispatch 

40def nan_like(grid: xu.UgridDataArray, dtype=np.float32, *args, **kwargs): 

41 return xu.full_like(grid, fill_value=np.nan, dtype=dtype, *args, **kwargs) 

42 

43 

44@typedispatch 

45def is_unstructured(grid: xu.UgridDataArray | xu.UgridDataset) -> bool: 

46 return True 

47 

48 

49@typedispatch 

50def is_unstructured(grid: xr.DataArray | xr.Dataset) -> bool: 

51 return False 

52 

53 

54def _force_decreasing_y(structured_grid: xr.DataArray | xr.Dataset): 

55 flip = slice(None, None, -1) 

56 if structured_grid.indexes["y"].is_monotonic_increasing: 

57 structured_grid = structured_grid.isel(y=flip) 

58 elif not structured_grid.indexes["y"].is_monotonic_decreasing: 

59 raise RuntimeError( 

60 f"Non-monotonous y-coordinates for grid: {structured_grid.name}." 

61 ) 

62 return structured_grid 

63 

64 

65def _get_first_item(objects: Sequence): 

66 return next(iter(objects)) 

67 

68 

69# Typedispatching doesn't work based on types of list elements, therefore resort to 

70# isinstance testing 

71def _type_dispatch_functions_on_grid_sequence( 

72 objects: Sequence[GridDataArray | GridDataset], 

73 unstructured_func: Callable, 

74 structured_func: Callable, 

75 *args, 

76 **kwargs, 

77) -> GridDataArray | GridDataset: 

78 """ 

79 Type dispatch functions on sequence of grids. Functions like merging or concatenating. 

80 """ 

81 first_object = _get_first_item(objects) 

82 start_type = type(first_object) 

83 homogeneous = all([isinstance(o, start_type) for o in objects]) 

84 if not homogeneous: 

85 unique_types = set([type(o) for o in objects]) 

86 raise TypeError( 

87 f"Only homogeneous sequences can be reduced, received sequence of {unique_types}" 

88 ) 

89 if isinstance(first_object, (xu.UgridDataArray, xu.UgridDataset)): 

90 return unstructured_func(objects, *args, **kwargs) 

91 elif isinstance(first_object, (xr.DataArray, xr.Dataset)): 

92 return _force_decreasing_y(structured_func(objects, *args, **kwargs)) 

93 raise TypeError( 

94 f"'{unstructured_func.__name__}' not supported for type {type(objects[0])}" 

95 ) 

96 

97 

98# Typedispatching doesn't work based on types of dict elements, therefore resort 

99# to manual type testing 

100def _type_dispatch_functions_on_dict( 

101 dict_of_objects: Mapping[str, GridDataArray | float | bool | int], 

102 unstructured_func: Callable, 

103 structured_func: Callable, 

104 *args, 

105 **kwargs, 

106): 

107 """ 

108 Typedispatch function on grid and scalar variables provided in dictionary. 

109 Types do not need to be homogeneous as scalars and grids can be mixed. No 

110 mixing of structured and unstructured grids is allowed. Also allows running 

111 function on dictionary with purely scalars, in which case it will call to 

112 the xarray function. 

113 """ 

114 

115 error_msg = textwrap.dedent( 

116 """ 

117 Received both structured grid (xr.DataArray) and xu.UgridDataArray. This 

118 means structured grids as well as unstructured grids were provided. 

119 """ 

120 ) 

121 

122 if dict_of_objects is None: 

123 return xr.Dataset() 

124 

125 types = [type(arg) for arg in dict_of_objects.values()] 

126 has_unstructured = xu.UgridDataArray in types 

127 # Test structured if xr.DataArray and spatial. 

128 has_structured_grid = any( 

129 [ 

130 isinstance(arg, xr.DataArray) and is_spatial_2D(arg) 

131 for arg in dict_of_objects.values() 

132 ] 

133 ) 

134 if has_structured_grid and has_unstructured: 

135 raise TypeError(error_msg) 

136 if has_unstructured: 

137 return unstructured_func([dict_of_objects], *args, **kwargs) 

138 

139 return structured_func([dict_of_objects], *args, **kwargs) 

140 

141 

142def merge( 

143 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

144) -> GridDataset: 

145 return _type_dispatch_functions_on_grid_sequence( 

146 objects, xu.merge, xr.merge, *args, **kwargs 

147 ) 

148 

149 

150def merge_partitions( 

151 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

152) -> GridDataArray | GridDataset: 

153 return _type_dispatch_functions_on_grid_sequence( 

154 objects, xu.merge_partitions, structured.merge_partitions, *args, **kwargs 

155 ) 

156 

157 

158def concat( 

159 objects: Sequence[GridDataArray | GridDataset], *args, **kwargs 

160) -> GridDataArray | GridDataset: 

161 return _type_dispatch_functions_on_grid_sequence( 

162 objects, xu.concat, xr.concat, *args, **kwargs 

163 ) 

164 

165 

166def merge_unstructured_dataset(variables_to_merge: list[dict], *args, **kwargs): 

167 """ 

168 Work around xugrid issue https://github.com/Deltares/xugrid/issues/179 

169 

170 Expects only one dictionary in list. List is used to have same API as 

171 xr.merge(). 

172 

173 Merges unstructured grids first, then manually assigns scalar variables. 

174 """ 

175 if len(variables_to_merge) > 1: 

176 raise ValueError( 

177 f"Only one dict of variables expected, got {len(variables_to_merge)}" 

178 ) 

179 

180 variables_to_merge_dict = variables_to_merge[0] 

181 

182 if not isinstance(variables_to_merge_dict, dict): 

183 raise TypeError(f"Expected dict, got {type(variables_to_merge_dict)}") 

184 

185 # Separate variables into list of grids and dict of scalar variables 

186 grids_ls = [] 

187 scalar_dict = {} 

188 for name, variable in variables_to_merge_dict.items(): 

189 if isinstance(variable, xu.UgridDataArray): 

190 grids_ls.append(variable.rename(name)) 

191 else: 

192 scalar_dict[name] = variable 

193 

194 # Merge grids 

195 dataset = xu.merge(grids_ls, *args, **kwargs) 

196 

197 # Temporarily work around this xugrid issue, until fixed: 

198 # https://github.com/Deltares/xugrid/issues/206 

199 grid_hashes = [hash(pickle.dumps(grid)) for grid in dataset.ugrid.grids] 

200 unique_grid_hashes = np.unique(grid_hashes) 

201 if unique_grid_hashes.size > 1: 

202 raise ValueError( 

203 "Multiple grids provided, please provide data on one unique grid" 

204 ) 

205 else: 

206 # Possibly won't work anymore if this ever gets implemented: 

207 # https://github.com/Deltares/xugrid/issues/195 

208 dataset._grids = [dataset.grids[0]] 

209 

210 # Assign scalar variables manually 

211 for name, variable in scalar_dict.items(): 

212 dataset[name] = variable 

213 

214 return dataset 

215 

216 

217def merge_with_dictionary( 

218 variables_to_merge: Mapping[str, GridDataArray | float | bool | int], 

219 *args, 

220 **kwargs, 

221): 

222 return _type_dispatch_functions_on_dict( 

223 variables_to_merge, merge_unstructured_dataset, xr.merge, *args, **kwargs 

224 ) 

225 

226 

227@typedispatch 

228def bounding_polygon(active: xr.DataArray): 

229 """Return bounding polygon of active cells""" 

230 to_polygonize = active.where(active, other=np.nan) 

231 polygons_gdf = _polygonize(to_polygonize) 

232 # Filter polygons with inactive values (NaN) 

233 is_active_polygon = polygons_gdf["value"] == 1.0 

234 return polygons_gdf.loc[is_active_polygon] 

235 

236 

237@typedispatch 

238def bounding_polygon(active: xu.UgridDataArray): 

239 """Return bounding polygon of active cells""" 

240 active_indices = np.where(active > 0)[0] 

241 domain_slice = {f"{active.ugrid.grid.face_dimension}": active_indices} 

242 active_clipped = active.isel(domain_slice, missing_dims="ignore") 

243 

244 return active_clipped.ugrid.grid.bounding_polygon() 

245 

246 

247@typedispatch 

248def is_spatial_2D(array: xr.DataArray) -> bool: 

249 """Return True if the array contains data in at least 2 spatial dimensions""" 

250 coords = array.coords 

251 dims = array.dims 

252 has_spatial_coords = "x" in coords and "y" in coords 

253 has_spatial_dims = "x" in dims and "y" in dims 

254 return has_spatial_coords & has_spatial_dims 

255 

256 

257@typedispatch 

258def is_spatial_2D(array: xu.UgridDataArray) -> bool: 

259 """Return True if the array contains data associated to cell faces""" 

260 face_dim = array.ugrid.grid.face_dimension 

261 dims = array.dims 

262 coords = array.coords 

263 has_spatial_coords = face_dim in coords 

264 has_spatial_dims = face_dim in dims 

265 return has_spatial_dims & has_spatial_coords 

266 

267 

268@typedispatch 

269def is_spatial_2D(_: object) -> bool: 

270 return False 

271 

272 

273@typedispatch 

274def is_equal(array1: xu.UgridDataArray, array2: xu.UgridDataArray) -> bool: 

275 return array1.equals(array2) and array1.ugrid.grid.equals(array2.ugrid.grid) 

276 

277 

278@typedispatch 

279def is_equal(array1: xr.DataArray, array2: xr.DataArray) -> bool: 

280 return array1.equals(array2) 

281 

282 

283@typedispatch 

284def is_equal(array1: object, array2: object) -> bool: 

285 return False 

286 

287 

288@typedispatch 

289def is_same_domain(grid1: xu.UgridDataArray, grid2: xu.UgridDataArray) -> bool: 

290 return grid1.coords.equals(grid2.coords) and grid1.ugrid.grid.equals( 

291 grid2.ugrid.grid 

292 ) 

293 

294 

295@typedispatch 

296def is_same_domain(grid1: xr.DataArray, grid2: xr.DataArray) -> bool: 

297 return grid1.coords.equals(grid2.coords) 

298 

299 

300@typedispatch 

301def is_same_domain(grid1: object, grid2: object) -> bool: 

302 return False 

303 

304 

305@typedispatch 

306def get_spatial_dimension_names(grid: xr.DataArray) -> list[str]: 

307 return ["x", "y", "layer", "dx", "dy"] 

308 

309 

310@typedispatch 

311def get_spatial_dimension_names(grid: xu.UgridDataArray) -> list[str]: 

312 facedim = grid.ugrid.grid.face_dimension 

313 return [facedim, "layer"] 

314 

315 

316@typedispatch 

317def get_spatial_dimension_names(grid: object) -> list[str]: 

318 return [] 

319 

320 

321@typedispatch 

322def get_grid_geometry_hash(grid: xr.DataArray) -> int: 

323 hash_x = hash(pickle.dumps(grid["x"].values)) 

324 hash_y = hash(pickle.dumps(grid["y"].values)) 

325 return (hash_x, hash_y) 

326 

327 

328@typedispatch 

329def get_grid_geometry_hash(grid: xu.UgridDataArray) -> int: 

330 hash_x = hash(pickle.dumps(grid.ugrid.grid.node_x)) 

331 hash_y = hash(pickle.dumps(grid.ugrid.grid.node_y)) 

332 hash_connectivity = hash(pickle.dumps(grid.ugrid.grid.node_face_connectivity)) 

333 return (hash_x, hash_y, hash_connectivity) 

334 

335 

336@typedispatch 

337def get_grid_geometry_hash(grid: object) -> int: 

338 raise ValueError("get_grid_geometry_hash not supported for this object.")