Coverage for src/scores/processing.py: 99%

74 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1"""Tools for processing data for verification""" 

2import operator 

3from collections.abc import Iterable 

4from typing import Optional, Union 

5 

6import numpy as np 

7import pandas as pd 

8import xarray as xr 

9 

10from scores.typing import FlexibleDimensionTypes, XarrayLike 

11from scores.utils import gather_dimensions 

12 

13INEQUALITY_MODES = { 

14 ">=": (operator.ge, -1), 

15 ">": (operator.gt, 1), 

16 "<=": (operator.le, 1), 

17 "<": (operator.lt, -1), 

18} 

19# '==' does not map to `operater.eq` and '!=' does not map to operator.ne on purpose. 

20# This is because we wish to test within a tolerance. 

21EQUALITY_MODES = {"==": (operator.le), "!=": (operator.gt)} 

22 

23 

24def check_binary(data: XarrayLike, name: str): 

25 """ 

26 Checks that data does not have any non-NaN values out of the set {0, 1} 

27 

28 Args: 

29 data: The data to convert to check if only contains binary values 

30 Raises: 

31 ValueError: if there are values in `fcst` and `obs` that are not in the 

32 set {0, 1, np.nan} and `check_args` is true. 

33 """ 

34 if isinstance(data, xr.DataArray): 

35 unique_values = pd.unique(data.values.flatten()) 

36 else: 

37 unique_values = pd.unique(data.to_array().values.flatten()) 

38 unique_values = unique_values[~np.isnan(unique_values)] 

39 binary_set = {0, 1} 

40 

41 if not set(unique_values).issubset(binary_set): 

42 raise ValueError(f"`{name}` contains values that are not in the set {{0, 1, np.nan}}") 

43 

44 

45def comparative_discretise( 

46 data: XarrayLike, comparison: Union[xr.DataArray, float, int], mode: str, abs_tolerance: Optional[float] = None 

47) -> XarrayLike: 

48 """ 

49 Converts the values of `data` to 0 or 1 based on how they relate to the specified 

50 values in `comparison` via the `mode` operator. 

51 

52 Args: 

53 data: The data to convert to 

54 discrete values. 

55 comparison: The values to which 

56 to compare `data`. 

57 mode: Specifies the required relation of `data` to `thresholds` 

58 for a value to fall in the 'event' category (i.e. assigned to 1). 

59 Allowed modes are: 

60 - '>=' values in `data` greater than or equal to the 

61 corresponding threshold are assigned as 1. 

62 - '>' values in `data` greater than the corresponding threshold 

63 are assigned as 1. 

64 - '<=' values in `data` less than or equal to the corresponding 

65 threshold are assigned as 1. 

66 - '<' values in `data` less than the corresponding threshold 

67 are assigned as 1. 

68 - '==' values in `data` equal to the corresponding threshold 

69 are assigned as 1 

70 - '!=' values in `data` not equal to the corresponding threshold 

71 are assigned as 1. 

72 abs_tolerance: If supplied, values in data that are 

73 within abs_tolerance of a threshold are considered to be equal to 

74 that threshold. This is generally used to correct for floating 

75 point rounding, e.g. we may want to consider 1.0000000000000002 as 

76 equal to 1. 

77 Returns: 

78 An xarray data object of the same type as `data`. The dimensions of the 

79 output are the union of all dimensions in `data` and `comparison`. The 

80 values of the output are either 0 or 1 or NaN, depending on the truth 

81 of the operation `data <mode> comparison`. 

82 Raises: 

83 ValueError: if abs_tolerance is not a non-negative float. 

84 ValueError: if `mode` is not valid. 

85 TypeError: if `comparison` is not a float, int or xarray.DataArray. 

86 """ 

87 

88 # sanitise abs_tolerance 

89 if abs_tolerance is None: 

90 abs_tolerance = 0 

91 elif abs_tolerance < 0: 

92 raise ValueError(f"value {abs_tolerance} of abs_tolerance is invalid, it must be a non-negative float") 

93 

94 if isinstance(comparison, (float, int)): 

95 comparison = xr.DataArray(comparison) 

96 elif not isinstance(comparison, xr.DataArray): 

97 raise TypeError("comparison must be a float, int or xarray.DataArray") 

98 

99 # mask to preserve NaN in data and comparison 

100 notnull_mask = data.notnull() * comparison.notnull() 

101 

102 # do the discretisation 

103 if mode in INEQUALITY_MODES: 

104 operator_func, factor = INEQUALITY_MODES[mode] 

105 discrete_data = operator_func(data, comparison + (abs_tolerance * factor)).where(notnull_mask) 

106 elif mode in EQUALITY_MODES: 

107 operator_func = EQUALITY_MODES[mode] 

108 discrete_data = operator_func(abs(data - comparison), abs_tolerance).where(notnull_mask) 

109 else: 

110 raise ValueError( 

111 f"'{mode}' is not a valid mode. Available modes are: " 

112 f"{sorted(INEQUALITY_MODES) + sorted(EQUALITY_MODES)}" 

113 ) 

114 discrete_data.attrs["discretisation_tolerance"] = abs_tolerance 

115 discrete_data.attrs["discretisation_mode"] = mode 

116 

117 return discrete_data 

118 

119 

120def binary_discretise( 

121 data: XarrayLike, 

122 thresholds: FlexibleDimensionTypes, 

123 mode: str, 

124 abs_tolerance: Optional[float] = None, 

125 autosqueeze: Optional[bool] = False, 

126): 

127 """ 

128 Converts the values of `data` to 0 or 1 for each threshold in `thresholds` 

129 according to the operation defined by `mode`. 

130 

131 Args: 

132 data: The data to convert to 

133 discrete values. 

134 thresholds: Threshold(s) at which to convert the 

135 values of `data` to 0 or 1. 

136 mode: Specifies the required relation of `data` to `thresholds` 

137 for a value to fall in the 'event' category (i.e. assigned to 1). 

138 Allowed modes are: 

139 

140 - '>=' values in `data` greater than or equal to the 

141 corresponding threshold are assigned as 1. 

142 - '>' values in `data` greater than the corresponding threshold 

143 are assigned as 1. 

144 - '<=' values in `data` less than or equal to the corresponding 

145 threshold are assigned as 1. 

146 - '<' values in `data` less than the corresponding threshold 

147 are assigned as 1. 

148 - '==' values in `data` equal to the corresponding threshold 

149 are assigned as 1 

150 - '!=' values in `data` not equal to the corresponding threshold 

151 are assigned as 1. 

152 

153 abs_tolerance: If supplied, values in data that are 

154 within abs_tolerance of a threshold are considered to be equal to 

155 that threshold. This is generally used to correct for floating 

156 point rounding, e.g. we may want to consider 1.0000000000000002 as 

157 equal to 1 

158 

159 autosqueeze: If True and only one threshold is 

160 supplied, then the dimension 'threshold' is squeezed out of the 

161 output. If `thresholds` is float-like, then this is forced to 

162 True, otherwise defaults to False. 

163 

164 Returns: 

165 An xarray data object with the type and dimensions of `data`, plus an 

166 extra dimension 'threshold' if `autosqueeze` is False. The values of 

167 the output are either 0 or 1, depending on whether `data <mode> threshold` 

168 is True or not (although NaNs are preserved). 

169 

170 Raises: 

171 ValueError: if 'threshold' is a dimension in `data`. 

172 ValueError: if "Values in `thresholds` are not montonic increasing" 

173 """ 

174 if "threshold" in data.dims: 

175 raise ValueError("'threshold' must not be in the supplied data object dimensions") 

176 

177 # if thresholds is 0-D, convert it to a length-1 1-D array 

178 # but autosqueeze=True so the 'threshold' dimension is dropped 

179 thresholds_np = np.array(thresholds) 

180 if thresholds_np.ndim == 0: 

181 thresholds_np = np.expand_dims(thresholds_np, 0) 

182 autosqueeze = True 

183 

184 # sanitise thresholds 

185 if not (thresholds_np[1:] - thresholds_np[:-1] >= 0).all(): 

186 raise ValueError("Values in `thresholds` are not montonic increasing") 

187 

188 # make thresholds DataArray 

189 thresholds_da = xr.DataArray(thresholds_np, dims=["threshold"], coords={"threshold": thresholds_np}) 

190 

191 # do the discretisation 

192 discrete_data = comparative_discretise(data, thresholds_da, mode, abs_tolerance=abs_tolerance) 

193 

194 # squeeze 

195 if autosqueeze and len(thresholds_np) == 1: 

196 # squeeze out the 'threshold' dimension, but keep the coordinate 

197 discrete_data = discrete_data.squeeze(dim="threshold") 

198 

199 return discrete_data 

200 

201 

202def broadcast_and_match_nan(*args: XarrayLike) -> tuple[XarrayLike, ...]: 

203 """ 

204 Input xarray data objects are 'matched' - they are broadcast against each 

205 other (forced to have the same dimensions), and the position of nans are 

206 forced onto all DataArrays. This matching process is applied across all 

207 supplied DataArrays, as well as all DataArrays inside supplied Datasets. 

208 

209 Args: 

210 *args: any number of xarray data objects supplied as positional arguments. See 

211 examples below. 

212 

213 Returns: 

214 A tuple of data objects of the same length as the number of data objects 

215 supplied as input. Each returned object is the 'matched' version of the 

216 input. 

217 

218 Raises: 

219 ValueError: if any input args is not an xarray data 

220 object. 

221 

222 Examples: 

223 

224 >>> # Matching xarray data objects 

225 >>> da1_matched, ds_matched, da2_matched = xrtools.broadcast_and_match_nan(da1, ds, da2) 

226 

227 >>> # Matching a tuple of xarray data objects 

228 >>> input_tuple = (da1, ds, da2) 

229 >>> matched_tuple = broadcast_and_match_nan(*input_tuple) 

230 >>> da1_matched = matched_tuple[0] 

231 >>> ds_matched = matched_tuple[1] 

232 >>> da2_matched = matched_tuple[2] 

233 """ 

234 

235 # sanitise inputs 

236 for i, arg in enumerate(args): 

237 if not isinstance(arg, (xr.Dataset, xr.DataArray)): 

238 raise ValueError( 

239 f"Argument {i} is not an xarray data object. (counting from 0, i.e. " 

240 "argument 0 is the first argument)" 

241 ) 

242 

243 # internal function to update the mask 

244 def update_mask(mask, data_array): 

245 """ 

246 Perform the boolean AND operation on a mask (DataArray) and 

247 data_array.notnull() 

248 """ 

249 return mask & data_array.notnull() 

250 

251 # initialise the mask 

252 mask = True 

253 # generate the mask 

254 for arg in args: 

255 # update the mask for a DataArray 

256 if isinstance(arg, xr.DataArray): 

257 mask = update_mask(mask, arg) 

258 # update the mask for Datasets 

259 elif isinstance(arg, xr.Dataset): 259 ↛ 254line 259 didn't jump to line 254, because the condition on line 259 was never false

260 for data_var in arg.data_vars: 

261 mask = update_mask(mask, arg[data_var]) 

262 

263 # return matched data objects 

264 return tuple(arg.where(mask) for arg in args) 

265 

266 

267def proportion_exceeding( 

268 data: XarrayLike, 

269 thresholds: Iterable, 

270 preserve_dims: FlexibleDimensionTypes = None, 

271 reduce_dims: FlexibleDimensionTypes = None, 

272): 

273 """ 

274 Calculates the proportion of `data` equal to or exceeding `thresholds`. 

275 

276 Args: 

277 data (xarray.Dataset or xarray.DataArray): The data from which 

278 to calculate the proportion exceeding `thresholds` 

279 thresholds (iterable): The proportion of Flip-Flop index results 

280 equal to or exceeding these thresholds will be calculated. 

281 the flip-flop index. 

282 reduce_dims: Dimensions to reduce. 

283 preserve_dims: Dimensions to preserve. 

284 

285 Returns: 

286 An xarray data object with the type of `data` and dimensions 

287 `dims` + 'threshold'. The values are the proportion of `data` 

288 that are greater than or equal to the corresponding threshold. 

289 

290 """ 

291 return _binary_discretise_proportion(data, thresholds, ">=", preserve_dims, reduce_dims) 

292 

293 

294def _binary_discretise_proportion( 

295 data: XarrayLike, 

296 thresholds: Iterable, 

297 mode: str, 

298 preserve_dims: FlexibleDimensionTypes = None, 

299 reduce_dims: FlexibleDimensionTypes = None, 

300 abs_tolerance: Optional[bool] = None, 

301 autosqueeze: bool = False, 

302): 

303 """ 

304 Returns the proportion of `data` in each category. The categories are 

305 defined by the relationship of data to threshold as specified by 

306 the operation `mode`. 

307 

308 Args: 

309 data: The data to convert 

310 into 0 and 1 according the thresholds before calculating the 

311 proportion. 

312 thresholds: The proportion of Flip-Flop index results 

313 equal to or exceeding these thresholds will be calculated. 

314 the flip-flop index. 

315 mode: Specifies the required relation of `data` to `thresholds` 

316 for a value to fall in the 'event' category (i.e. assigned to 1). 

317 Allowed modes are: 

318 

319 - '>=' values in `data` greater than or equal to the 

320 corresponding threshold are assigned as 1. 

321 - '>' values in `data` greater than the corresponding threshold 

322 are assigned as 1. 

323 - '<=' values in `data` less than or equal to the corresponding 

324 threshold are assigned as 1. 

325 - '<' values in `data` less than the corresponding threshold 

326 are assigned as 1. 

327 - '==' values in `data` equal to the corresponding threshold 

328 are assigned as 1 

329 - '!=' values in `data` not equal to the corresponding threshold 

330 are assigned as 1. 

331 reduce_dims: Dimensions to reduce. 

332 preserve_dims: Dimensions to preserve. 

333 abs_tolerance: If supplied, values in data that are 

334 within abs_tolerance of a threshold are considered to be equal to 

335 that threshold. This is generally used to correct for floating 

336 point rounding, e.g. we may want to consider 1.0000000000000002 as 

337 equal to 1. 

338 autosqueeze: If True and only one threshold is 

339 supplied, then the dimension 'threshold' is squeezed out of the 

340 output. If `thresholds` is float-like, then this is forced to 

341 True, otherwise defaults to False. 

342 

343 Returns: 

344 An xarray data object with the type of `data`, dimension `dims` + 

345 'threshold'. The values of the output are the proportion of `data` that 

346 satisfy the relationship to `thresholds` as specified by `mode`. 

347 

348 Examples: 

349 

350 >>> data = xr.DataArray([0, 0.5, 0.5, 1]) 

351 

352 >>> _binary_discretise_proportion(data, [0, 0.5, 1], '==') 

353 <xarray.DataArray (threshold: 3)> 

354 array([ 0.25, 0.5 , 0.25]) 

355 Coordinates: 

356 * threshold (threshold) float64 0.0 0.5 1.0 

357 Attributes: 

358 discretisation_tolerance: 0 

359 discretisation_mode: == 

360 

361 >>> _binary_discretise_proportion(data, [0, 0.5, 1], '>=') 

362 <xarray.DataArray (threshold: 3)> 

363 array([ 1. , 0.75, 0.25]) 

364 Coordinates: 

365 * threshold (threshold) float64 0.0 0.5 1.0 

366 Attributes: 

367 discretisation_tolerance: 0 

368 discretisation_mode: >= 

369 

370 See also: 

371 `scores.processing.binary_discretise` 

372 

373 """ 

374 # values are 1 when (data {mode} threshold), and 0 when ~(data {mode} threshold). 

375 discrete_data = binary_discretise(data, thresholds, mode, abs_tolerance=abs_tolerance, autosqueeze=autosqueeze) 

376 

377 # The proportion in each category 

378 dims = gather_dimensions(data.dims, data.dims, preserve_dims, reduce_dims) 

379 proportion = discrete_data.mean(dim=dims) 

380 

381 # attach attributes 

382 proportion.attrs = discrete_data.attrs 

383 

384 return proportion