Coverage for src/scores/utils.py: 100%
121 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2Contains frequently-used functions of a general nature within scores
3"""
4import inspect
5import warnings
6from collections.abc import Hashable, Iterable, Sequence
7from typing import Optional
9import xarray as xr
11from scores.typing import FlexibleDimensionTypes, XarrayLike
13WARN_ALL_DATA_CONFLICT_MSG = """
14You are requesting to reduce or preserve every dimension by specifying the string 'all'.
15In this case, 'all' is also a named dimension in your data, leading to an ambiguity.
16In order to reduce or preserve the named data dimension, specify ['all'] as a list item
17rather than relying on string interpretation. The program will continue to interpret the
18string as an instruction to reduce or preserve every dimension.
19"""
21ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION = """
22You are requesting to preserve a dimension which does not appear in your data (fcst or obs).
23It is ambiguous how to proceed therefore an exception has been raised instead.
24"""
26ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2 = """
27You are requesting to preserve a dimension which does not appear in your data
28(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been
29raised instead.
30"""
32ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION = """
33You are requesting to reduce a dimension which does not appear in your data (fcst or obs).
34It is ambiguous how to proceed therefore an exception has been raised instead.
35"""
37ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2 = """
38You are requesting to reduce a dimension which does not appear in your data
39(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been
40raised instead.
41"""
43ERROR_OVERSPECIFIED_PRESERVE_REDUCE = """
44You have specified both preserve_dims and reduce_dims. This method doesn't know how
45to properly interpret that, therefore an exception has been raised.
46"""
49class DimensionError(Exception):
50 """
51 Custom exception used when attempting to operate over xarray DataArray or
52 Dataset objects that do not have compatible dimensions.
53 """
56def gather_dimensions( # pylint: disable=too-many-branches
57 fcst_dims: Iterable[Hashable],
58 obs_dims: Iterable[Hashable],
59 reduce_dims: FlexibleDimensionTypes = None,
60 preserve_dims: FlexibleDimensionTypes = None,
61) -> set[Hashable]:
62 """
63 Establish which dimensions to reduce when calculating errors but before taking means.
65 Note: `scores.utils.gather_dimensions` and `scores.utils.gather_dimensions2` will be
66 integrated at some point in the future. `scores.utils.gather_dimensions2` offers
67 more comprehensive and less restrictive dimension checking and should be preferred in
68 the meantime. See `scores.probability.crps_cdf` for an example of
69 `scores.utils.gather_dimensions2` usage.
71 Args:
72 fcst_dims: Forecast dimensions inputs
73 obs_dims: Observation dimensions inputs.
74 reduce_dims: Dimensions to reduce.
75 preserve_dims: Dimensions to preserve.
77 Returns:
78 Dimensions based on optional args.
79 Raises:
80 ValueError: When `preserve_dims and `reduce_dims` are both specified.
82 See also:
83 `scores.utils.gather_dimensions2`
84 """
86 all_dims = set(fcst_dims).union(set(obs_dims))
88 # Handle error conditions related to specified dimensions
89 if preserve_dims is not None and reduce_dims is not None:
90 raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE)
92 # Handle error conditions related to specified dimensions
93 specified = preserve_dims or reduce_dims
94 if specified == "all":
95 if "all" in all_dims:
96 warnings.warn(WARN_ALL_DATA_CONFLICT_MSG)
97 elif specified is not None:
98 if isinstance(specified, str):
99 specified = [specified]
101 if not set(specified).issubset(all_dims):
102 if preserve_dims is not None:
103 raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION)
104 raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION)
106 # Handle preserve_dims case
107 if preserve_dims is not None:
108 if preserve_dims == "all":
109 return set([])
111 if isinstance(preserve_dims, str):
112 preserve_dims = [preserve_dims]
114 reduce_dims = set(all_dims).difference(preserve_dims)
116 # Handle reduce all
117 elif reduce_dims == "all":
118 reduce_dims = set(all_dims)
120 # Handle is reduce_dims and preserve_dims are both None
121 if reduce_dims is None and preserve_dims is None:
122 reduce_dims = set(all_dims)
124 # Handle reduce by string
125 elif isinstance(reduce_dims, str):
126 reduce_dims = set([reduce_dims])
128 # Turn into a set if needed
129 assert reduce_dims is not None
130 reduce_dims = set(reduce_dims)
132 # Reduce by list is the default so no handling needed
133 return reduce_dims
136def gather_dimensions2(
137 fcst: xr.DataArray,
138 obs: xr.DataArray,
139 weights: xr.DataArray = None,
140 reduce_dims: FlexibleDimensionTypes = None,
141 preserve_dims: FlexibleDimensionTypes = None,
142 special_fcst_dims: FlexibleDimensionTypes = None,
143) -> set[Hashable]:
144 """
145 Performs standard dimensions checks for inputs of functions that calculate (mean) scores.
146 Returns a set of the dimensions to reduce.
148 Note: `scores.utils.gather_dimensions` and `scores.utils.gather_dimensions2` will be
149 integrated at some point in the future. `scores.utils.gather_dimensions2` offers
150 more comprehensive and less restrictive dimension checking and should be preferred in
151 the meantime. See `scores.probability.crps_cdf` for an example of
152 `scores.utils.gather_dimensions2` usage.
154 Args:
155 fcst: Forecast data
156 obs: Observation data
157 weights: Weights for calculating a weighted mean of scores
158 reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions.
159 preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions.
160 special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores.
161 Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`.
162 e.g. the ensemble member dimension if calculating CRPS for ensembles, or the
163 threshold dimension of calculating CRPS for CDFs.
165 Returns:
166 Set of dimensions over which to take the mean once the checks are passed.
168 Raises:
169 ValueError: when `preserve_dims and `reduce_dims` are both specified.
170 ValueError: when `special_fcst_dims` is not a subset of `fcst.dims`.
171 ValueError: when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims`
172 contains elements from `special_fcst_dims`.
173 ValueError: when `preserve_dims and `reduce_dims` contain elements not among dimensions
174 of the data (`fcst`, `obs` or `weights`).
176 See also:
177 `scores.utils.gather_dimensions`
178 """
179 all_data_dims = set(fcst.dims).union(set(obs.dims))
180 if weights is not None:
181 all_data_dims = all_data_dims.union(set(weights.dims))
183 # all_scoring_dims is the set of dims remaining after individual scores are computed.
184 all_scoring_dims = all_data_dims.copy()
186 # Handle error conditions related to specified dimensions
187 if preserve_dims is not None and reduce_dims is not None:
188 raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE)
190 specified_dims = preserve_dims or reduce_dims
192 if specified_dims == "all":
193 if "all" in all_data_dims:
194 warnings.warn(WARN_ALL_DATA_CONFLICT_MSG)
195 elif specified_dims is not None:
196 if isinstance(specified_dims, str):
197 specified_dims = [specified_dims]
199 # check that special_fcst_dims are in fcst.dims only
200 if special_fcst_dims is not None:
201 if isinstance(special_fcst_dims, str):
202 special_fcst_dims = [special_fcst_dims]
203 if not set(special_fcst_dims).issubset(set(fcst.dims)):
204 raise ValueError("`special_fcst_dims` must be a subset of `fcst` dimensions")
205 if len(set(obs.dims).intersection(set(special_fcst_dims))) > 0:
206 raise ValueError("`obs.dims` must not contain any `special_fcst_dims`")
207 if weights is not None:
208 if len(set(weights.dims).intersection(set(special_fcst_dims))) > 0:
209 raise ValueError("`weights.dims` must not contain any `special_fcst_dims`")
210 if specified_dims is not None and specified_dims != "all":
211 if len(set(specified_dims).intersection(set(special_fcst_dims))) > 0:
212 raise ValueError("`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`")
213 # remove special_fcst_dims from all_scoring_dims
214 all_scoring_dims = all_scoring_dims.difference(set(special_fcst_dims))
216 if specified_dims is not None and specified_dims != "all":
217 if not set(specified_dims).issubset(all_scoring_dims):
218 if preserve_dims is not None:
219 raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2)
220 raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2)
222 # all errors have been captured, so now return list of dims to reduce
223 if specified_dims is None:
224 return all_scoring_dims
225 if reduce_dims is not None:
226 if reduce_dims == "all":
227 return all_scoring_dims
228 return set(specified_dims)
229 if preserve_dims == "all":
230 return set([])
231 return all_scoring_dims.difference(set(specified_dims))
234def dims_complement(data, dims=None) -> list[str]:
235 """Returns the complement of data.dims and dims
237 Args:
238 data: Input xarray object
239 dims: an Iterable of strings corresponding to dimension names
241 Returns:
242 A sorted list of dimension names, the complement of data.dims and dims
243 """
245 if dims is None:
246 dims = []
248 # check that dims is in data.dims, and that dims is a of a valid form
249 check_dims(data, dims, mode="superset")
251 complement = set(data.dims) - set(dims)
252 return sorted(list(complement))
255def check_dims(xr_data: XarrayLike, expected_dims: Sequence[str], mode: Optional[str] = None):
256 """
257 Checks the dimensions xr_data with expected_dims, according to `mode`.
259 Args:
260 xr_data: if a Dataset is supplied,
261 all of its data variables (DataArray objects) are checked.
262 expected_dims: an Iterable of dimension names.
263 mode: one of 'equal' (default), 'subset' or 'superset'.
264 If 'equal', checks that the data object has the same dimensions
265 as `expected_dims`.
266 If 'subset', checks that the dimensions of the data object is a
267 subset of `expected_dims`.
268 If 'superset', checks that the dimensions of the data object is a
269 superset of `expected_dims`, (i.e. contains `expected_dims`).
270 If 'proper subset', checks that the dimensions of the data object is a
271 subset of `expected_dims`, (i.e. is a subset, but not equal to
272 `expected_dims`).
273 If 'proper superset', checks that the dimensions of the data object
274 is a proper superset of `expected_dims`, (i.e. contains but is not
275 equal to `expected_dims`).
276 If 'disjoint', checks that the dimensions of the data object shares no
277 elements with `expected_dims`.
279 Raises:
280 scores.utils.DimensionError: the dimensions of `xr_data` does
281 not pass the check as specified by `mode`.
282 TypeError: `xr_data` is not an xarray data object.
283 ValueError: `expected_dims` contains duplicate values.
284 ValueError: `expected_dims` cannot be coerced into a set.
285 ValueError: `mode` is not one of 'equal', 'subset', 'superset',
286 'proper subset', 'proper superset', or 'disjoint'
287 """
289 if isinstance(expected_dims, str):
290 raise TypeError(f"Supplied dimensions '{expected_dims}' must be an iterable of strings, not a string itself.")
292 try:
293 dims_set = set(expected_dims)
294 except Exception as exc:
295 raise ValueError(
296 f"Cannot convert supplied dims {expected_dims} into a set. Check debug log for more information."
297 ) from exc
299 if len(dims_set) != len(expected_dims):
300 raise ValueError(f"Supplied dimensions {expected_dims} contains duplicate values.")
302 if not hasattr(xr_data, "dims"):
303 raise DimensionError("Supplied object has no dimensions")
305 # internal functions to check a data array
306 check_modes = {
307 "equal": lambda da, dims_set: set(da.dims) == dims_set,
308 "subset": lambda da, dims_set: set(da.dims) <= dims_set,
309 "superset": lambda da, dims_set: set(da.dims) >= dims_set,
310 "proper subset": lambda da, dims_set: set(da.dims) < dims_set,
311 "proper superset": lambda da, dims_set: set(da.dims) > dims_set,
312 "disjoint": lambda da, dims_set: len(set(da.dims) & dims_set) == 0,
313 }
315 if mode is None:
316 mode = "equal"
317 if mode not in check_modes:
318 raise ValueError(f"No such mode {mode}, mode must be one of: {list(check_modes.keys())}")
320 check_fn = check_modes[mode]
322 # check the dims
323 if not check_fn(xr_data, dims_set):
324 raise DimensionError(
325 f"Dimensions {list(xr_data.dims)} of data object are not {mode} to the "
326 f"dimensions {sorted(list(dims_set))}."
327 )
329 if isinstance(xr_data, xr.Dataset):
330 # every data variable must pass the dims check too!
331 for data_var in xr_data.data_vars:
332 if not check_fn(xr_data[data_var], dims_set):
333 raise DimensionError(
334 f"Dimensions {list(xr_data[data_var].dims)} of data variable "
335 f"'{data_var}' are not {mode} to the dimensions {sorted(dims_set)}"
336 )
339def tmp_coord_name(xr_data: xr.DataArray, count=1) -> str:
340 """
341 Generates temporary coordinate names that are not among the coordinate or dimension
342 names of `xr_data`.
344 Args:
345 xr_data: Input xarray data array
346 count: Number of unique names to generate
348 Returns:
349 If count = 1, a string which is the concatenation of 'new' with all coordinate and
350 dimension names in the input array. (this is the default)
351 If count > 1, a list of such strings, each unique from one another
352 """
353 all_names = ["new"] + list(xr_data.dims) + list(xr_data.coords)
354 result = "".join(all_names)
356 if count == 1:
357 return result
359 results = [str(i) + result for i in range(count)]
360 return results