Coverage for src/scores/categorical/multicategorical_impl.py: 100%

58 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2This module contains methods which may be used for scoring multicategorical forecasts 

3""" 

4from collections.abc import Sequence 

5from typing import Optional, Union 

6 

7import numpy as np 

8import xarray as xr 

9 

10from scores.functions import apply_weights 

11from scores.typing import FlexibleDimensionTypes 

12from scores.utils import check_dims, gather_dimensions 

13 

14 

15def firm( # pylint: disable=too-many-arguments 

16 fcst: xr.DataArray, 

17 obs: xr.DataArray, 

18 risk_parameter: float, 

19 categorical_thresholds: Sequence[float], 

20 threshold_weights: Sequence[Union[float, xr.DataArray]], 

21 discount_distance: Optional[float] = 0, 

22 reduce_dims: FlexibleDimensionTypes = None, 

23 preserve_dims: FlexibleDimensionTypes = None, 

24 weights: Optional[xr.DataArray] = None, 

25 threshold_assignment: Optional[str] = "lower", 

26) -> xr.Dataset: 

27 """ 

28 Calculates the FIxed Risk Multicategorical (FIRM) score including the 

29 underforecast and overforecast penalties. 

30 

31 `categorical_thresholds` and `threshold_weights` must be the same length. 

32 

33 Args: 

34 fcst: An array of real-valued forecasts that we want to treat categorically. 

35 obs: An array of real-valued observations that we want to treat categorically. 

36 risk_parameter: Risk parameter (alpha) for the FIRM score. The value must 

37 satisfy 0 < `risk_parameter` < 1. 

38 categorical_thresholds: Category thresholds (thetas) to delineate the 

39 categories. 

40 threshold_weights: Weights that specify the relative importance of forecasting on 

41 the correct side of each category threshold. Either a positive 

42 float can be supplied for each categorical threshold or an 

43 xr.DataArray (with no negative values) can be provided for each 

44 categorical threshold as long as its dims are a subset of `obs` dims. 

45 NaN values are allowed in the xr.DataArray. For each NaN value at a 

46 given coordinate, the FIRM score will be NaN at that coordinate, 

47 before dims are collapsed. 

48 discount_distance: An optional discounting distance parameter which 

49 satisfies `discount_distance` >= 0 such that the cost of misses and 

50 false alarms are discounted whenever the observation is within 

51 distance `discount_distance` of the forecast category. A value of 0 

52 will not apply any discounting. 

53 reduce_dims: Optionally specify which dimensions to reduce when 

54 calculating the FIRM score. All other dimensions will be preserved. As a 

55 special case, 'all' will allow all dimensions to be reduced. Only one 

56 of `reduce_dims` and `preserve_dims` can be supplied. The default behaviour 

57 if neither are supplied is to reduce all dims. 

58 preserve_dims: Optionally specify which dimensions to preserve 

59 when calculating FIRM. All other dimensions will be reduced. 

60 As a special case, 'all' will allow all dimensions to be 

61 preserved. In this case, the result will be in the same 

62 shape/dimensionality as the forecast, and the errors will be 

63 the FIRM score at each point (i.e. single-value comparison 

64 against observed), and the forecast and observed dimensions 

65 must match precisely. Only one of `reduce_dims` and `preserve_dims` can be 

66 supplied. The default behaviour if neither are supplied is to reduce all dims. 

67 weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude, 

68 by population, custom) 

69 threshold_assignment: Specifies whether the intervals defining the categories are 

70 left or right closed. That is whether the decision threshold is included in 

71 the upper (left closed) or lower (right closed) category. Defaults to "lower". 

72 

73 Returns: 

74 An xarray Dataset with data vars: 

75 

76 * firm_score: A score for a single category for each coord based on 

77 the FIRM framework. 

78 * overforecast_penalty: Penalty for False Alarms. 

79 * underforecast_penalty: Penalty for Misses. 

80 

81 Raises: 

82 ValueError: if `len(categorical_thresholds) < 1`. 

83 ValueError: if `categorical_thresholds` and `threshold_weights` lengths 

84 are not equal. 

85 ValueError: if `risk_parameter` <= 0 or >= 1. 

86 ValueError: if any values in `threshold_weights` are <= 0. 

87 ValueError: if `discount_distance` is not None and < 0. 

88 scores.utils.DimensionError: if `threshold_weights` is a list of xr.DataArrays 

89 and if the dimensions of these xr.DataArrays is not a subset of the `obs` dims. 

90 

91 Note: 

92 Setting `discount distance` to None or 0, will mean that no 

93 discounting is applied. This means that errors will be penalised 

94 strictly categorically. 

95 

96 Setting `discount distance` to np.inf means that the cost of a miss 

97 is always proportional to the distance of the observation from the 

98 threshold, and similarly for false alarms. 

99 

100 References: 

101 Taggart, R., Loveday, N. and Griffiths, D., 2022. A scoring framework for tiered 

102 warnings and multicategorical forecasts based on fixed risk measures. Quarterly 

103 Journal of the Royal Meteorological Society, 148(744), pp.1389-1406. 

104 """ 

105 _check_firm_inputs( 

106 obs, risk_parameter, categorical_thresholds, threshold_weights, discount_distance, threshold_assignment 

107 ) 

108 total_score = [] 

109 for categorical_threshold, weight in zip(categorical_thresholds, threshold_weights): 

110 score = weight * _single_category_score( 

111 fcst, obs, risk_parameter, categorical_threshold, discount_distance, threshold_assignment 

112 ) 

113 total_score.append(score) 

114 summed_score = sum(total_score) 

115 reduce_dims = gather_dimensions(fcst.dims, obs.dims, reduce_dims, preserve_dims) # type: ignore[assignment] 

116 summed_score = apply_weights(summed_score, weights) 

117 score = summed_score.mean(dim=reduce_dims) 

118 

119 return score 

120 

121 

122def _check_firm_inputs( 

123 obs, risk_parameter, categorical_thresholds, threshold_weights, discount_distance, threshold_assignment 

124): 

125 """ 

126 Checks that the FIRM inputs are suitable 

127 """ 

128 if len(categorical_thresholds) < 1: 

129 raise ValueError("`categorical_thresholds` must have at least one threshold") 

130 

131 if not len(categorical_thresholds) == len(threshold_weights): 

132 raise ValueError("The length of `categorical_thresholds` and `weights` must be equal") 

133 if risk_parameter <= 0 or risk_parameter >= 1: 

134 raise ValueError("0 < `risk_parameter` < 1 must be satisfied") 

135 

136 for count, weight in enumerate(threshold_weights): 

137 if isinstance(weight, xr.DataArray): 

138 check_dims(weight, obs.dims, "subset") 

139 if np.any(weight <= 0): 

140 raise ValueError( 

141 f""" 

142 No values <= 0 are allowed in `weights`. At least one 

143 negative value was found in index {count} of `weights` 

144 """ 

145 ) 

146 elif weight <= 0: 

147 raise ValueError("All values in `weights` must be > 0") 

148 

149 if discount_distance < 0: 

150 raise ValueError("`discount_distance` must be >= 0") 

151 

152 if threshold_assignment not in ["upper", "lower"]: 

153 raise ValueError(""" `threshold_assignment` must be either \"upper\" or \"lower\" """) 

154 

155 

156def _single_category_score( 

157 fcst: xr.DataArray, 

158 obs: xr.DataArray, 

159 risk_parameter: float, 

160 categorical_threshold: float, 

161 discount_distance: Optional[float] = None, 

162 threshold_assignment: Optional[str] = "lower", 

163) -> xr.Dataset: 

164 """ 

165 Calculates the score for a single category for the `firm` metric at each 

166 coord. Under-forecast and over-forecast penalties are also calculated 

167 

168 Args: 

169 fcst: An array of real-valued forecasts. 

170 obs: An array of real-valued observations. 

171 risk_parameter: Risk parameter (alpha) for the FIRM score. 

172 Must satisfy 0 < risk parameter < 1. Note that `firm` checks this 

173 rather than this function. 

174 categorical_threshold: Category threshold (theta) to delineate the 

175 category. 

176 discount_distance: A discounting distance parameter which must 

177 be >= 0 such that the cost of misses and false alarms are 

178 discounted whenever the observation is within distance 

179 `discount_distance` of the forecast category. A value of 0 

180 will not a apply any discounting. 

181 threshold_assignment: Specifies whether the intervals defining the categories are 

182 left or right closed. That is whether the decision threshold is included in 

183 the upper (left closed) or lower (right closed) category. Defaults to "lower". 

184 

185 Returns: 

186 An xarray Dataset with data vars: 

187 

188 * firm_score: a score for a single category for each coord 

189 based on the FIRM framework. All dimensions are preserved. 

190 * overforecast_penalty: Penalty for False Alarms. 

191 * underforecast_penalty: Penalty for Misses. 

192 """ 

193 # pylint: disable=unbalanced-tuple-unpacking 

194 fcst, obs = xr.align(fcst, obs) 

195 

196 if threshold_assignment == "lower": 

197 # False Alarms 

198 condition1 = (obs <= categorical_threshold) & (categorical_threshold < fcst) 

199 # Misses 

200 condition2 = (fcst <= categorical_threshold) & (categorical_threshold < obs) 

201 else: 

202 # False Alarms 

203 condition1 = (obs < categorical_threshold) & (categorical_threshold <= fcst) 

204 # Misses 

205 condition2 = (fcst < categorical_threshold) & (categorical_threshold <= obs) 

206 

207 # Bring back NaNs 

208 condition1 = condition1.where(~np.isnan(fcst)) 

209 condition1 = condition1.where(~np.isnan(obs)) 

210 condition2 = condition2.where(~np.isnan(fcst)) 

211 condition2 = condition2.where(~np.isnan(obs)) 

212 

213 if discount_distance: 

214 scale_1 = np.minimum(categorical_threshold - obs, discount_distance) 

215 scale_2 = np.minimum(obs - categorical_threshold, discount_distance) 

216 else: 

217 scale_1 = 1 

218 scale_2 = 1 

219 

220 overforecast_penalty = (1 - risk_parameter) * scale_1 * condition1 

221 underforecast_penalty = risk_parameter * scale_2 * condition2 

222 firm_score = overforecast_penalty + underforecast_penalty 

223 

224 score = xr.Dataset( 

225 { 

226 "firm_score": firm_score, 

227 "overforecast_penalty": overforecast_penalty, 

228 "underforecast_penalty": underforecast_penalty, 

229 } 

230 ) 

231 score = score.transpose(*fcst.dims) 

232 return score