Coverage for tests/categorical/test_multicategorical.py: 100%
31 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2Contains unit tests for scores.categorical
3"""
4import dask
5import dask.array
6import numpy as np
7import pytest
8import xarray as xr
10from scores.categorical import firm
11from scores.categorical.multicategorical_impl import _single_category_score
12from scores.utils import DimensionError
13from tests.categorical import multicategorical_test_data as mtd
16@pytest.mark.parametrize(
17 ("fcst", "obs", "categorical_threshold", "discount_distance", "threshold_assignment", "expected"),
18 [
19 # Threshold 5, discount = 0, preserve all dims
20 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0, "lower", mtd.EXP_SC_CASE0),
21 # Threshold -200, discount = 0, preserve 1 dim
22 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, -200, 0, "lower", mtd.EXP_SC_CASE1),
23 # Threshold 200, discount = 0, preserve 1 dim
24 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 200, 0, "lower", mtd.EXP_SC_CASE1),
25 # Threshold 5, discount = 7, preserve all dims.
26 # discount_distance is maximum for both false alarms and misses
27 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 7, "lower", mtd.EXP_SC_CASE2),
28 # Threshold 5, discount = 0.5, preserve all dims.
29 # discount_distance is minimum for both false alarms and misses
30 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0.5, "lower", mtd.EXP_SC_CASE3),
31 # Test lower/right assignment
32 (mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "lower", mtd.EXP_SC_CASE4),
33 # Test upper/left assignment
34 (mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "upper", mtd.EXP_SC_CASE5),
35 ],
36)
37def test__single_category_score(fcst, obs, categorical_threshold, discount_distance, threshold_assignment, expected):
38 """Tests _single_category_score"""
39 risk_parameter = 0.7
41 calculated = _single_category_score(
42 fcst, obs, risk_parameter, categorical_threshold, discount_distance, threshold_assignment
43 )
44 xr.testing.assert_allclose(calculated, expected)
47@pytest.mark.parametrize(
48 (
49 "fcst",
50 "obs",
51 "risk_parameters",
52 "categorical_thresholds",
53 "weights",
54 "reduce_dims",
55 "preserve_dims",
56 "discount_distance",
57 "expected",
58 ),
59 [
60 # Test for single category case identical to CASE0 in
61 # test__single_category_score
62 (
63 mtd.DA_FCST_SC,
64 mtd.DA_OBS_SC,
65 0.7,
66 [5],
67 [1],
68 None,
69 ["i", "j", "k"],
70 0.0,
71 mtd.EXP_SC_CASE0,
72 ),
73 # Test for single category with discount distance. Identical to CASE3
74 # in test__single_category_score
75 (
76 mtd.DA_FCST_SC,
77 mtd.DA_OBS_SC,
78 0.7,
79 [5],
80 [1],
81 None,
82 ["i", "j", "k"],
83 0.5,
84 mtd.EXP_SC_CASE3,
85 ),
86 # Test for single category case on slightly bigger dataset
87 # Preserve dimensions
88 (
89 mtd.DA_FCST_FIRM,
90 mtd.DA_OBS_FIRM,
91 0.7,
92 [5],
93 [1],
94 None,
95 ["i", "j", "k"],
96 0,
97 mtd.EXP_FIRM_CASE0,
98 ),
99 # Single category, only keep one dimension
100 (
101 mtd.DA_FCST_FIRM,
102 mtd.DA_OBS_FIRM,
103 0.7,
104 [5],
105 [1],
106 None,
107 ["i"],
108 0,
109 mtd.EXP_FIRM_CASE1,
110 ),
111 # Single category, no dimensions
112 (
113 mtd.DA_FCST_FIRM,
114 mtd.DA_OBS_FIRM,
115 0.7,
116 [5],
117 [1],
118 None,
119 None,
120 0,
121 mtd.EXP_FIRM_CASE2,
122 ),
123 # Single category, no dimensions
124 (
125 mtd.DA_FCST_FIRM,
126 mtd.DA_OBS_FIRM,
127 0.7,
128 [5],
129 [1],
130 "all",
131 None,
132 0,
133 mtd.EXP_FIRM_CASE2,
134 ),
135 # 2 categories, same weight
136 (
137 mtd.DA_FCST_FIRM,
138 mtd.DA_OBS_FIRM,
139 0.7,
140 [0, 5],
141 [1, 1],
142 None,
143 ["i", "j", "k"],
144 0,
145 mtd.EXP_FIRM_CASE3,
146 ),
147 # 3 categories, same weight
148 (
149 mtd.DA_FCST_FIRM,
150 mtd.DA_OBS_FIRM,
151 0.7,
152 [0, 5, 8],
153 [1, 1, 1],
154 None,
155 ["i", "j", "k"],
156 0,
157 mtd.EXP_FIRM_CASE4,
158 ),
159 # 2 categories, 2 weights
160 (
161 mtd.DA_FCST_FIRM,
162 mtd.DA_OBS_FIRM,
163 0.7,
164 [0, 5],
165 [2, 1],
166 None,
167 ["i", "j", "k"],
168 0,
169 mtd.EXP_FIRM_CASE4,
170 ),
171 # 2 categories, 2 weights that are xr.DataArrays
172 (
173 mtd.DA_FCST_FIRM,
174 mtd.DA_OBS_FIRM,
175 0.7,
176 [0, 5],
177 mtd.LIST_WEIGHTS_FIRM0,
178 None,
179 ["i", "j", "k"],
180 0,
181 mtd.EXP_FIRM_CASE4,
182 ),
183 # 2 categories, 2 weights with the first a xr.DataArray and the second
184 # a float
185 (
186 mtd.DA_FCST_FIRM,
187 mtd.DA_OBS_FIRM,
188 0.7,
189 [0, 5],
190 [mtd.LIST_WEIGHTS_FIRM0[0], 1],
191 None,
192 ["i", "j", "k"],
193 0,
194 mtd.EXP_FIRM_CASE4,
195 ),
196 # 2 categories, 2 weights that are xr.DataArrays with different values
197 # for different coords
198 (
199 mtd.DA_FCST_FIRM,
200 mtd.DA_OBS_FIRM,
201 0.7,
202 [0, 5],
203 mtd.LIST_WEIGHTS_FIRM1,
204 None,
205 ["i", "j", "k"],
206 0,
207 mtd.EXP_FIRM_CASE5,
208 ),
209 # 2 categories, 2 weights that are xr.DataArrays with different values
210 # for different coords, with NaN in weights
211 (
212 mtd.DA_FCST_FIRM,
213 mtd.DA_OBS_FIRM,
214 0.7,
215 [0, 5],
216 mtd.LIST_WEIGHTS_FIRM2,
217 None,
218 ["i", "j", "k"],
219 0,
220 mtd.EXP_FIRM_CASE6,
221 ),
222 ],
223)
224def test_firm(
225 fcst,
226 obs,
227 risk_parameters,
228 categorical_thresholds,
229 weights,
230 reduce_dims,
231 preserve_dims,
232 discount_distance,
233 expected,
234):
235 """Tests firm"""
236 calculated = firm(
237 fcst,
238 obs,
239 risk_parameters,
240 categorical_thresholds,
241 weights,
242 discount_distance,
243 reduce_dims,
244 preserve_dims,
245 )
246 if preserve_dims != None:
247 calculated = calculated.transpose(*preserve_dims)
248 xr.testing.assert_allclose(
249 calculated,
250 expected,
251 atol=0.001,
252 )
255def test_firm_dask():
256 """Tests firm works with dask"""
257 calculated = firm(
258 mtd.DA_FCST_FIRM.chunk(),
259 mtd.DA_OBS_FIRM.chunk(),
260 0.7,
261 [0, 5],
262 mtd.LIST_WEIGHTS_FIRM2,
263 0,
264 reduce_dims=None,
265 preserve_dims=["i", "j", "k"],
266 )
268 calculated = calculated.transpose("i", "j", "k")
270 assert isinstance(calculated.firm_score.data, dask.array.Array)
271 calculated = calculated.compute()
272 assert isinstance(calculated.firm_score.data, np.ndarray)
273 xr.testing.assert_allclose(
274 calculated,
275 mtd.EXP_FIRM_CASE6,
276 atol=0.001,
277 )
280@pytest.mark.parametrize(
281 (
282 "fcst",
283 "obs",
284 "risk_parameters",
285 "categorical_thresholds",
286 "weights",
287 "preserve_dims",
288 "discount_distance",
289 "threshold_assignment",
290 "error_type",
291 "error_msg_snippet",
292 ),
293 [
294 # len(categorical_thresholds) is 0
295 (
296 mtd.DA_FCST_FIRM,
297 mtd.DA_OBS_FIRM,
298 0.5,
299 [],
300 [],
301 ["i", "j", "k"],
302 0,
303 "upper",
304 ValueError,
305 "`categorical_thresholds` must have at least",
306 ),
307 # weights and categories don't match. risk_parameters is a float
308 (
309 mtd.DA_FCST_FIRM,
310 mtd.DA_OBS_FIRM,
311 0.5,
312 [1],
313 [1, 2],
314 ["i", "j", "k"],
315 0,
316 "upper",
317 ValueError,
318 "`categorical_thresholds` and `weights`",
319 ),
320 # risk_parameter = 0
321 (
322 mtd.DA_FCST_FIRM,
323 mtd.DA_OBS_FIRM,
324 0.0,
325 [5],
326 [1],
327 ["i", "j", "k"],
328 0,
329 "upper",
330 ValueError,
331 "0 < `risk_parameter` < 1 must",
332 ),
333 # risk_parameter = 1
334 (
335 mtd.DA_FCST_FIRM,
336 mtd.DA_OBS_FIRM,
337 0.0,
338 [5],
339 [1],
340 ["i", "j", "k"],
341 0,
342 "upper",
343 ValueError,
344 "0 < `risk_parameter` < 1 must",
345 ),
346 # negative weight
347 (
348 mtd.DA_FCST_FIRM,
349 mtd.DA_OBS_FIRM,
350 0.5,
351 [5, 6],
352 [1, -1],
353 ["i", "j", "k"],
354 0,
355 "upper",
356 ValueError,
357 "`weights` must be > 0",
358 ),
359 # negative weight with weights being a xr.DataArray
360 (
361 mtd.DA_FCST_FIRM,
362 mtd.DA_OBS_FIRM,
363 0.5,
364 [5, 6],
365 mtd.LIST_WEIGHTS_FIRM3,
366 ["i", "j", "k"],
367 0,
368 "upper",
369 ValueError,
370 "value was found in index 0 of `weights",
371 ),
372 # zero weight
373 (
374 mtd.DA_FCST_FIRM,
375 mtd.DA_OBS_FIRM,
376 0.5,
377 [5, 6],
378 [1, 0],
379 ["i", "j", "k"],
380 0,
381 "upper",
382 ValueError,
383 "`weights` must be > 0",
384 ),
385 # zero weight with weights being a xr.DataArray
386 (
387 mtd.DA_FCST_FIRM,
388 mtd.DA_OBS_FIRM,
389 0.5,
390 [5, 6],
391 mtd.LIST_WEIGHTS_FIRM4,
392 ["i", "j", "k"],
393 0,
394 "upper",
395 ValueError,
396 "No values <= 0 are allowed in `weights`",
397 ),
398 # bad dims in weights
399 (
400 mtd.DA_FCST_FIRM,
401 mtd.DA_OBS_FIRM,
402 0.5,
403 [5, 6],
404 mtd.LIST_WEIGHTS_FIRM5,
405 ["i", "j", "k"],
406 0,
407 "upper",
408 DimensionError,
409 "of data object are not subset to",
410 ),
411 # negative discount distance
412 (
413 mtd.DA_FCST_FIRM,
414 mtd.DA_OBS_FIRM,
415 0.5,
416 [5],
417 [1],
418 ["i", "j", "k"],
419 -1,
420 "upper",
421 ValueError,
422 "`discount_distance` must be >= 0",
423 ),
424 # wrong threshold assignment
425 (
426 mtd.DA_FCST_FIRM,
427 mtd.DA_OBS_FIRM,
428 0.5,
429 [5],
430 [1],
431 ["i", "j", "k"],
432 0.0,
433 "up",
434 ValueError,
435 """ `threshold_assignment` must be either \"upper\" or \"lower\" """,
436 ),
437 ],
438)
439def test_firm_raises(
440 fcst,
441 obs,
442 risk_parameters,
443 categorical_thresholds,
444 weights,
445 preserve_dims,
446 discount_distance,
447 threshold_assignment,
448 error_type,
449 error_msg_snippet,
450):
451 """
452 Tests that the firm raises the correct errors
453 """
454 with pytest.raises(error_type, match=error_msg_snippet):
455 firm(
456 fcst,
457 obs,
458 risk_parameters,
459 categorical_thresholds,
460 weights,
461 discount_distance,
462 None,
463 preserve_dims,
464 threshold_assignment=threshold_assignment,
465 )