Coverage for tests/continuous/test_quantile_loss.py: 100%
29 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2Contains unit tests for scores.probability.continuous
3"""
4import dask
5import dask.array
6import numpy as np
7import pytest
8import xarray as xr
10from scores.continuous import quantile_score
11from scores.utils import DimensionError
12from tests.assertions import assert_dataarray_equal, assert_dataset_equal
13from tests.continuous import quantile_loss_test_data as qltd
16@pytest.mark.parametrize("alpha", [1.0, 1.1, 0, -0.8])
17def test_qsf_value_error_alpha(alpha):
18 """quantile_score raises ValueError."""
19 with pytest.raises(ValueError):
20 quantile_score(qltd.DA1_2X2, qltd.DA1_2X2, alpha)
23@pytest.mark.parametrize(
24 ("obs", "reduce_dims", "preserve_dims"),
25 [
26 # fcst and obs with mismatched dims
27 (qltd.DA1_2X2X2, None, None),
28 # dims not in fcst or obs
29 (qltd.DA1_2X2, ["bananas"], None),
30 (qltd.DA1_2X2, None, ["orange"]),
31 ],
32)
33def test_qsf_exceptions(obs, reduce_dims, preserve_dims):
34 """quantile_score raises DimensionError."""
35 with pytest.raises(DimensionError):
36 quantile_score(qltd.DA1_2X2, obs, 0.5, reduce_dims=reduce_dims, preserve_dims=preserve_dims)
39@pytest.mark.parametrize(
40 ("fcst", "obs", "alpha", "preserve_dims", "reduce_dims", "weights", "expected"),
41 [
42 (
43 qltd.DA1_2X2,
44 qltd.DA2_2X2,
45 0.7,
46 ["i", "j"],
47 None,
48 None,
49 xr.DataArray([[0.9, 0], [0.7, 1.4]], coords=[[0, 1], [0, 1]], dims=["i", "j"]),
50 ),
51 (
52 qltd.DA2_2X2,
53 qltd.DA1_2X2,
54 0.7,
55 ["i", "j"],
56 None,
57 None,
58 xr.DataArray([[2.1, 0], [0.3, 0.6]], coords=[[0, 1], [0, 1]], dims=["i", "j"]),
59 ),
60 (
61 qltd.FCST1,
62 qltd.OBS1,
63 0.2,
64 ["valid_start", "station_index", "lead_time"],
65 None,
66 None,
67 qltd.EXPECTED1,
68 ),
69 (
70 qltd.FCST1,
71 qltd.OBS1,
72 0.1,
73 ["lead_time"],
74 None,
75 None,
76 qltd.EXPECTED2,
77 ),
78 (
79 qltd.FCST1,
80 qltd.OBS1,
81 0.1,
82 None,
83 ["valid_start", "station_index"],
84 None,
85 qltd.EXPECTED2,
86 ),
87 # To test weight
88 (
89 qltd.FCST1,
90 qltd.OBS1,
91 0.1,
92 None,
93 ["valid_start", "station_index"],
94 qltd.WEIGHTS,
95 qltd.EXPECTED4,
96 ),
97 # To test missing data
98 (qltd.FCST2, qltd.FCST2, 0.8, ["lead_time"], None, None, qltd.EXPECTED3),
99 (qltd.FCST2, qltd.FCST2, 0.8, None, None, None, xr.DataArray([0.0]).squeeze()),
100 # To test function can handle Dataset as input
101 (
102 qltd.FCST_DS,
103 qltd.OBS_DS,
104 0.1,
105 None,
106 ["valid_start", "station_index"],
107 None,
108 qltd.EXPECTED_DS1,
109 ),
110 (
111 qltd.FCST_DS,
112 qltd.OBS_DS,
113 0.1,
114 None,
115 ["valid_start", "station_index"],
116 qltd.WEIGHTS_DS,
117 qltd.EXPECTED_DS2,
118 ),
119 ],
120)
121def test_qsf_calculations(fcst, obs, alpha, preserve_dims, reduce_dims, weights, expected):
122 """quantile_score returns the expected object."""
123 result = quantile_score(fcst, obs, alpha, preserve_dims=preserve_dims, reduce_dims=reduce_dims, weights=weights)
124 if isinstance(fcst, xr.DataArray):
125 assert_dataarray_equal(result, expected, decimals=7)
126 else:
127 assert_dataset_equal(result, expected, decimals=7)
130def test_quantile_score_dask():
131 """Tests quantile_score works with dask"""
132 result = quantile_score(
133 fcst=qltd.FCST1.chunk(),
134 obs=qltd.OBS1.chunk(),
135 alpha=0.1,
136 reduce_dims=["valid_start", "station_index"],
137 )
138 assert isinstance(result.data, dask.array.Array)
139 result = result.compute()
140 assert isinstance(result.data, np.ndarray)
141 assert_dataarray_equal(result, qltd.EXPECTED2, decimals=7)