Coverage for src/scores/categorical/multicategorical_impl.py: 100%
58 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2This module contains methods which may be used for scoring multicategorical forecasts
3"""
4from collections.abc import Sequence
5from typing import Optional, Union
7import numpy as np
8import xarray as xr
10from scores.functions import apply_weights
11from scores.typing import FlexibleDimensionTypes
12from scores.utils import check_dims, gather_dimensions
15def firm( # pylint: disable=too-many-arguments
16 fcst: xr.DataArray,
17 obs: xr.DataArray,
18 risk_parameter: float,
19 categorical_thresholds: Sequence[float],
20 threshold_weights: Sequence[Union[float, xr.DataArray]],
21 discount_distance: Optional[float] = 0,
22 reduce_dims: FlexibleDimensionTypes = None,
23 preserve_dims: FlexibleDimensionTypes = None,
24 weights: Optional[xr.DataArray] = None,
25 threshold_assignment: Optional[str] = "lower",
26) -> xr.Dataset:
27 """
28 Calculates the FIxed Risk Multicategorical (FIRM) score including the
29 underforecast and overforecast penalties.
31 `categorical_thresholds` and `threshold_weights` must be the same length.
33 Args:
34 fcst: An array of real-valued forecasts that we want to treat categorically.
35 obs: An array of real-valued observations that we want to treat categorically.
36 risk_parameter: Risk parameter (alpha) for the FIRM score. The value must
37 satisfy 0 < `risk_parameter` < 1.
38 categorical_thresholds: Category thresholds (thetas) to delineate the
39 categories.
40 threshold_weights: Weights that specify the relative importance of forecasting on
41 the correct side of each category threshold. Either a positive
42 float can be supplied for each categorical threshold or an
43 xr.DataArray (with no negative values) can be provided for each
44 categorical threshold as long as its dims are a subset of `obs` dims.
45 NaN values are allowed in the xr.DataArray. For each NaN value at a
46 given coordinate, the FIRM score will be NaN at that coordinate,
47 before dims are collapsed.
48 discount_distance: An optional discounting distance parameter which
49 satisfies `discount_distance` >= 0 such that the cost of misses and
50 false alarms are discounted whenever the observation is within
51 distance `discount_distance` of the forecast category. A value of 0
52 will not apply any discounting.
53 reduce_dims: Optionally specify which dimensions to reduce when
54 calculating the FIRM score. All other dimensions will be preserved. As a
55 special case, 'all' will allow all dimensions to be reduced. Only one
56 of `reduce_dims` and `preserve_dims` can be supplied. The default behaviour
57 if neither are supplied is to reduce all dims.
58 preserve_dims: Optionally specify which dimensions to preserve
59 when calculating FIRM. All other dimensions will be reduced.
60 As a special case, 'all' will allow all dimensions to be
61 preserved. In this case, the result will be in the same
62 shape/dimensionality as the forecast, and the errors will be
63 the FIRM score at each point (i.e. single-value comparison
64 against observed), and the forecast and observed dimensions
65 must match precisely. Only one of `reduce_dims` and `preserve_dims` can be
66 supplied. The default behaviour if neither are supplied is to reduce all dims.
67 weights: Optionally provide an array for weighted averaging (e.g. by area, by latitude,
68 by population, custom)
69 threshold_assignment: Specifies whether the intervals defining the categories are
70 left or right closed. That is whether the decision threshold is included in
71 the upper (left closed) or lower (right closed) category. Defaults to "lower".
73 Returns:
74 An xarray Dataset with data vars:
76 * firm_score: A score for a single category for each coord based on
77 the FIRM framework.
78 * overforecast_penalty: Penalty for False Alarms.
79 * underforecast_penalty: Penalty for Misses.
81 Raises:
82 ValueError: if `len(categorical_thresholds) < 1`.
83 ValueError: if `categorical_thresholds` and `threshold_weights` lengths
84 are not equal.
85 ValueError: if `risk_parameter` <= 0 or >= 1.
86 ValueError: if any values in `threshold_weights` are <= 0.
87 ValueError: if `discount_distance` is not None and < 0.
88 scores.utils.DimensionError: if `threshold_weights` is a list of xr.DataArrays
89 and if the dimensions of these xr.DataArrays is not a subset of the `obs` dims.
91 Note:
92 Setting `discount distance` to None or 0, will mean that no
93 discounting is applied. This means that errors will be penalised
94 strictly categorically.
96 Setting `discount distance` to np.inf means that the cost of a miss
97 is always proportional to the distance of the observation from the
98 threshold, and similarly for false alarms.
100 References:
101 Taggart, R., Loveday, N. and Griffiths, D., 2022. A scoring framework for tiered
102 warnings and multicategorical forecasts based on fixed risk measures. Quarterly
103 Journal of the Royal Meteorological Society, 148(744), pp.1389-1406.
104 """
105 _check_firm_inputs(
106 obs, risk_parameter, categorical_thresholds, threshold_weights, discount_distance, threshold_assignment
107 )
108 total_score = []
109 for categorical_threshold, weight in zip(categorical_thresholds, threshold_weights):
110 score = weight * _single_category_score(
111 fcst, obs, risk_parameter, categorical_threshold, discount_distance, threshold_assignment
112 )
113 total_score.append(score)
114 summed_score = sum(total_score)
115 reduce_dims = gather_dimensions(fcst.dims, obs.dims, reduce_dims, preserve_dims) # type: ignore[assignment]
116 summed_score = apply_weights(summed_score, weights)
117 score = summed_score.mean(dim=reduce_dims)
119 return score
122def _check_firm_inputs(
123 obs, risk_parameter, categorical_thresholds, threshold_weights, discount_distance, threshold_assignment
124):
125 """
126 Checks that the FIRM inputs are suitable
127 """
128 if len(categorical_thresholds) < 1:
129 raise ValueError("`categorical_thresholds` must have at least one threshold")
131 if not len(categorical_thresholds) == len(threshold_weights):
132 raise ValueError("The length of `categorical_thresholds` and `weights` must be equal")
133 if risk_parameter <= 0 or risk_parameter >= 1:
134 raise ValueError("0 < `risk_parameter` < 1 must be satisfied")
136 for count, weight in enumerate(threshold_weights):
137 if isinstance(weight, xr.DataArray):
138 check_dims(weight, obs.dims, "subset")
139 if np.any(weight <= 0):
140 raise ValueError(
141 f"""
142 No values <= 0 are allowed in `weights`. At least one
143 negative value was found in index {count} of `weights`
144 """
145 )
146 elif weight <= 0:
147 raise ValueError("All values in `weights` must be > 0")
149 if discount_distance < 0:
150 raise ValueError("`discount_distance` must be >= 0")
152 if threshold_assignment not in ["upper", "lower"]:
153 raise ValueError(""" `threshold_assignment` must be either \"upper\" or \"lower\" """)
156def _single_category_score(
157 fcst: xr.DataArray,
158 obs: xr.DataArray,
159 risk_parameter: float,
160 categorical_threshold: float,
161 discount_distance: Optional[float] = None,
162 threshold_assignment: Optional[str] = "lower",
163) -> xr.Dataset:
164 """
165 Calculates the score for a single category for the `firm` metric at each
166 coord. Under-forecast and over-forecast penalties are also calculated
168 Args:
169 fcst: An array of real-valued forecasts.
170 obs: An array of real-valued observations.
171 risk_parameter: Risk parameter (alpha) for the FIRM score.
172 Must satisfy 0 < risk parameter < 1. Note that `firm` checks this
173 rather than this function.
174 categorical_threshold: Category threshold (theta) to delineate the
175 category.
176 discount_distance: A discounting distance parameter which must
177 be >= 0 such that the cost of misses and false alarms are
178 discounted whenever the observation is within distance
179 `discount_distance` of the forecast category. A value of 0
180 will not a apply any discounting.
181 threshold_assignment: Specifies whether the intervals defining the categories are
182 left or right closed. That is whether the decision threshold is included in
183 the upper (left closed) or lower (right closed) category. Defaults to "lower".
185 Returns:
186 An xarray Dataset with data vars:
188 * firm_score: a score for a single category for each coord
189 based on the FIRM framework. All dimensions are preserved.
190 * overforecast_penalty: Penalty for False Alarms.
191 * underforecast_penalty: Penalty for Misses.
192 """
193 # pylint: disable=unbalanced-tuple-unpacking
194 fcst, obs = xr.align(fcst, obs)
196 if threshold_assignment == "lower":
197 # False Alarms
198 condition1 = (obs <= categorical_threshold) & (categorical_threshold < fcst)
199 # Misses
200 condition2 = (fcst <= categorical_threshold) & (categorical_threshold < obs)
201 else:
202 # False Alarms
203 condition1 = (obs < categorical_threshold) & (categorical_threshold <= fcst)
204 # Misses
205 condition2 = (fcst < categorical_threshold) & (categorical_threshold <= obs)
207 # Bring back NaNs
208 condition1 = condition1.where(~np.isnan(fcst))
209 condition1 = condition1.where(~np.isnan(obs))
210 condition2 = condition2.where(~np.isnan(fcst))
211 condition2 = condition2.where(~np.isnan(obs))
213 if discount_distance:
214 scale_1 = np.minimum(categorical_threshold - obs, discount_distance)
215 scale_2 = np.minimum(obs - categorical_threshold, discount_distance)
216 else:
217 scale_1 = 1
218 scale_2 = 1
220 overforecast_penalty = (1 - risk_parameter) * scale_1 * condition1
221 underforecast_penalty = risk_parameter * scale_2 * condition2
222 firm_score = overforecast_penalty + underforecast_penalty
224 score = xr.Dataset(
225 {
226 "firm_score": firm_score,
227 "overforecast_penalty": overforecast_penalty,
228 "underforecast_penalty": underforecast_penalty,
229 }
230 )
231 score = score.transpose(*fcst.dims)
232 return score