Coverage for src/scores/probability/roc_impl.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2Implementation of Reciever Operating Characteristic (ROC) calculations 

3""" 

4from collections.abc import Iterable, Sequence 

5from typing import Optional 

6 

7import numpy as np 

8import xarray as xr 

9 

10from scores.categorical import probability_of_detection, probability_of_false_detection 

11from scores.processing import binary_discretise 

12from scores.utils import gather_dimensions 

13 

14 

15def roc_curve_data( # pylint: disable=too-many-arguments 

16 fcst: xr.DataArray, 

17 obs: xr.DataArray, 

18 thresholds: Iterable[float], 

19 reduce_dims: Optional[Sequence[str]] = None, 

20 preserve_dims: Optional[Sequence[str]] = None, 

21 weights: Optional[xr.DataArray] = None, 

22 check_args: bool = True, 

23) -> xr.Dataset: 

24 """ 

25 Calculates data required for plotting a Receiver (Relative) Operating Characteristic (ROC) 

26 curve including the AUC. The ROC curve is used as a way to measure the discrimination 

27 ability of a particular forecast. 

28 

29 The AUC is the probability that the forecast probability of a random event is higher 

30 than the forecast probability of a random non-event. 

31 

32 Args: 

33 fcst: An array of probabilistic forecasts for a binary event in the range [0, 1]. 

34 obs: An array of binary values where 1 is an event and 0 is a non-event. 

35 thresholds: Monotonic increasing values between 0 and 1, the thresholds at and 

36 above which to convert the probabilistic forecast to a value of 1 (an 'event') 

37 reduce_dims: Optionally specify which dimensions to reduce when 

38 calculating the ROC curve data. All other dimensions will be preserved. As a 

39 special case, 'all' will allow all dimensions to be reduced. Only one 

40 of `reduce_dims` and `preserve_dims` can be supplied. The default behaviour 

41 if neither are supplied is to reduce all dims. 

42 preserve_dims: Optionally specify which dimensions to preserve 

43 when calculating ROC curve data. All other dimensions will be reduced. 

44 As a special case, 'all' will allow all dimensions to be 

45 preserved. In this case, the result will be in the same 

46 shape/dimensionality as the forecast, and the values will be 

47 the ROC curve at each point (i.e. single-value comparison 

48 against observed) for each threshold, and the forecast and observed dimensions 

49 must match precisely. Only one of `reduce_dims` and `preserve_dims` can be 

50 supplied. The default behaviour if neither are supplied is to reduce all dims. 

51 weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude, 

52 by population, custom). 

53 check_args: Checks if `obs` data only contains values in the set 

54 {0, 1, np.nan}. You may want to skip this check if you are sure about your 

55 input data and want to improve the performance when working with dask. 

56 

57 Returns: 

58 An xarray.Dataset with data variables: 

59 

60 - 'POD' (the probability of detection) 

61 - 'POFD' (the probability of false detection) 

62 - 'AUC' (the area under the ROC curve) 

63 

64 `POD` and `POFD` have dimensions `dims` + 'threshold', while `AUC` has 

65 dimensions `dims`. 

66 

67 Notes: 

68 The probabilistic `fcst` is converted to a deterministic forecast 

69 for each threshold in `thresholds`. If a value in `fcst` is greater 

70 than or equal to the threshold, then it is converted into a 

71 'forecast event' (fcst = 1), and a 'forecast non-event' (fcst = 0) 

72 otherwise. The probability of detection (POD) and probability of false 

73 detection (POFD) are calculated for the converted forecast. From the 

74 POD and POFD data, the area under the ROC curve is calculated. 

75 

76 Ideally concave ROC curves should be generated rather than traditional 

77 ROC curves. 

78 

79 Raises: 

80 ValueError: if `fcst` contains values outside of the range [0, 1] 

81 ValueError: if `obs` contains non-nan values not in the set {0, 1} 

82 ValueError: if 'threshold' is a dimension in `fcst`. 

83 ValueError: if values in `thresholds` are not montonic increasing or are outside 

84 the range [0, 1] 

85 """ 

86 if check_args: 

87 if fcst.max().item() > 1 or fcst.min().item() < 0: 

88 raise ValueError("`fcst` contains values outside of the range [0, 1]") 

89 

90 if np.max(thresholds) > 1 or np.min(thresholds) < 0: 

91 raise ValueError("`thresholds` contains values outside of the range [0, 1]") 

92 

93 if not np.all(np.array(thresholds)[1:] >= np.array(thresholds)[:-1]): 

94 raise ValueError("`thresholds` is not monotonic increasing between 0 and 1") 

95 

96 # make a discrete forecast for each threshold in thresholds 

97 # discrete_fcst has an extra dimension 'threshold' 

98 discrete_fcst = binary_discretise(fcst, thresholds, ">=") 

99 

100 all_dims = set(fcst.dims).union(set(obs.dims)) 

101 final_reduce_dims = gather_dimensions(fcst.dims, obs.dims, reduce_dims, preserve_dims) 

102 final_preserve_dims = all_dims - set(final_reduce_dims) # type: ignore 

103 auc_dims = () if final_preserve_dims is None else tuple(final_preserve_dims) 

104 final_preserve_dims = auc_dims + ("threshold",) # type: ignore[assignment] 

105 

106 pod = probability_of_detection( 

107 discrete_fcst, obs, preserve_dims=final_preserve_dims, weights=weights, check_args=check_args 

108 ) 

109 

110 pofd = probability_of_false_detection( 

111 discrete_fcst, obs, preserve_dims=final_preserve_dims, weights=weights, check_args=check_args 

112 ) 

113 

114 # Need to ensure ordering of dims is consistent for xr.apply_ufunc 

115 pod = pod.transpose(*final_preserve_dims) 

116 pofd = pofd.transpose(*final_preserve_dims) 

117 

118 auc = -1 * xr.apply_ufunc( 

119 np.trapz, 

120 pod, 

121 pofd, 

122 input_core_dims=[pod.dims, pofd.dims], 

123 output_core_dims=[auc_dims], 

124 dask="parallelized", 

125 ) 

126 

127 return xr.Dataset({"POD": pod, "POFD": pofd, "AUC": auc})