Coverage for src/scores/probability/roc_impl.py: 100%
27 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2Implementation of Reciever Operating Characteristic (ROC) calculations
3"""
4from collections.abc import Iterable, Sequence
5from typing import Optional
7import numpy as np
8import xarray as xr
10from scores.categorical import probability_of_detection, probability_of_false_detection
11from scores.processing import binary_discretise
12from scores.utils import gather_dimensions
15def roc_curve_data( # pylint: disable=too-many-arguments
16 fcst: xr.DataArray,
17 obs: xr.DataArray,
18 thresholds: Iterable[float],
19 reduce_dims: Optional[Sequence[str]] = None,
20 preserve_dims: Optional[Sequence[str]] = None,
21 weights: Optional[xr.DataArray] = None,
22 check_args: bool = True,
23) -> xr.Dataset:
24 """
25 Calculates data required for plotting a Receiver (Relative) Operating Characteristic (ROC)
26 curve including the AUC. The ROC curve is used as a way to measure the discrimination
27 ability of a particular forecast.
29 The AUC is the probability that the forecast probability of a random event is higher
30 than the forecast probability of a random non-event.
32 Args:
33 fcst: An array of probabilistic forecasts for a binary event in the range [0, 1].
34 obs: An array of binary values where 1 is an event and 0 is a non-event.
35 thresholds: Monotonic increasing values between 0 and 1, the thresholds at and
36 above which to convert the probabilistic forecast to a value of 1 (an 'event')
37 reduce_dims: Optionally specify which dimensions to reduce when
38 calculating the ROC curve data. All other dimensions will be preserved. As a
39 special case, 'all' will allow all dimensions to be reduced. Only one
40 of `reduce_dims` and `preserve_dims` can be supplied. The default behaviour
41 if neither are supplied is to reduce all dims.
42 preserve_dims: Optionally specify which dimensions to preserve
43 when calculating ROC curve data. All other dimensions will be reduced.
44 As a special case, 'all' will allow all dimensions to be
45 preserved. In this case, the result will be in the same
46 shape/dimensionality as the forecast, and the values will be
47 the ROC curve at each point (i.e. single-value comparison
48 against observed) for each threshold, and the forecast and observed dimensions
49 must match precisely. Only one of `reduce_dims` and `preserve_dims` can be
50 supplied. The default behaviour if neither are supplied is to reduce all dims.
51 weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
52 by population, custom).
53 check_args: Checks if `obs` data only contains values in the set
54 {0, 1, np.nan}. You may want to skip this check if you are sure about your
55 input data and want to improve the performance when working with dask.
57 Returns:
58 An xarray.Dataset with data variables:
60 - 'POD' (the probability of detection)
61 - 'POFD' (the probability of false detection)
62 - 'AUC' (the area under the ROC curve)
64 `POD` and `POFD` have dimensions `dims` + 'threshold', while `AUC` has
65 dimensions `dims`.
67 Notes:
68 The probabilistic `fcst` is converted to a deterministic forecast
69 for each threshold in `thresholds`. If a value in `fcst` is greater
70 than or equal to the threshold, then it is converted into a
71 'forecast event' (fcst = 1), and a 'forecast non-event' (fcst = 0)
72 otherwise. The probability of detection (POD) and probability of false
73 detection (POFD) are calculated for the converted forecast. From the
74 POD and POFD data, the area under the ROC curve is calculated.
76 Ideally concave ROC curves should be generated rather than traditional
77 ROC curves.
79 Raises:
80 ValueError: if `fcst` contains values outside of the range [0, 1]
81 ValueError: if `obs` contains non-nan values not in the set {0, 1}
82 ValueError: if 'threshold' is a dimension in `fcst`.
83 ValueError: if values in `thresholds` are not montonic increasing or are outside
84 the range [0, 1]
85 """
86 if check_args:
87 if fcst.max().item() > 1 or fcst.min().item() < 0:
88 raise ValueError("`fcst` contains values outside of the range [0, 1]")
90 if np.max(thresholds) > 1 or np.min(thresholds) < 0:
91 raise ValueError("`thresholds` contains values outside of the range [0, 1]")
93 if not np.all(np.array(thresholds)[1:] >= np.array(thresholds)[:-1]):
94 raise ValueError("`thresholds` is not monotonic increasing between 0 and 1")
96 # make a discrete forecast for each threshold in thresholds
97 # discrete_fcst has an extra dimension 'threshold'
98 discrete_fcst = binary_discretise(fcst, thresholds, ">=")
100 all_dims = set(fcst.dims).union(set(obs.dims))
101 final_reduce_dims = gather_dimensions(fcst.dims, obs.dims, reduce_dims, preserve_dims)
102 final_preserve_dims = all_dims - set(final_reduce_dims) # type: ignore
103 auc_dims = () if final_preserve_dims is None else tuple(final_preserve_dims)
104 final_preserve_dims = auc_dims + ("threshold",) # type: ignore[assignment]
106 pod = probability_of_detection(
107 discrete_fcst, obs, preserve_dims=final_preserve_dims, weights=weights, check_args=check_args
108 )
110 pofd = probability_of_false_detection(
111 discrete_fcst, obs, preserve_dims=final_preserve_dims, weights=weights, check_args=check_args
112 )
114 # Need to ensure ordering of dims is consistent for xr.apply_ufunc
115 pod = pod.transpose(*final_preserve_dims)
116 pofd = pofd.transpose(*final_preserve_dims)
118 auc = -1 * xr.apply_ufunc(
119 np.trapz,
120 pod,
121 pofd,
122 input_core_dims=[pod.dims, pofd.dims],
123 output_core_dims=[auc_dims],
124 dask="parallelized",
125 )
127 return xr.Dataset({"POD": pod, "POFD": pofd, "AUC": auc})