Coverage for tests/probabilty/test_crps.py: 100%
87 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1# pylint disable: too-many-arguments
2"""
3Contains unit tests for scores.probability.crps
4"""
5import dask
6import dask.array
7import numpy as np
8import pytest
9import xarray as xr
11from scores.probability import (
12 adjust_fcst_for_crps,
13 crps_cdf,
14 crps_cdf_brier_decomposition,
15 crps_for_ensemble,
16)
17from scores.probability.crps_impl import (
18 crps_cdf_exact,
19 crps_cdf_reformat_inputs,
20 crps_cdf_trapz,
21 crps_step_threshold_weight,
22)
23from tests.assertions import assert_dataarray_equal, assert_dataset_equal
24from tests.probabilty import crps_test_data
27@pytest.mark.parametrize(
28 ("weight_upper", "expected"),
29 [
30 (True, crps_test_data.EXP_STEP_WEIGHT_UPPER),
31 (False, crps_test_data.EXP_STEP_WEIGHT_LOWER),
32 ],
33)
34def test_crps_stepweight(
35 weight_upper,
36 expected,
37):
38 """Tests `crps_step_threshold_weight` with a variety of inputs."""
39 result = crps_step_threshold_weight(
40 crps_test_data.DA_STEP_WEIGHT,
41 "x",
42 [1, 2, 3, 4, 5, 6],
43 True,
44 0.2,
45 weight_upper,
46 )
47 assert_dataarray_equal(result, expected, decimals=7)
50def test_crps_cdf_exact():
51 """Tests `crps_cdf_exact`."""
52 result = crps_cdf_exact(
53 crps_test_data.DA_FCST_CRPS_EXACT,
54 crps_test_data.DA_OBS_CRPS_EXACT,
55 crps_test_data.DA_WT_CRPS_EXACT,
56 "x",
57 include_components=True,
58 )
59 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=7)
61 result2 = crps_cdf_exact(
62 crps_test_data.DA_FCST_CRPS_EXACT,
63 crps_test_data.DA_OBS_CRPS_EXACT,
64 crps_test_data.DA_WT_CRPS_EXACT,
65 "x",
66 include_components=False,
67 )
69 assert list(result2.data_vars) == ["total"]
72def test_crps_cdf_exact_dask():
73 """Tests `crps_cdf_exact` works with Dask."""
74 result = crps_cdf_exact(
75 crps_test_data.DA_FCST_CRPS_EXACT.chunk(),
76 crps_test_data.DA_OBS_CRPS_EXACT.chunk(),
77 crps_test_data.DA_WT_CRPS_EXACT,
78 "x",
79 include_components=True,
80 )
81 assert isinstance(result.total.data, dask.array.Array)
82 result = result.compute()
83 assert isinstance(result.total.data, np.ndarray)
84 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=7)
86 result2 = crps_cdf_exact(
87 crps_test_data.DA_FCST_CRPS_EXACT.chunk(),
88 crps_test_data.DA_OBS_CRPS_EXACT.chunk(),
89 crps_test_data.DA_WT_CRPS_EXACT.chunk(),
90 "x",
91 include_components=False,
92 )
93 assert isinstance(result2.total.data, dask.array.Array)
94 result2 = result2.compute()
95 assert isinstance(result2.total.data, np.ndarray)
96 assert list(result2.data_vars) == ["total"]
99@pytest.mark.parametrize(
100 ("weight_upper", "expected"),
101 [
102 (True, crps_test_data.EXP_STEP_WEIGHT_UPPER),
103 (False, crps_test_data.EXP_STEP_WEIGHT_LOWER),
104 ],
105)
106def test_crps_stepweight2(
107 weight_upper,
108 expected,
109):
110 """Tests `crps_step_threshold_weight` with a variety of inputs."""
111 result = crps_step_threshold_weight(
112 crps_test_data.DA_STEP_WEIGHT,
113 "x",
114 [1, 2, 3, 4, 5, 6],
115 True,
116 0.2,
117 weight_upper,
118 )
119 assert_dataarray_equal(result, expected, decimals=7)
122def test_crps_cdf_trapz():
123 """Tests `crps_cdf_trapz`. Uses dense interpolation to get an approximate (to 4 dec pl) result."""
124 result = crps_cdf_trapz(
125 crps_test_data.DA_FCST_CRPS_DENSE,
126 crps_test_data.DA_OBS_CRPS_DENSE,
127 crps_test_data.DA_WT_CRPS_DENSE,
128 "x",
129 include_components=True,
130 )
131 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=4)
133 result2 = crps_cdf_trapz(
134 crps_test_data.DA_FCST_CRPS_DENSE,
135 crps_test_data.DA_OBS_CRPS_DENSE,
136 crps_test_data.DA_WT_CRPS_DENSE,
137 "x",
138 include_components=False,
139 )
140 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=4)
141 assert list(result2.data_vars) == ["total"]
144@pytest.mark.parametrize(
145 (
146 "threshold_weight",
147 "additional_thresholds",
148 "expected",
149 ),
150 [
151 (
152 crps_test_data.DA_WT_REFORMAT1,
153 None,
154 crps_test_data.EXP_REFORMAT1,
155 ),
156 (
157 crps_test_data.DA_WT_REFORMAT1,
158 [0, 1.5],
159 crps_test_data.EXP_REFORMAT2,
160 ),
161 (
162 None,
163 None,
164 crps_test_data.EXP_REFORMAT3,
165 ),
166 ],
167)
168def test_crps_cdf_reformat_inputs(
169 threshold_weight,
170 additional_thresholds,
171 expected,
172):
173 """Tests `crps_cdf_reformat_inputs` with a variety of inputs."""
174 result = crps_cdf_reformat_inputs(
175 crps_test_data.DA_FCST_REFORMAT1,
176 crps_test_data.DA_OBS_REFORMAT1,
177 "x",
178 threshold_weight,
179 additional_thresholds,
180 "linear",
181 "forward",
182 )
183 assert len(result) == len(expected)
185 for res, exp in zip(result, expected):
186 assert_dataarray_equal(res, exp, decimals=7)
189@pytest.mark.parametrize(
190 (
191 "fcst",
192 "obs",
193 "threshold_dim",
194 "threshold_weight",
195 "fcst_fill_method",
196 "threshold_weight_fill_method",
197 "integration_method",
198 "dims",
199 "error_msg_snippet",
200 ),
201 [
202 (
203 crps_test_data.DA_FCST_REFORMAT1,
204 crps_test_data.DA_OBS_REFORMAT1,
205 "y",
206 None,
207 "linear",
208 None,
209 "exact",
210 None,
211 "'y' is not a dimension of `fcst`",
212 ),
213 (
214 crps_test_data.DA_FCST_REFORMAT1,
215 crps_test_data.DA_OBS_REFORMAT1,
216 "x",
217 crps_test_data.DA_WT_CHECK_CRPS1,
218 "linear",
219 "forward",
220 "exact",
221 None,
222 "'x' is not a dimension of `threshold_weight`",
223 ),
224 (
225 crps_test_data.DA_FCST_REFORMAT1,
226 crps_test_data.DA_OBS_REFORMAT1,
227 "station",
228 None,
229 "linear",
230 None,
231 "exact",
232 None,
233 "'station' is a dimension of `obs`",
234 ),
235 (
236 crps_test_data.DA_FCST_REFORMAT1,
237 crps_test_data.DA_OBS_CHECK_CRPS,
238 "x",
239 None,
240 "linear",
241 None,
242 "exact",
243 None,
244 "Dimensions of `obs` must be a subset of dimensions of `fcst`",
245 ),
246 (
247 crps_test_data.DA_FCST_REFORMAT1,
248 crps_test_data.DA_OBS_REFORMAT1,
249 "x",
250 crps_test_data.DA_WT_CHECK_CRPS2,
251 "linear",
252 "backward",
253 "exact",
254 None,
255 "Dimensions of `threshold_weight` must be a subset of dimensions of `fcst`",
256 ),
257 # TODO: Revisit if still needed after more handling in gather_dimensions
258 # (
259 # crps_test_data.DA_FCST_REFORMAT1,
260 # crps_test_data.DA_OBS_REFORMAT1,
261 # "x",
262 # None,
263 # "linear",
264 # None,
265 # "exact",
266 # ["y", "x"],
267 # "`dims` must be a subset of `fcst` dimensions",
268 # ),
269 (
270 crps_test_data.DA_FCST_REFORMAT1,
271 crps_test_data.DA_OBS_REFORMAT1,
272 "x",
273 None,
274 "fat",
275 None,
276 "exact",
277 None,
278 "`fcst_fill_method` must be 'linear', 'step', 'forward' or 'backward'",
279 ),
280 (
281 crps_test_data.DA_FCST_REFORMAT1,
282 crps_test_data.DA_OBS_REFORMAT1,
283 "x",
284 crps_test_data.DA_WT_REFORMAT1,
285 "linear",
286 "fat",
287 "exact",
288 None,
289 "`threshold_weight_fill_method` must be 'linear', 'step', 'forward' or 'backward'",
290 ),
291 (
292 crps_test_data.DA_FCST_REFORMAT1,
293 crps_test_data.DA_OBS_REFORMAT1,
294 "x",
295 crps_test_data.DA_WT_REFORMAT1,
296 "linear",
297 "step",
298 "waffly",
299 None,
300 "`integration_method` must be 'exact' or 'trapz'",
301 ),
302 (
303 crps_test_data.DA_FCST_CHECK_CRPS,
304 crps_test_data.DA_OBS_REFORMAT1,
305 "x",
306 crps_test_data.DA_WT_REFORMAT1,
307 "linear",
308 "step",
309 "exact",
310 None,
311 "`threshold_dim` in `fcst` must have at least 2 values to calculate CRPS",
312 ),
313 (
314 crps_test_data.DA_FCST_CHECK_CRPS2,
315 crps_test_data.DA_OBS_CHECK_CRPS2,
316 "x",
317 None,
318 "linear",
319 "step",
320 "exact",
321 None,
322 "`threshold_dim` coordinates in `fcst` must be increasing",
323 ),
324 (
325 crps_test_data.DA_FCST_CHECK_CRPS2A,
326 crps_test_data.DA_OBS_CHECK_CRPS2,
327 "x",
328 crps_test_data.DA_FCST_CHECK_CRPS2, # weight
329 "linear",
330 "step",
331 "exact",
332 None,
333 "`threshold_dim` coordinates in `threshold_weight` must be increasing",
334 ),
335 (
336 crps_test_data.DA_FCST_CHECK_CRPS2A,
337 crps_test_data.DA_OBS_CHECK_CRPS2,
338 "x",
339 crps_test_data.DA_WT_CHECK_CRPS3, # weight
340 "linear",
341 "step",
342 "exact",
343 None,
344 "`threshold_weight` has negative values",
345 ),
346 ],
347)
348# pylint: disable=too-many-arguments
349def test_crps_cdf_raises(
350 fcst,
351 obs,
352 threshold_dim,
353 threshold_weight,
354 fcst_fill_method,
355 threshold_weight_fill_method,
356 integration_method,
357 dims,
358 error_msg_snippet,
359):
360 """Check that `crps` raises exceptions as expected."""
361 with pytest.raises(ValueError, match=error_msg_snippet):
362 crps_cdf(
363 fcst,
364 obs,
365 threshold_dim,
366 threshold_weight,
367 [],
368 True,
369 fcst_fill_method,
370 threshold_weight_fill_method,
371 integration_method,
372 dims,
373 )
376@pytest.mark.parametrize(
377 (
378 "fcst",
379 "threshold_weight",
380 "propagate_nan",
381 "integration_method",
382 "dims",
383 "expected_and_dec", # tuple: expected and decimals
384 ),
385 [
386 ( # exact, preserve station dim
387 crps_test_data.DA_FCST_CRPS,
388 crps_test_data.DA_WT_CRPS,
389 True,
390 "exact",
391 ["station"],
392 (crps_test_data.EXP_CRPS1, 7),
393 ),
394 ( # exact, preserve no dims
395 crps_test_data.DA_FCST_CRPS,
396 crps_test_data.DA_WT_CRPS,
397 True,
398 "exact",
399 None,
400 (crps_test_data.EXP_CRPS2, 7),
401 ),
402 ( # trapz, preserve station dim
403 crps_test_data.DA_FCST_CRPS_DENSE,
404 crps_test_data.DA_WT_CRPS_DENSE,
405 True,
406 "trapz",
407 ["station"],
408 (crps_test_data.EXP_CRPS1, 4),
409 ),
410 ( # exact, preserve station dim, weight is None
411 crps_test_data.DA_FCST_CRPS,
412 None,
413 True,
414 "exact",
415 ["station"],
416 (crps_test_data.EXP_CRPS3, 7),
417 ),
418 ( # exact, preserve station dim, weight is None, don't propagate nan
419 crps_test_data.DA_FCST_CRPS,
420 None,
421 False,
422 "exact",
423 ["station"],
424 (crps_test_data.EXP_CRPS4, 7),
425 ),
426 ( # exact, preserve station dim, don't propagate nan
427 crps_test_data.DA_FCST_CRPS,
428 crps_test_data.DA_WT_CRPS,
429 False,
430 "exact",
431 ["station"],
432 (crps_test_data.EXP_CRPS5, 7),
433 ),
434 ],
435)
436# pylint: disable=too-many-arguments
437def test_crps_cdf(
438 fcst,
439 threshold_weight,
440 propagate_nan,
441 integration_method,
442 dims,
443 expected_and_dec,
444):
445 """Tests `crps` with a variety of inputs."""
446 result = crps_cdf(
447 fcst,
448 crps_test_data.DA_OBS_CRPS,
449 "x",
450 threshold_weight=threshold_weight,
451 additional_thresholds=None,
452 propagate_nans=propagate_nan,
453 fcst_fill_method="linear",
454 threshold_weight_fill_method="forward",
455 integration_method=integration_method,
456 preserve_dims=dims,
457 include_components=True,
458 )
460 assert_dataset_equal(result, expected_and_dec[0], decimals=expected_and_dec[1])
463@pytest.mark.parametrize(
464 (
465 "fcst",
466 "obs",
467 "decreasing_tolerance",
468 "expected",
469 ),
470 [
471 (
472 crps_test_data.DA_FCST_ADJUST1,
473 crps_test_data.DA_OBS_ADJUST1,
474 0,
475 crps_test_data.EXP_FCST_ADJUST1,
476 ),
477 ( # exact, preserve station dim
478 crps_test_data.DA_FCST_ADJUST2,
479 crps_test_data.DA_OBS_ADJUST2,
480 10,
481 crps_test_data.EXP_FCST_ADJUST2,
482 ),
483 ],
484)
485def test_adjust_fcst_for_crps(
486 fcst,
487 obs,
488 decreasing_tolerance,
489 expected,
490):
491 """Tests `adjust_fcst_for_crps` with a variety of inputs."""
492 result = adjust_fcst_for_crps(fcst, "x", obs, decreasing_tolerance)
493 assert_dataarray_equal(result, expected, decimals=7)
496@pytest.mark.parametrize(
497 (
498 "threshold_dim",
499 "decreasing_tolerance",
500 "error_msg_snippet",
501 ),
502 [
503 (
504 "aj",
505 0,
506 "'aj' is not a dimension of `fcst`",
507 ),
508 (
509 "x",
510 -10,
511 "`decreasing_tolerance` must be nonnegative",
512 ),
513 ],
514)
515def test_adjust_fcst_raises(
516 threshold_dim,
517 decreasing_tolerance,
518 error_msg_snippet,
519):
520 """Check that `adjust_fcst_for_crps` raises exceptions as expected."""
521 with pytest.raises(ValueError, match=error_msg_snippet):
522 adjust_fcst_for_crps(
523 crps_test_data.DA_FCST_ADJUST1,
524 threshold_dim,
525 crps_test_data.DA_OBS_ADJUST1,
526 decreasing_tolerance,
527 )
530@pytest.mark.parametrize(
531 (
532 "fcst",
533 "obs",
534 "threshold_dim",
535 "fcst_fill_method",
536 "dims",
537 "error_msg_snippet",
538 ),
539 [
540 (
541 crps_test_data.DA_FCST_REFORMAT1,
542 crps_test_data.DA_OBS_REFORMAT1,
543 "y",
544 "linear",
545 None,
546 "'y' is not a dimension of `fcst`",
547 ),
548 (
549 crps_test_data.DA_FCST_REFORMAT1,
550 crps_test_data.DA_OBS_REFORMAT1,
551 "station",
552 "linear",
553 None,
554 "'station' is a dimension of `obs`",
555 ),
556 (
557 crps_test_data.DA_FCST_REFORMAT1,
558 crps_test_data.DA_OBS_CHECK_CRPS,
559 "x",
560 "linear",
561 None,
562 "Dimensions of `obs` must be a subset of dimensions of `fcst`",
563 ),
564 (
565 crps_test_data.DA_FCST_REFORMAT1,
566 crps_test_data.DA_OBS_REFORMAT1,
567 "x",
568 "fat",
569 None,
570 "`fcst_fill_method` must be 'linear', 'step', 'forward' or 'backward'",
571 ),
572 (
573 crps_test_data.DA_FCST_CHECK_CRPS2,
574 crps_test_data.DA_OBS_CHECK_CRPS2,
575 "x",
576 "linear",
577 None,
578 "`threshold_dim` coordinates in `fcst` must be increasing",
579 ),
580 ],
581)
582# pylint: disable=too-many-arguments
583def test_crps_cdf_brier_raises(
584 fcst,
585 obs,
586 threshold_dim,
587 fcst_fill_method,
588 dims,
589 error_msg_snippet,
590):
591 """Check that `crps_cdf_brier_decomposition` raises exceptions as expected."""
592 with pytest.raises(ValueError, match=error_msg_snippet):
593 crps_cdf_brier_decomposition(fcst, obs, threshold_dim, fcst_fill_method=fcst_fill_method, reduce_dims=dims)
596@pytest.mark.parametrize(
597 ("dims", "expected"),
598 [
599 (None, crps_test_data.EXP_CRPS_BD1),
600 (["station"], crps_test_data.EXP_CRPS_BD2),
601 ],
602)
603def test_crps_cdf_brier_decomposition(dims, expected):
604 """Tests `crps_cdf_brier_decomposition` with a variety of inputs."""
605 result = crps_cdf_brier_decomposition(
606 crps_test_data.DA_FCST_CRPS_BD, crps_test_data.DA_OBS_CRPS_BD, "x", preserve_dims=dims
607 )
608 assert_dataset_equal(result, expected, decimals=7)
611def test_crps_for_ensemble():
612 """Tests `crps_for_ensemble` returns as expected."""
613 result_ecdf = crps_for_ensemble(
614 crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="ecdf", preserve_dims="all"
615 )
616 result_fair = crps_for_ensemble(
617 crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="fair", preserve_dims="all"
618 )
619 result_weighted_mean = crps_for_ensemble(
620 crps_test_data.DA_FCST_CRPSENS,
621 crps_test_data.DA_OBS_CRPSENS,
622 "ens_member",
623 method="ecdf",
624 weights=crps_test_data.DA_WT_CRPSENS,
625 )
627 assert_dataarray_equal(result_ecdf, crps_test_data.EXP_CRPSENS_ECDF, decimals=7)
628 assert_dataarray_equal(result_fair, crps_test_data.EXP_CRPSENS_FAIR, decimals=7)
629 assert_dataarray_equal(result_weighted_mean, crps_test_data.EXP_CRPSENS_WT, decimals=7)
632def test_crps_for_ensemble_raises():
633 """Tests `crps_for_ensemble` raises exception as expected."""
634 with pytest.raises(ValueError) as excinfo:
635 crps_for_ensemble(xr.DataArray(data=[1]), xr.DataArray(data=[1]), "ens_member", "unfair")
636 assert "`method` must be one of 'ecdf' or 'fair'" in str(excinfo.value)
639def test_crps_for_ensemble_dask():
640 """Tests `crps_for_ensemble` works with dask."""
641 result = crps_for_ensemble(
642 fcst=crps_test_data.DA_FCST_CRPSENS.chunk(),
643 obs=crps_test_data.DA_OBS_CRPSENS.chunk(),
644 ensemble_member_dim="ens_member",
645 method="ecdf",
646 preserve_dims="all",
647 )
648 assert isinstance(result.data, dask.array.Array)
649 result = result.compute()
650 assert isinstance(result.data, np.ndarray)
651 assert_dataarray_equal(result, crps_test_data.EXP_CRPSENS_ECDF, decimals=7)