Coverage for tests/probabilty/test_roc.py: 100%
25 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""
2Contains unit tests for scores.probability.roc_impl
3"""
4import dask
5import numpy as np
6import pytest
7import xarray as xr
9from scores.probability import roc_curve_data
10from tests.probabilty import roc_test_data as rtd
13@pytest.mark.parametrize(
14 ("fcst", "obs", "thresholds", "preserve_dims", "reduce_dims", "weights", "expected"),
15 [
16 # preserve_dims=['lead_day']
17 (
18 rtd.FCST_2X3X2_WITH_NAN,
19 rtd.OBS_3X3_WITH_NAN,
20 [0, 0.3, 1],
21 ["lead_day"],
22 None,
23 None,
24 rtd.EXP_ROC_LEADDAY,
25 ),
26 # reduce_dims=['letter', 'pet']
27 (
28 rtd.FCST_2X3X2_WITH_NAN,
29 rtd.OBS_3X3_WITH_NAN,
30 [0, 0.3, 1],
31 None,
32 ["letter", "pet"],
33 None,
34 rtd.EXP_ROC_LEADDAY,
35 ),
36 # Weighting by lead day and reducing all dims except `lead_day` should produce identical results
37 (
38 rtd.FCST_2X3X2_WITH_NAN,
39 rtd.OBS_3X3_WITH_NAN,
40 [0, 0.3, 1],
41 None,
42 ["letter", "pet"],
43 rtd.LEAD_DAY_WEIGHTS,
44 rtd.EXP_ROC_LEADDAY,
45 ),
46 # preserve_dims=None, reduce_dims=None
47 (
48 rtd.FCST_2X3X2_WITH_NAN,
49 rtd.OBS_3X3_WITH_NAN,
50 [0, 0.3, 1],
51 None,
52 None,
53 None,
54 rtd.EXP_ROC_NONE,
55 ),
56 # preserve_dims=None, reduce_dims=None, weight by lead_day
57 (
58 rtd.FCST_2X3X2_WITH_NAN,
59 rtd.OBS_3X3_WITH_NAN,
60 [0, 0.3, 1],
61 None,
62 None,
63 rtd.LEAD_DAY_WEIGHTS,
64 rtd.EXP_ROC_NONE_WEIGHTED,
65 ),
66 # Test AUC works with multiple DIMS preserved
67 (
68 rtd.FCST_2X3X2_WITH_NAN,
69 rtd.OBS_3X3_WITH_NAN,
70 [0, 0.3, 1],
71 ["lead_day", "letter"],
72 None,
73 None,
74 rtd.EXP_ROC_MULTI_DIMS,
75 ),
76 # Test AUC works with multiple DIMS preserved (and dim order switched)
77 (
78 rtd.FCST_2X3X2_WITH_NAN,
79 rtd.OBS_3X3_WITH_NAN,
80 [0, 0.3, 1],
81 ["letter", "lead_day"],
82 None,
83 None,
84 rtd.EXP_ROC_MULTI_DIMS,
85 ),
86 ],
87)
88def test_roc_curve_data(fcst, obs, thresholds, preserve_dims, reduce_dims, weights, expected):
89 """
90 Tests the roc_curve_data
91 """
92 result = roc_curve_data(
93 fcst, obs, thresholds, preserve_dims=preserve_dims, reduce_dims=reduce_dims, weights=weights
94 )
95 result.broadcast_equals(expected)
98def test_roc_curve_data_dask():
99 """tests that roc_curve_data works with dask"""
100 result = roc_curve_data(
101 rtd.FCST_2X3X2_WITH_NAN.chunk(),
102 rtd.OBS_3X3_WITH_NAN.chunk(),
103 [0, 0.3, 1],
104 preserve_dims=["letter", "lead_day"],
105 check_args=False,
106 )
107 assert isinstance(result.POD.data, dask.array.Array) # type: ignore
108 assert isinstance(result.POFD.data, dask.array.Array) # type: ignore
109 assert isinstance(result.AUC.data, dask.array.Array) # type: ignore
111 result = result.compute()
113 assert isinstance(result.POD.data, np.ndarray)
114 assert isinstance(result.POFD.data, np.ndarray)
115 assert isinstance(result.AUC.data, np.ndarray)
116 result.broadcast_equals(rtd.EXP_ROC_MULTI_DIMS)
119@pytest.mark.parametrize(
120 ("fcst", "obs", "thresholds", "preserve_dims", "error_class", "error_msg_snippet"),
121 [
122 # fcst has invalid values
123 (
124 xr.DataArray([43, 100, 1, 0]),
125 xr.DataArray([0, 1, 1, 0]),
126 [0.3],
127 None,
128 ValueError,
129 "`fcst` contains values outside of the range [0, 1]",
130 ),
131 # obs has invalid values
132 (
133 xr.DataArray([0, 0.3, 1, 0]),
134 xr.DataArray([0, 1, 1, 0.5]),
135 [0.3],
136 None,
137 ValueError,
138 "`obs` contains values that are not in the set {0, 1, np.nan}",
139 ),
140 # 'threshold' in fcst dims
141 (
142 xr.DataArray([0.5], dims=["threshold"]),
143 xr.DataArray(0),
144 [0.3],
145 None,
146 ValueError,
147 "'threshold' must not be in the supplied data object dimensions",
148 ),
149 # put an np.inf in the thresholds
150 (
151 rtd.FCST_2X3X2_WITH_NAN,
152 rtd.OBS_3X3_WITH_NAN,
153 [-np.inf, 0, 0.3, 1, np.inf],
154 None,
155 ValueError,
156 "`thresholds` contains values outside of the range [0, 1]",
157 ),
158 # thresholds are not monotonic increasing
159 (
160 rtd.FCST_2X3X2_WITH_NAN,
161 rtd.OBS_3X3_WITH_NAN,
162 [0, 0.3, 0.2, 1],
163 None,
164 ValueError,
165 "`thresholds` is not monotonic increasing between 0 and 1",
166 ),
167 ],
168)
169def test_roc_curve_data_raises(fcst, obs, thresholds, preserve_dims, error_class, error_msg_snippet):
170 """
171 Tests that roc_curve_data raises the correct error
172 """
173 with pytest.raises(error_class) as exc:
174 roc_curve_data(fcst, obs, thresholds, preserve_dims=preserve_dims)
175 assert error_msg_snippet in str(exc.value)