Coverage for src/scores/processing.py: 99%
74 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""Tools for processing data for verification"""
2import operator
3from collections.abc import Iterable
4from typing import Optional, Union
6import numpy as np
7import pandas as pd
8import xarray as xr
10from scores.typing import FlexibleDimensionTypes, XarrayLike
11from scores.utils import gather_dimensions
13INEQUALITY_MODES = {
14 ">=": (operator.ge, -1),
15 ">": (operator.gt, 1),
16 "<=": (operator.le, 1),
17 "<": (operator.lt, -1),
18}
19# '==' does not map to `operater.eq` and '!=' does not map to operator.ne on purpose.
20# This is because we wish to test within a tolerance.
21EQUALITY_MODES = {"==": (operator.le), "!=": (operator.gt)}
24def check_binary(data: XarrayLike, name: str):
25 """
26 Checks that data does not have any non-NaN values out of the set {0, 1}
28 Args:
29 data: The data to convert to check if only contains binary values
30 Raises:
31 ValueError: if there are values in `fcst` and `obs` that are not in the
32 set {0, 1, np.nan} and `check_args` is true.
33 """
34 if isinstance(data, xr.DataArray):
35 unique_values = pd.unique(data.values.flatten())
36 else:
37 unique_values = pd.unique(data.to_array().values.flatten())
38 unique_values = unique_values[~np.isnan(unique_values)]
39 binary_set = {0, 1}
41 if not set(unique_values).issubset(binary_set):
42 raise ValueError(f"`{name}` contains values that are not in the set {{0, 1, np.nan}}")
45def comparative_discretise(
46 data: XarrayLike, comparison: Union[xr.DataArray, float, int], mode: str, abs_tolerance: Optional[float] = None
47) -> XarrayLike:
48 """
49 Converts the values of `data` to 0 or 1 based on how they relate to the specified
50 values in `comparison` via the `mode` operator.
52 Args:
53 data: The data to convert to
54 discrete values.
55 comparison: The values to which
56 to compare `data`.
57 mode: Specifies the required relation of `data` to `thresholds`
58 for a value to fall in the 'event' category (i.e. assigned to 1).
59 Allowed modes are:
60 - '>=' values in `data` greater than or equal to the
61 corresponding threshold are assigned as 1.
62 - '>' values in `data` greater than the corresponding threshold
63 are assigned as 1.
64 - '<=' values in `data` less than or equal to the corresponding
65 threshold are assigned as 1.
66 - '<' values in `data` less than the corresponding threshold
67 are assigned as 1.
68 - '==' values in `data` equal to the corresponding threshold
69 are assigned as 1
70 - '!=' values in `data` not equal to the corresponding threshold
71 are assigned as 1.
72 abs_tolerance: If supplied, values in data that are
73 within abs_tolerance of a threshold are considered to be equal to
74 that threshold. This is generally used to correct for floating
75 point rounding, e.g. we may want to consider 1.0000000000000002 as
76 equal to 1.
77 Returns:
78 An xarray data object of the same type as `data`. The dimensions of the
79 output are the union of all dimensions in `data` and `comparison`. The
80 values of the output are either 0 or 1 or NaN, depending on the truth
81 of the operation `data <mode> comparison`.
82 Raises:
83 ValueError: if abs_tolerance is not a non-negative float.
84 ValueError: if `mode` is not valid.
85 TypeError: if `comparison` is not a float, int or xarray.DataArray.
86 """
88 # sanitise abs_tolerance
89 if abs_tolerance is None:
90 abs_tolerance = 0
91 elif abs_tolerance < 0:
92 raise ValueError(f"value {abs_tolerance} of abs_tolerance is invalid, it must be a non-negative float")
94 if isinstance(comparison, (float, int)):
95 comparison = xr.DataArray(comparison)
96 elif not isinstance(comparison, xr.DataArray):
97 raise TypeError("comparison must be a float, int or xarray.DataArray")
99 # mask to preserve NaN in data and comparison
100 notnull_mask = data.notnull() * comparison.notnull()
102 # do the discretisation
103 if mode in INEQUALITY_MODES:
104 operator_func, factor = INEQUALITY_MODES[mode]
105 discrete_data = operator_func(data, comparison + (abs_tolerance * factor)).where(notnull_mask)
106 elif mode in EQUALITY_MODES:
107 operator_func = EQUALITY_MODES[mode]
108 discrete_data = operator_func(abs(data - comparison), abs_tolerance).where(notnull_mask)
109 else:
110 raise ValueError(
111 f"'{mode}' is not a valid mode. Available modes are: "
112 f"{sorted(INEQUALITY_MODES) + sorted(EQUALITY_MODES)}"
113 )
114 discrete_data.attrs["discretisation_tolerance"] = abs_tolerance
115 discrete_data.attrs["discretisation_mode"] = mode
117 return discrete_data
120def binary_discretise(
121 data: XarrayLike,
122 thresholds: FlexibleDimensionTypes,
123 mode: str,
124 abs_tolerance: Optional[float] = None,
125 autosqueeze: Optional[bool] = False,
126):
127 """
128 Converts the values of `data` to 0 or 1 for each threshold in `thresholds`
129 according to the operation defined by `mode`.
131 Args:
132 data: The data to convert to
133 discrete values.
134 thresholds: Threshold(s) at which to convert the
135 values of `data` to 0 or 1.
136 mode: Specifies the required relation of `data` to `thresholds`
137 for a value to fall in the 'event' category (i.e. assigned to 1).
138 Allowed modes are:
140 - '>=' values in `data` greater than or equal to the
141 corresponding threshold are assigned as 1.
142 - '>' values in `data` greater than the corresponding threshold
143 are assigned as 1.
144 - '<=' values in `data` less than or equal to the corresponding
145 threshold are assigned as 1.
146 - '<' values in `data` less than the corresponding threshold
147 are assigned as 1.
148 - '==' values in `data` equal to the corresponding threshold
149 are assigned as 1
150 - '!=' values in `data` not equal to the corresponding threshold
151 are assigned as 1.
153 abs_tolerance: If supplied, values in data that are
154 within abs_tolerance of a threshold are considered to be equal to
155 that threshold. This is generally used to correct for floating
156 point rounding, e.g. we may want to consider 1.0000000000000002 as
157 equal to 1
159 autosqueeze: If True and only one threshold is
160 supplied, then the dimension 'threshold' is squeezed out of the
161 output. If `thresholds` is float-like, then this is forced to
162 True, otherwise defaults to False.
164 Returns:
165 An xarray data object with the type and dimensions of `data`, plus an
166 extra dimension 'threshold' if `autosqueeze` is False. The values of
167 the output are either 0 or 1, depending on whether `data <mode> threshold`
168 is True or not (although NaNs are preserved).
170 Raises:
171 ValueError: if 'threshold' is a dimension in `data`.
172 ValueError: if "Values in `thresholds` are not montonic increasing"
173 """
174 if "threshold" in data.dims:
175 raise ValueError("'threshold' must not be in the supplied data object dimensions")
177 # if thresholds is 0-D, convert it to a length-1 1-D array
178 # but autosqueeze=True so the 'threshold' dimension is dropped
179 thresholds_np = np.array(thresholds)
180 if thresholds_np.ndim == 0:
181 thresholds_np = np.expand_dims(thresholds_np, 0)
182 autosqueeze = True
184 # sanitise thresholds
185 if not (thresholds_np[1:] - thresholds_np[:-1] >= 0).all():
186 raise ValueError("Values in `thresholds` are not montonic increasing")
188 # make thresholds DataArray
189 thresholds_da = xr.DataArray(thresholds_np, dims=["threshold"], coords={"threshold": thresholds_np})
191 # do the discretisation
192 discrete_data = comparative_discretise(data, thresholds_da, mode, abs_tolerance=abs_tolerance)
194 # squeeze
195 if autosqueeze and len(thresholds_np) == 1:
196 # squeeze out the 'threshold' dimension, but keep the coordinate
197 discrete_data = discrete_data.squeeze(dim="threshold")
199 return discrete_data
202def broadcast_and_match_nan(*args: XarrayLike) -> tuple[XarrayLike, ...]:
203 """
204 Input xarray data objects are 'matched' - they are broadcast against each
205 other (forced to have the same dimensions), and the position of nans are
206 forced onto all DataArrays. This matching process is applied across all
207 supplied DataArrays, as well as all DataArrays inside supplied Datasets.
209 Args:
210 *args: any number of xarray data objects supplied as positional arguments. See
211 examples below.
213 Returns:
214 A tuple of data objects of the same length as the number of data objects
215 supplied as input. Each returned object is the 'matched' version of the
216 input.
218 Raises:
219 ValueError: if any input args is not an xarray data
220 object.
222 Examples:
224 >>> # Matching xarray data objects
225 >>> da1_matched, ds_matched, da2_matched = xrtools.broadcast_and_match_nan(da1, ds, da2)
227 >>> # Matching a tuple of xarray data objects
228 >>> input_tuple = (da1, ds, da2)
229 >>> matched_tuple = broadcast_and_match_nan(*input_tuple)
230 >>> da1_matched = matched_tuple[0]
231 >>> ds_matched = matched_tuple[1]
232 >>> da2_matched = matched_tuple[2]
233 """
235 # sanitise inputs
236 for i, arg in enumerate(args):
237 if not isinstance(arg, (xr.Dataset, xr.DataArray)):
238 raise ValueError(
239 f"Argument {i} is not an xarray data object. (counting from 0, i.e. "
240 "argument 0 is the first argument)"
241 )
243 # internal function to update the mask
244 def update_mask(mask, data_array):
245 """
246 Perform the boolean AND operation on a mask (DataArray) and
247 data_array.notnull()
248 """
249 return mask & data_array.notnull()
251 # initialise the mask
252 mask = True
253 # generate the mask
254 for arg in args:
255 # update the mask for a DataArray
256 if isinstance(arg, xr.DataArray):
257 mask = update_mask(mask, arg)
258 # update the mask for Datasets
259 elif isinstance(arg, xr.Dataset): 259 ↛ 254line 259 didn't jump to line 254, because the condition on line 259 was never false
260 for data_var in arg.data_vars:
261 mask = update_mask(mask, arg[data_var])
263 # return matched data objects
264 return tuple(arg.where(mask) for arg in args)
267def proportion_exceeding(
268 data: XarrayLike,
269 thresholds: Iterable,
270 preserve_dims: FlexibleDimensionTypes = None,
271 reduce_dims: FlexibleDimensionTypes = None,
272):
273 """
274 Calculates the proportion of `data` equal to or exceeding `thresholds`.
276 Args:
277 data (xarray.Dataset or xarray.DataArray): The data from which
278 to calculate the proportion exceeding `thresholds`
279 thresholds (iterable): The proportion of Flip-Flop index results
280 equal to or exceeding these thresholds will be calculated.
281 the flip-flop index.
282 reduce_dims: Dimensions to reduce.
283 preserve_dims: Dimensions to preserve.
285 Returns:
286 An xarray data object with the type of `data` and dimensions
287 `dims` + 'threshold'. The values are the proportion of `data`
288 that are greater than or equal to the corresponding threshold.
290 """
291 return _binary_discretise_proportion(data, thresholds, ">=", preserve_dims, reduce_dims)
294def _binary_discretise_proportion(
295 data: XarrayLike,
296 thresholds: Iterable,
297 mode: str,
298 preserve_dims: FlexibleDimensionTypes = None,
299 reduce_dims: FlexibleDimensionTypes = None,
300 abs_tolerance: Optional[bool] = None,
301 autosqueeze: bool = False,
302):
303 """
304 Returns the proportion of `data` in each category. The categories are
305 defined by the relationship of data to threshold as specified by
306 the operation `mode`.
308 Args:
309 data: The data to convert
310 into 0 and 1 according the thresholds before calculating the
311 proportion.
312 thresholds: The proportion of Flip-Flop index results
313 equal to or exceeding these thresholds will be calculated.
314 the flip-flop index.
315 mode: Specifies the required relation of `data` to `thresholds`
316 for a value to fall in the 'event' category (i.e. assigned to 1).
317 Allowed modes are:
319 - '>=' values in `data` greater than or equal to the
320 corresponding threshold are assigned as 1.
321 - '>' values in `data` greater than the corresponding threshold
322 are assigned as 1.
323 - '<=' values in `data` less than or equal to the corresponding
324 threshold are assigned as 1.
325 - '<' values in `data` less than the corresponding threshold
326 are assigned as 1.
327 - '==' values in `data` equal to the corresponding threshold
328 are assigned as 1
329 - '!=' values in `data` not equal to the corresponding threshold
330 are assigned as 1.
331 reduce_dims: Dimensions to reduce.
332 preserve_dims: Dimensions to preserve.
333 abs_tolerance: If supplied, values in data that are
334 within abs_tolerance of a threshold are considered to be equal to
335 that threshold. This is generally used to correct for floating
336 point rounding, e.g. we may want to consider 1.0000000000000002 as
337 equal to 1.
338 autosqueeze: If True and only one threshold is
339 supplied, then the dimension 'threshold' is squeezed out of the
340 output. If `thresholds` is float-like, then this is forced to
341 True, otherwise defaults to False.
343 Returns:
344 An xarray data object with the type of `data`, dimension `dims` +
345 'threshold'. The values of the output are the proportion of `data` that
346 satisfy the relationship to `thresholds` as specified by `mode`.
348 Examples:
350 >>> data = xr.DataArray([0, 0.5, 0.5, 1])
352 >>> _binary_discretise_proportion(data, [0, 0.5, 1], '==')
353 <xarray.DataArray (threshold: 3)>
354 array([ 0.25, 0.5 , 0.25])
355 Coordinates:
356 * threshold (threshold) float64 0.0 0.5 1.0
357 Attributes:
358 discretisation_tolerance: 0
359 discretisation_mode: ==
361 >>> _binary_discretise_proportion(data, [0, 0.5, 1], '>=')
362 <xarray.DataArray (threshold: 3)>
363 array([ 1. , 0.75, 0.25])
364 Coordinates:
365 * threshold (threshold) float64 0.0 0.5 1.0
366 Attributes:
367 discretisation_tolerance: 0
368 discretisation_mode: >=
370 See also:
371 `scores.processing.binary_discretise`
373 """
374 # values are 1 when (data {mode} threshold), and 0 when ~(data {mode} threshold).
375 discrete_data = binary_discretise(data, thresholds, mode, abs_tolerance=abs_tolerance, autosqueeze=autosqueeze)
377 # The proportion in each category
378 dims = gather_dimensions(data.dims, data.dims, preserve_dims, reduce_dims)
379 proportion = discrete_data.mean(dim=dims)
381 # attach attributes
382 proportion.attrs = discrete_data.attrs
384 return proportion