Coverage for C:\src\imod-python\imod\util\structured.py: 96%

56 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-08 14:15 +0200

1from typing import Any, Optional 

2 

3import numpy as np 

4import xarray as xr 

5 

6 

7def where(condition, if_true, if_false, keep_nan: bool = True) -> xr.DataArray: 

8 """ 

9 Wrapped version of xarray's ``.where``. 

10 

11 This wrapped version does two differently: 

12 

13 Firstly, it prioritizes the dimensions as: ``if_true > if_false > condition``. 

14 ``xarray.where(cond, a, b)`` will choose the dimension over ``a`` or ``b``. 

15 This may result in unwanted dimension orders such as ``("y", "x", "layer)`` 

16 rather than ``("layer", "y', "x")``. 

17 

18 Secondly, it preserves the NaN values of ``if_true`` by default. If we 

19 wish to replace all values over 5 by 5, yet keep the NoData parts, this 

20 requires two operations with with xarray's ``where``. 

21 

22 Parameters 

23 ---------- 

24 condition: DataArray, Dataset 

25 Locations at which to preserve this object's values. dtype must be `bool`. 

26 if_true : scalar, DataArray or Dataset, optional 

27 Value to use for locations where ``cond`` is True. 

28 if_false : scalar, DataArray or Dataset, optional 

29 Value to use for locations where ``cond`` is False. 

30 keep_nan: bool, default: True 

31 Whether to keep the NaN values in place of ``if_true``. 

32 """ 

33 xr_obj = (xr.DataArray, xr.Dataset) 

34 da_true = isinstance(if_true, xr_obj) 

35 da_false = isinstance(if_false, xr_obj) 

36 da_cond = isinstance(condition, xr_obj) 

37 

38 # Give priority to where_true or where_false for broadcasting. 

39 if da_true: 

40 new = if_true.copy() 

41 elif da_false: 

42 new = xr.full_like(if_false, if_true) 

43 elif da_cond: 

44 new = xr.full_like(condition, if_true, dtype=type(if_true)) 

45 else: 

46 raise ValueError( 

47 "at least one of {condition, if_true, if_false} should be a " 

48 "DataArray or Dataset" 

49 ) 

50 

51 new = new.where(condition, other=if_false) 

52 if keep_nan and da_true: 

53 new = new.where(if_true.notnull()) 

54 

55 return new 

56 

57 

58def replace(da: xr.DataArray, to_replace: Any, value: Any) -> xr.DataArray: 

59 """ 

60 Replace values given in `to_replace` by `value`. 

61 

62 Parameters 

63 ---------- 

64 da: xr.DataArray 

65 to_replace: scalar or 1D array like 

66 Which values to replace. If to_replace and value are both array like, 

67 they must be the same length. 

68 value: scalar or 1D array like 

69 Value to replace any values matching `to_replace` with. 

70 

71 Returns 

72 ------- 

73 xr.DataArray 

74 DataArray after replacement. 

75 

76 Examples 

77 -------- 

78 

79 Replace values of 1.0 by 10.0, and 2.0 by 20.0: 

80 

81 >>> da = xr.DataArray([0.0, 1.0, 1.0, 2.0, 2.0]) 

82 >>> replaced = imod.util.replace(da, to_replace=[1.0, 2.0], value=[10.0, 20.0]) 

83 

84 """ 

85 from xarray.core.utils import is_scalar 

86 

87 def _replace( 

88 a: np.ndarray, to_replace: np.ndarray, value: np.ndarray 

89 ) -> np.ndarray: 

90 flat = da.values.ravel() 

91 

92 sorter = np.argsort(to_replace) 

93 insertion = np.searchsorted(to_replace, flat, sorter=sorter) 

94 indices = np.take(sorter, insertion, mode="clip") 

95 replaceable = to_replace[indices] == flat 

96 

97 out = flat.copy() 

98 out[replaceable] = value[indices[replaceable]] 

99 return out.reshape(a.shape) 

100 

101 if is_scalar(to_replace): 

102 if not is_scalar(value): 

103 raise TypeError("if to_replace is scalar, then value must be a scalar") 

104 if np.isnan(to_replace): # type: ignore 

105 return da.fillna(value) 

106 else: 

107 return da.where(da != to_replace, other=value) 

108 else: 

109 to_replace = np.asarray(to_replace) 

110 if to_replace.ndim != 1: 

111 raise ValueError("to_replace must be 1D or scalar") 

112 if is_scalar(value): 

113 value = np.full_like(to_replace, value) 

114 else: 

115 value = np.asarray(value) 

116 if to_replace.shape != value.shape: 

117 raise ValueError( 

118 f"Replacement arrays must match in shape. " 

119 f"Expecting {to_replace.shape} got {value.shape} " 

120 ) 

121 

122 _, counts = np.unique(to_replace, return_counts=True) 

123 if (counts > 1).any(): 

124 raise ValueError("to_replace contains duplicates") 

125 

126 isnan = np.isnan(to_replace) 

127 if isnan.any(): 

128 i = np.nonzero(isnan)[0] 

129 da = da.fillna(value[i]) 

130 

131 return xr.apply_ufunc( 

132 _replace, 

133 da, 

134 kwargs={"to_replace": to_replace, "value": value}, 

135 dask="parallelized", 

136 output_dtypes=[da.dtype], 

137 ) 

138 

139 

140def values_within_range( 

141 da: xr.DataArray, min: Optional[float] = None, max: Optional[float] = None 

142) -> xr.DataArray | bool: 

143 """ 

144 Find which values are within range. 

145 Function checks which values are unaffected by the clip method, to 

146 be able to deal with min and max values equal to None, which 

147 should be ignored. 

148 If both min and max are True, returns scalar True. 

149 

150 Parameters 

151 ---------- 

152 da: xr.DataArray 

153 DataArray to be checked 

154 min: float 

155 Minimum value, if None no minimum value is set 

156 max: float 

157 Maximum value, if None no maximum value is set 

158 

159 Returns 

160 ------- 

161 {bool, xr.DataArray} 

162 Boolean array with values which are within range as True. 

163 """ 

164 if min is None and max is None: 

165 return True 

166 else: 

167 return da == da.clip(min=min, max=max)