Coverage for src/scores/continuous/flip_flop_impl.py: 100%
86 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2This module contains functions for calculating flip flop indices
3"""
5from collections.abc import Generator, Iterable, Sequence
6from typing import Optional, Union, overload
8import numpy as np
9import xarray as xr
11from scores.functions import angular_difference
12from scores.processing import proportion_exceeding
13from scores.typing import FlexibleDimensionTypes, XarrayLike
14from scores.utils import DimensionError, check_dims, dims_complement
17def _flip_flop_index(data: xr.DataArray, sampling_dim: str, is_angular: bool = False) -> xr.DataArray:
18 """
19 Calculates the flip-flop index by collapsing the dimension specified by
20 `sampling_dim`.
22 Args:
23 data: Data from which to draw subsets.
24 sampling_dim: The name of the dimension along which to calculate
25 the flip-flop index.
26 is_angular: specifies whether `data` is directional data (e.g. wind
27 direction).
29 Returns:
30 A xarray.DataArray of the flip-flop index with the dimensions of
31 `data`, except for the `sampling_dim` dimension which is collapsed.
33 See also:
34 `scores.continuous.flip_flop.flip_flop_index`
35 """
36 # check that `sampling_dim` is in `data`.
37 check_dims(data, [sampling_dim], mode="superset")
38 # the maximum possible number of discrete flip_flops
39 sequence_length = len(data[sampling_dim])
40 max_possible_flip_flop_count = sequence_length - 2
42 # calculate the range
43 # skip_na=False guarantees that if there is a nan in that row,
44 # it will show up as nan in the end
45 if is_angular:
46 # get complementary dimensions as `encompassing_sector_size` takes
47 # dimensions to be preserved, not collapsed
48 dims_to_preserve = dims_complement(data, [sampling_dim])
49 # get maximum forecast range, if > 180 then clip to 180 as this is the
50 # maximum possible angular difference between two forecasts
51 enc_size = encompassing_sector_size(data=data, dims=dims_to_preserve)
52 range_val = np.clip(enc_size, a_min=None, a_max=180.0)
53 flip_flop = angular_difference(data.shift({sampling_dim: 1}), data)
54 else:
55 max_val = data.max(dim=sampling_dim, skipna=False)
56 min_val = data.min(dim=sampling_dim, skipna=False)
57 range_val = max_val - min_val
58 # subtract each consecutive 'row' from eachother
59 flip_flop = data.shift({sampling_dim: 1}) - data
61 # take the absolute value and sum.
62 # I don't do skipna=False here because .shift makes a row of nan
63 flip_flop = abs(flip_flop).sum(dim=sampling_dim)
64 # adjust based on the range. This is where nan will be introduced.
65 flip_flop = flip_flop - range_val
66 # normalise
67 return flip_flop / max_possible_flip_flop_count
70# If there are selections, a DataSet is always returned
71@overload
72def flip_flop_index(
73 data: xr.DataArray, sampling_dim: str, is_angular: bool = False, **selections: Iterable[int]
74) -> xr.Dataset:
75 ...
78# If there are no selections, a DataArray is always returned
79@overload
80def flip_flop_index(
81 data: xr.DataArray, sampling_dim: str, is_angular: bool = False, **selections: None
82) -> xr.DataArray:
83 ...
86# Return type is more precise at runtime when it is known if selections are being used
87def flip_flop_index(
88 data: xr.DataArray, sampling_dim: str, is_angular: bool = False, **selections: Optional[Iterable[int]]
89) -> XarrayLike:
90 """
91 Calculates the Flip-flop Index along the dimensions `sampling_dim`.
93 Args:
94 data: Data from which to draw subsets.
95 sampling_dim: The name of the dimension along which to calculate
96 the flip-flop index.
97 is_angular: specifies whether `data` is directional data (e.g. wind
98 direction).
99 **selections: Additional keyword arguments specify
100 subsets to draw from the dimension `sampling_dim` of the supplied `data`
101 before calculation of the flip_flop index. e.g. days123=[1, 2, 3]
103 Returns:
104 If `selections` are not supplied: An xarray.DataArray, the Flip-flop
105 Index by collapsing the dimension `sampling_dim`.
107 If `selections` are supplied: An xarray.Dataset. Each data variable
108 is a supplied key-word argument, and corresponds to selecting the
109 values specified from `sampling_dim` of `data`. The Flip-flop Index
110 is calculated for each of these selections.
112 Notes:
114 .. math::
116 \\text{{Flip-Flop Index}} = \\frac{{1}}{{N-2}}
117 \\left [
118 \\left(\\sum\\limits_{{i=1}}^{{N-1}}|x_i - x_{{i+1}}|\\right)
119 - \\left(\\max_{{j}}\\{{x_j\\}} - \\min_{{j}}\\{{x_j\\}}\\right)
120 \\right ]
122 Where :math:`N` is the number of data points, and :math:`x_i` is the
123 :math:`i^{{\\text{{th}}}}` data point.
125 Examples:
126 >>> data = xr.DataArray([50, 20, 40, 80], coords={{'lead_day': [1, 2, 3, 4]}})
128 >>> flip_flop_index(data, 'lead_day')
129 <xarray.DataArray ()>
130 array(15.0)
131 Attributes:
132 sampling_dim: lead_day
134 >>> flip_flop_index(data, 'lead_day', days123=[1, 2, 3], all_days=[1, 2, 3, 4])
135 <xarray.Dataset>
136 Dimensions: ()
137 Coordinates:
138 *empty*
139 Data variables:
140 days123 float64 20.0
141 all_days float64 15.0
142 Attributes:
143 selections: {{'days123': [1, 2, 3], 'all_days': [1, 2, 3, 4]}}
144 sampling_dim: lead_day
146 """
148 if not selections and isinstance(data, xr.DataArray):
149 result = _flip_flop_index(data, sampling_dim, is_angular=is_angular)
150 else:
151 result = xr.Dataset()
152 result.attrs["selections"] = selections
153 for key, data_subset in iter_selections(data, sampling_dim, **selections):
154 result[key] = _flip_flop_index(data_subset, sampling_dim, is_angular=is_angular)
155 result.attrs["sampling_dim"] = sampling_dim
157 return result
160# DataArray input types lead to DataArray output types
161@overload
162def iter_selections(
163 data: xr.DataArray, sampling_dim: str, **selections: Optional[Iterable[int]]
164) -> Generator[tuple[str, xr.DataArray], None, None]:
165 ...
168# Dataset input types load to Dataset output types
169@overload
170def iter_selections(
171 data: xr.Dataset, sampling_dim: str, **selections: Optional[Iterable[int]]
172) -> Generator[tuple[str, xr.Dataset], None, None]:
173 ...
176def iter_selections(
177 data: XarrayLike, sampling_dim: str, **selections: Optional[Iterable[int]]
178) -> Generator[tuple[str, XarrayLike], None, None]:
179 """
180 Selects subsets of data along dimension sampling_dim according to
181 `selections`.
183 Args:
184 data: The data to sample from.
185 sampling_dim: The dimension from which to sample.
186 selections: Each supplied keyword corresponds to a
187 selection of `data` from the dimensions `sampling_dim`. The
188 key is the first element of the yielded tuple.
190 Yields:
191 A tuple (key, data_subset), where key is the supplied `**selections`
192 keyword, and data_subset is the `data` at the values along
193 `sampling_dim` specified by `**selections`.
195 Raises:
196 KeyError: values in selections are not in data[sampling_dim]
198 Examples:
199 >>> data = xr.DataArray(
200 ... [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7],
201 ... coords={'lead_day': [1, 2, 3, 4, 5, 6, 7]}
202 ... )
203 >>> for key, data_subset in iter_selections(
204 ... data, 'lead_day', days123=[1, 2, 3], all_days=[1, 2, 3, 4, 5, 6, 7]
205 ... ):
206 ... print(key, ':', data_subset)
207 all_days : <xarray.DataArray (lead_day: 7)>
208 array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.7])
209 Coordinates:
210 * lead_day (lead_day) int64 1 2 3 4 5 6 7
211 days123 : <xarray.DataArray (lead_day: 3)>
212 array([ 0. , 0.1, 0.2])
213 Coordinates:
214 * lead_day (lead_day) int64 1 2 3
216 """
217 check_dims(data, [sampling_dim], "superset")
219 for key, values in selections.items():
220 try:
221 # Need copy so that attributes added in _iter_selections_with_attrs
222 # don't affect the whole dataframe but just the subset
223 data_subset = data.sel({sampling_dim: values}).copy(deep=False)
224 except KeyError as ex:
225 raise KeyError(
226 f"for `selections` item {str({key: values})}, not all values found in " f"dimension '{sampling_dim}'",
227 ) from ex
229 yield key, data_subset
232def encompassing_sector_size(data: xr.DataArray, dims: Sequence[str], skipna: bool = False) -> xr.DataArray:
233 """
234 Calculates the minimum angular distance which encompasses all data points
235 within an xarray.DataArray along a specified dimension. Assumes data is in
236 degrees.
237 Only one dimension may be collapsed each time, so length of dims must be of
238 length one less than the length of data.dims otherwise an exception will be
239 raised.
241 Args:
242 data: direction data in degrees
243 dims: Strings corresponding to the dimensions in the input
244 xarray data objects that we wish to preserve in the output. All other
245 dimensions in the input data objects are collapsed.
246 skipna: specifies whether to ignore nans in the data. If False
247 (default), will return a nan if one or more nans are present
249 Returns:
250 an xarray.DataArray of minimum encompassing sector sizes with
251 dimensions `dims`.
253 Raises:
254 scores.utils.DimensionError: raised if
256 - the set of data dimensions is not a proper superset of `dims`
257 - dimension to be collapsed isn't 1
258 """
259 check_dims(data, dims, "proper superset")
260 dims_to_collapse = dims_complement(data, dims)
261 if len(dims_to_collapse) != 1:
262 raise DimensionError("can only collapse one dimension")
263 dim_to_collapse = dims_to_collapse[0]
264 axis_to_collapse = data.get_axis_num(dim_to_collapse)
265 values = _encompassing_sector_size_np(
266 data=data.values,
267 axis_to_collapse=axis_to_collapse,
268 skipna=skipna,
269 )
270 new_dims = [dim for dim in data.dims if dim in dims]
271 coords = [data.coords[dim] for dim in new_dims]
272 result = xr.DataArray(values, dims=new_dims, coords=coords)
273 return result
276@np.errstate(invalid="ignore")
277def _encompassing_sector_size_np(
278 data: np.ndarray, axis_to_collapse: Union[int, tuple[int, ...]] = 0, skipna: bool = False
279) -> np.ndarray:
280 """
281 Calculates the minimum angular distance which encompasses all data points
282 within an xarray.DataArray along a specified dimension. Assumes data is in
283 degrees.
285 Args:
286 data: direction data in degrees
287 axis_to_collapse: number of axis to collapse in data, the numpy.ndarray
288 skipna: specifies whether to ignore nans in the data. If False
289 (default), will return a nan if one or more nans are present
291 Returns:
292 an numpy.ndarray of minimum encompassing sector sizes
293 """
294 # code will be simpler, and makes broadcasting easier if we are dealing
295 # with the axis=0
296 data = np.moveaxis(data, axis_to_collapse, 0)
297 # make data in range [0, 360)
298 data = data % 360
299 data = np.sort(data, axis=0)
300 if skipna:
301 # rotate so one angle is at zero, then we can replace Nans with zeroes
302 if data.ndim == 1:
303 data = (data - data[0]) % 360
304 else:
305 data = (data - data[0, :]) % 360
306 all_nans = np.all(np.isnan(data), axis=0)
307 # if all NaNs, we don't want to change, and still want end result to be
308 # NaN.
309 # if some NaNs but not all, then set to zero, which will just end up
310 # being a duplicate value after we've rotated so at least one zero value
311 data[np.isnan(data) & ~all_nans] = 0
312 # make a back-shifted copy of `data`
313 data_rolled = np.roll(data, shift=-1, axis=0)
314 # determine absolute angular difference between all adjacent angles
315 angular_diffs = np.abs(data - data_rolled)
316 angular_diffs = np.where(
317 # nan_to_num so doesn't complain about comparing with NaN
318 np.nan_to_num(angular_diffs) > 180,
319 360 - angular_diffs,
320 angular_diffs,
321 )
322 # the max difference between adjacent angles, or its complement, is
323 # equivalent to the smallest sector size which encompasses all angles in
324 # `data`.
325 max_args = np.argmax(angular_diffs, axis=0)
326 max_indices = tuple([max_args] + list(np.indices(max_args.shape)))
327 # determine the first of the two angles resulting in max difference
328 first_bounding_angle = data[max_indices]
329 # rotate all angles by `first_bounding_angle` (anticlockwise), and make any
330 # resulting negative angles positive. This ensures that the rotated
331 # `first_bounding_angle` is 0, and is therefore the smallest angle in the
332 # rotated set
333 rotated = (data_rolled - first_bounding_angle) % 360
334 # determine the second of the two angles, now rotated, resulting in max
335 # difference
336 second_bound_angle_rotated = rotated[max_indices]
337 max_of_rotated = np.max(rotated, axis=0)
338 # if `second_bounding_angle_rotated` is the largest element, then
339 # sector size is the clockwise span of 0 -> `second_bounding_angle_rotated`,
340 # otherwise it's the anticlockwise span
341 result = np.where(
342 max_of_rotated == second_bound_angle_rotated,
343 second_bound_angle_rotated,
344 360 - second_bound_angle_rotated,
345 )
346 # if there are only one or two distinct angles, return the unique difference
347 # calculated
348 n_unique_angles = (angular_diffs != 0).sum(axis=0)
349 result = np.where(n_unique_angles <= 2, np.max(angular_diffs, axis=0), result)
350 return result
353def flip_flop_index_proportion_exceeding(
354 data: xr.DataArray,
355 sampling_dim: str,
356 thresholds: Iterable,
357 is_angular: bool = False,
358 preserve_dims: FlexibleDimensionTypes = None,
359 reduce_dims: FlexibleDimensionTypes = None,
360 **selections: Iterable[int],
361):
362 """
363 Calculates the flip-flop index and returns the proportion exceeding
364 (or equal to) each of the supplied `thresholds`.
366 Args:
367 data: Data from which to draw subsets.
368 sampling_dim: The name of the dimension along which to calculate
369 thresholds: The proportion of Flip-Flop index results
370 equal to or exceeding these thresholds will be calculated.
371 the flip-flop index.
372 is_angular: specifies whether `data` is directional data (e.g. wind
373 direction).
374 reduce_dims: Dimensions to reduce.
375 preserve_dims: Dimensions to preserve.
376 **selections: Additional keyword arguments specify
377 subsets to draw from the dimension `sampling_dim` of the supplied `data`
378 before calculation of the flip_flop index. e.g. days123=[1, 2, 3]
379 Returns:
380 If `selections` are not supplied - An xarray.DataArray with dimensions
381 `dims` + 'threshold'. The DataArray is the proportion of the Flip-flop
382 Index calculated by collapsing dimension `sampling_dim` exceeding or
383 equal to `thresholds`.
385 If `selections` are supplied - An xarray.Dataset with dimensions `dims`
386 + 'threshold'. There is a data variable for each keyword in
387 `selections`, and corresponds to the Flip-Flop Index proportion
388 exceeding for the subset of data specified by the keyword values.
390 Examples:
391 >>> data = xr.DataArray(
392 ... [[50, 20, 40, 80], [10, 50, 10, 100], [0, 30, 20, 50]],
393 ... dims=['station_number', 'lead_day'],
394 ... coords=[[10001, 10002, 10003], [1, 2, 3, 4]]
395 ... )
397 >>> flip_flop_index_proportion_exceeding(data, 'lead_day', [20])
398 <xarray.DataArray (threshold: 1)>
399 array([ 0.33333333])
400 Coordinates:
401 * threshold (threshold) int64 20
402 Attributes:
403 sampling_dim: lead_day
405 >>> flip_flop_index_proportion_exceeding(
406 ... data, 'lead_day', [20], days123=[1, 2, 3], all_days=[1, 2, 3, 4]
407 ... )
408 <xarray.Dataset>
409 Dimensions: (threshold: 1)
410 Coordinates:
411 * threshold (threshold) int64 20
412 Data variables:
413 days123 (threshold) float64 0.6667
414 all_days (threshold) float64 0.3333
415 Attributes:
416 selections: {{'days123': [1, 2, 3], 'all_days': [1, 2, 3, 4]}}
417 sampling_dim: lead_day
419 See also:
420 `scores.continuous.flip_flop_index`
422 """
423 if preserve_dims is not None and sampling_dim in list(preserve_dims):
424 raise DimensionError(
425 f"`sampling_dim`: '{sampling_dim}' must not be in dimensions to preserve "
426 f"`preserve_dims`: {list(preserve_dims)}"
427 )
428 if reduce_dims is not None and sampling_dim in list(reduce_dims):
429 raise DimensionError(
430 f"`sampling_dim`: '{sampling_dim}' must not be in dimensions to reduce "
431 f"`reduce_dims`: {list(reduce_dims)}"
432 )
433 # calculate the flip-flop index
434 flip_flop_data = flip_flop_index(data, sampling_dim, is_angular=is_angular, **selections)
435 # calculate the proportion exceeding each threshold
436 flip_flop_exceeding = proportion_exceeding(flip_flop_data, thresholds, reduce_dims, preserve_dims)
437 # overwrite the attributes
438 flip_flop_exceeding.attrs = flip_flop_data.attrs
440 return flip_flop_exceeding