Coverage for tests/test_processing.py: 100%
41 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
« prev ^ index » next coverage.py v7.3.2, created at 2024-02-28 12:51 +1100
1"""Tests for processing"""
2import numpy as np
3import pytest
4import xarray as xr
6from scores.processing import (
7 binary_discretise,
8 broadcast_and_match_nan,
9 check_binary,
10 comparative_discretise,
11)
12from tests import test_processing_data as xtd
15@pytest.mark.parametrize(
16 ("args", "expected"),
17 [
18 # DataArrays
19 ([], tuple()),
20 ([xtd.DA_2], (xtd.DA_2,)),
21 ([xtd.DA_1, xtd.DA_2], xtd.EXPECTED_12),
22 ([xtd.DA_2, xtd.DA_1], xtd.EXPECTED_21),
23 ([xtd.DA_1, xtd.DA_2, xtd.DA_3], xtd.EXPECTED_123),
24 ([xtd.DA_2, xtd.DA_4], xtd.EXPECTED_24),
25 ([xtd.DA_1, xtd.DA_5], xtd.EXPECTED_15),
26 ([xtd.DA_2, xtd.DA_6], xtd.EXPECTED_26),
27 ([xtd.DA_3, xtd.DA_7], xtd.EXPECTED_37),
28 # Datasets
29 ([xtd.DS_12], xtd.EXPECTED_DS12),
30 ([xtd.DS_123], xtd.EXPECTED_DS123),
31 ([xtd.DS_12, xtd.DS_123], xtd.EXPECTED_DS12_DS123),
32 ([xtd.DS_12, xtd.DS_3], xtd.EXPECTED_DS12_DS3),
33 # Datasets and DataArrays
34 ([xtd.DS_3, xtd.DA_7], xtd.EXPECTED_DS3_7),
35 ([xtd.DS_7, xtd.DA_3], xtd.EXPECTED_DS7_3),
36 ([xtd.DS_12, xtd.DA_3], xtd.EXPECTED_DS12_3),
37 ],
38)
39def test_broadcast_and_match_nan(args, expected):
40 """
41 Tests that broadcast_and_match_nan calculates the correct result
42 Args:
43 args: a list of the args that will be *-ed into match_DataArray
44 expected: a tuple, the expected output of match_DataArray
45 """
46 calculated = broadcast_and_match_nan(*args)
47 for calculated_element, expected_element in zip(calculated, expected):
48 assert calculated_element.equals(expected_element)
51@pytest.mark.parametrize(
52 ("args", "error_msg_snippet"),
53 [
54 ([xr.Dataset({"DA_1": xtd.DA_1}), xtd.DA_1, np.arange(4)], "Argument 2"),
55 ([np.arange(5)], "Argument 0"),
56 ],
57)
58def test_broadcast_and_match_nan_rasies(args, error_msg_snippet):
59 """
60 Tests that processing.broadcast_and_match_nan correctly raises an ValueError
61 """
62 with pytest.raises(ValueError) as excinfo:
63 broadcast_and_match_nan(*args)
64 assert error_msg_snippet in str(excinfo.value)
67@pytest.mark.parametrize(
68 ("data", "comparison", "mode", "abs_tolerance", "expected"),
69 [
70 ####################################################
71 # Tests with 1-D comparison, testing abs_tolerance #
72 ####################################################
73 # TESTS FOR MODE='>='
74 # 0: mode='>=', 0-D, value greater than threshold
75 (xr.DataArray(0.5), xtd.THRESH_DA_0, ">=", None, xtd.EXP_CDIS_0),
76 # 1: mode='>=', 0-D, equal value and threshold
77 (xr.DataArray(0.3), xtd.THRESH_DA_0, ">=", None, xtd.EXP_CDIS_1),
78 # 2: mode='>=', 0-D, value less than threshold
79 (xr.DataArray(0.2), xtd.THRESH_DA_0, ">=", None, xtd.EXP_CDIS_2),
80 # 3: mode='>=', 0-D, value less than threshold but within tolerance
81 (xr.DataArray(0.299999999), xtd.THRESH_DA_0, ">=", 1e-8, xtd.EXP_CDIS_3),
82 # 4: mode='>=', 0-D, value less than threshold and outside tolerance
83 (xr.DataArray(0.299999999), xtd.THRESH_DA_0, ">=", 1e-10, xtd.EXP_CDIS_4),
84 # 5: mode='>=', 0-D, threshold of length 3
85 (xr.DataArray(0.3), xtd.THRESH_DA_1, ">=", 1e-10, xtd.EXP_CDIS_5),
86 # 6: mode='>=', 0-D, integers, threshold of length 3
87 (xr.DataArray(5), xtd.THRESH_DA_2, ">=", 0, xtd.EXP_CDIS_6),
88 # 7: mode='>=', 0-D, NaN, threshold of length 3
89 (xr.DataArray(np.nan), xtd.THRESH_DA_2, ">=", 1e-10, xtd.EXP_CDIS_7),
90 # 8: mode='>=', 1-D, one threshold
91 (
92 xr.DataArray([0.4, 0.2, 0.7], dims=["day"], coords={"day": [0, 1, 2]}),
93 xtd.THRESH_DA_4,
94 ">=",
95 None,
96 xtd.EXP_CDIS_8,
97 ),
98 # 9: mode='>=', 1-D, two thresholds
99 (
100 xr.DataArray([0.4, 0.2, 0.7], dims=["day"], coords={"day": [0, 1, 2]}),
101 xtd.THRESH_DA_3,
102 ">=",
103 None,
104 xtd.EXP_CDIS_9,
105 ),
106 # 10: mode='>=', 1-D, two thresholds, tolerance=1e-8
107 (xtd.DATA_4X1, xtd.THRESH_DA_3, ">=", 1e-8, xtd.EXP_CDIS_10),
108 # 11: mode='>=', 1-D, with NaN, two thresholds, tolerance=1e-8
109 (xtd.DATA_4X1_NAN, xtd.THRESH_DA_3, ">=", 1e-8, xtd.EXP_CDIS_11),
110 # 12: mode='>=', 1-D, with NaNs, two thresholds, tolerance=1e-8
111 (xtd.DATA_4X1_2NAN, xtd.THRESH_DA_3, ">=", 1e-8, xtd.EXP_CDIS_12),
112 # 13: mode='>=', 1-D, with NaN, two thresholds, tolerance=1e-10
113 (xtd.DATA_4X1_NAN, xtd.THRESH_DA_3, ">=", 1e-10, xtd.EXP_CDIS_13),
114 # 14: mode='>=', 2-D, with NaN, two thresholds, tolerance=1e-8
115 (xtd.DATA_4X2_POINT4_POINT5_NAN, xtd.THRESH_DA_3, ">=", 1e-8, xtd.EXP_CDIS_14),
116 ############
117 # Datasets #
118 ############
119 # 15: Data 1-D dataset, comparison 1-D
120 (
121 xr.Dataset({"zero": xtd.DATA_4X1, "one": xtd.DATA_4X1_NAN}),
122 xtd.THRESH_DA_3,
123 ">=",
124 1e-8,
125 xtd.EXP_CDIS_DS_1D_1D,
126 ),
127 # 16: Dataset input, mode='>=', 0-D data, 0_D comparison
128 (
129 xr.Dataset({"zero": xr.DataArray(0.5)}),
130 xtd.THRESH_DA_0,
131 ">=",
132 None,
133 xtd.EXP_CDIS_DS_0D_1D,
134 ),
135 #############################
136 # Tests with n-D comparison #
137 #############################
138 # 17. 0-D data, 0-D comparison
139 (
140 xr.DataArray(0.3),
141 xr.DataArray(2),
142 ">",
143 1e-8,
144 xr.DataArray(
145 0.0,
146 attrs={"discretisation_tolerance": 1e-8, "discretisation_mode": ">"},
147 ),
148 ),
149 # 18. 0-D data, 1-D comparison
150 (xr.DataArray(0.3), xtd.COMP_1D, ">", None, xtd.EXP_CDIS_0D_1D),
151 # 19. 1-D data, 0-D comparison
152 (xtd.DA_1D, xr.DataArray(1.3), ">", None, xtd.EXP_CDIS_1D_0D),
153 # 20. 1-D data, 1-D comparison with shared dimension
154 (xtd.DA_1D, xtd.COMP_1DA, ">", None, xtd.EXP_CDIS_1D_1DA),
155 # 21. 1-D data, 2-D comparison, one dimension shared
156 (xtd.DA_1D, xtd.COMP_2DA, ">=", None, xtd.EXP_CDIS_1D_2DA),
157 # 22. 1-D data, 2-D comparison, no shared dimension
158 (xtd.DA_1D, xtd.COMP_2D, ">", None, xtd.EXP_CDIS_1D_2D),
159 # 23. 2-D data, 0-D comparison
160 (xtd.DA_2D, xr.DataArray(1.2), ">", None, xtd.EXP_CDIS_2D_0D),
161 # 24. 2-D data, 1-D comparison with shared dimension
162 (xtd.DA_2D, xtd.COMP_1DA, ">", None, xtd.EXP_CDIS_2D_1DA),
163 # 25. 2-D data, 1-D comparison
164 (xtd.DA_2D, xtd.COMP_1D, ">", None, xtd.EXP_CDIS_2D_1D),
165 # 26. 3-D data, 1-D comparison withshared dimension
166 (xtd.DA_3D, xtd.COMP_1DA, ">", None, xtd.EXP_CDIS_3D_1DA),
167 # 27. 3-D data, 1-D comparison withshared dimension
168 (xtd.DA_3D, xtd.DA_3D, ">", None, xtd.EXP_CDIS_3D_3D),
169 ################################################################
170 # SMOKE TESTS FOR ALL MODES ['<', '>', '>=', '<=', '++', '!='] #
171 ################################################################
172 # 28. mode='>=', tolerance=1e-8
173 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">=", 1e-8, xtd.EXP_CDIS_GE0),
174 # 29. mode='>=', tolerance=0
175 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">=", None, xtd.EXP_CDIS_GE1),
176 # 30. mode='>', tolerance=1e-8
177 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">", 1e-8, xtd.EXP_CDIS_GT0),
178 # 31. mode='>', tolerance=0
179 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">", None, xtd.EXP_CDIS_GT1),
180 # 32. mode='<=', tolerance=1e-8
181 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<=", 1e-8, xtd.EXP_CDIS_LE0),
182 # 33. mode='<=', tolerance=0
183 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<=", None, xtd.EXP_CDIS_LE1),
184 # 34. mode='<', tolerance=1e-8
185 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<", 1e-8, xtd.EXP_CDIS_LT0),
186 # 35. mode='<', tolerance=0
187 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<", None, xtd.EXP_CDIS_LT1),
188 # 36. mode='==', tolerance=1e-8
189 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", 1e-8, xtd.EXP_CDIS_EQ0),
190 # 37. mode='==', tolerance=0
191 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", None, xtd.EXP_CDIS_EQ1),
192 # 38. mode='!=', tolerance=1e-8
193 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "!=", 1e-8, xtd.EXP_CDIS_NE0),
194 # 39. mode='!=', tolerance=0
195 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "!=", None, xtd.EXP_CDIS_NE1),
196 ##################################
197 # 0-D Integer & float comparison #
198 ##################################
199 # 40. Integer DataArray, float comparison
200 (
201 xr.DataArray([5], dims=["a"], coords={"a": [0]}),
202 4.0,
203 ">",
204 1e-8,
205 xtd.EXP_CDIS_ONE,
206 ),
207 # 41. Integer DataArray, integer comparison (output should still be float!)
208 (
209 xr.DataArray([5], dims=["a"], coords={"a": [0]}),
210 4,
211 ">",
212 1e-8,
213 xtd.EXP_CDIS_ONE,
214 ),
215 # 42. Float DataArray, integer comparison
216 (
217 xr.DataArray([5.0], dims=["a"], coords={"a": [0]}),
218 4.0,
219 ">",
220 1e-8,
221 xtd.EXP_CDIS_ONE,
222 ),
223 # 43. Float DataArray data, 0-D float comparison
224 (
225 xr.DataArray(2.3),
226 2.0,
227 ">",
228 1e-8,
229 xr.DataArray(
230 1.0,
231 attrs={"discretisation_tolerance": 1e-8, "discretisation_mode": ">"},
232 ),
233 ),
234 # 44. NaN in comparison
235 (
236 xr.DataArray(2),
237 np.nan,
238 ">",
239 0,
240 xr.DataArray(
241 np.nan,
242 attrs={"discretisation_tolerance": 0, "discretisation_mode": ">"},
243 ),
244 ),
245 ##########################
246 # Float & int DataArrays #
247 ##########################
248 # 45. Integer DataArray, float comparison
249 (
250 xr.DataArray([5], dims=["a"], coords={"a": [0]}),
251 xr.DataArray([4.0], dims=["a"], coords={"a": [0]}),
252 ">",
253 1e-8,
254 xtd.EXP_CDIS_ONE,
255 ),
256 # 46. Integer DataArray, integer comparison (output should still be float!)
257 (
258 xr.DataArray([5], dims=["a"], coords={"a": [0]}),
259 xr.DataArray([4], dims=["a"], coords={"a": [0]}),
260 ">",
261 1e-8,
262 xtd.EXP_CDIS_ONE,
263 ),
264 # 47. Float DataArray, integer comparison
265 (
266 xr.DataArray([5.0], dims=["a"], coords={"a": [0]}),
267 xr.DataArray([4.0], dims=["a"], coords={"a": [0]}),
268 ">",
269 1e-8,
270 xtd.EXP_CDIS_ONE,
271 ),
272 # 48. Float DataArray data, 0-D float comparison
273 (
274 xr.DataArray(2.3),
275 xr.DataArray([2.0], dims=["a"], coords={"a": [0]}),
276 ">",
277 1e-8,
278 xtd.EXP_CDIS_ONE,
279 ),
280 ],
281)
282def test_comparative_discretise(data, comparison, mode, abs_tolerance, expected):
283 """
284 Tests comparative_discretise
285 """
286 calculated = comparative_discretise(data, comparison, mode, abs_tolerance=abs_tolerance)
287 xr.testing.assert_equal(calculated, expected)
290@pytest.mark.parametrize(
291 ("data", "comparison", "mode", "abs_tolerance", "error_class", "error_msg_snippet"),
292 [
293 # invalid abs_tolerance
294 (
295 xr.DataArray(0.5),
296 xr.DataArray(0.3),
297 ">=",
298 -1e-8,
299 ValueError,
300 "value -1e-08 of abs_tolerance is invalid, it must be a non-negative float",
301 ),
302 # invalid mode
303 (
304 xr.DataArray(0.5),
305 xr.DataArray(0.3),
306 "&",
307 1e-8,
308 ValueError,
309 "'&' is not a valid mode. Available modes are: ['<', '<=', '>', '>=', '!=', '==']",
310 ),
311 # invalid comparison
312 (
313 xr.DataArray(0.5),
314 np.array(0.5),
315 ">",
316 None,
317 TypeError,
318 "comparison must be a float, int or xarray.DataArray",
319 ),
320 ],
321)
322def test_comparative_discretise_raises(data, comparison, mode, abs_tolerance, error_class, error_msg_snippet):
323 """
324 Tests that .comparitive_discretise raises the correct error
325 """
326 with pytest.raises(error_class) as exc:
327 comparative_discretise(data, comparison, mode, abs_tolerance=abs_tolerance)
328 assert error_msg_snippet in str(exc.value)
331@pytest.mark.parametrize(
332 ("data", "thresholds", "mode", "abs_tolerance", "autosqueeze", "expected"),
333 [
334 # Test autosqueeze
335 # 0. mode='>=', 0-D, value greater than threshold, autosqueeze=False
336 (xr.DataArray(0.5), [0.3], ">=", None, False, xtd.EXP_DIS_0),
337 # 1. mode='>=', 0-D, test autosqueeze
338 (xr.DataArray(0.5), [0.3], ">=", None, True, xtd.EXP_DIS_1),
339 # SMOKE TESTS FOR ALL MODES: ['<', '>', '>=', '<=', '==', '!=']
340 # 2. mode='>=', tolerance=1e-8
341 (xtd.DATA_5X1_POINT4, [0.4], ">=", 1e-8, True, xtd.EXP_DIS_GE0),
342 # 3. mode='>=', tolerance=0
343 (xtd.DATA_5X1_POINT4, [0.4], ">=", None, True, xtd.EXP_DIS_GE1),
344 # 4. mode='>', tolerance=1e-8
345 (xtd.DATA_5X1_POINT4, [0.4], ">", 1e-8, True, xtd.EXP_DIS_GT0),
346 # 5. mode='>', tolerance=0
347 (xtd.DATA_5X1_POINT4, [0.4], ">", None, True, xtd.EXP_DIS_GT1),
348 # 6. mode='<=', tolerance=1e-8
349 (xtd.DATA_5X1_POINT4, [0.4], "<=", 1e-8, True, xtd.EXP_DIS_LE0),
350 # 7. mode='<=', tolerance=0
351 (xtd.DATA_5X1_POINT4, [0.4], "<=", None, True, xtd.EXP_DIS_LE1),
352 # 8. mode='<', tolerance=1e-8
353 (xtd.DATA_5X1_POINT4, [0.4], "<", 1e-8, True, xtd.EXP_DIS_LT0),
354 # 9. mode='<', tolerance=0
355 (xtd.DATA_5X1_POINT4, [0.4], "<", None, True, xtd.EXP_DIS_LT1),
356 # 10. mode='==', tolerance=1e-8
357 (xtd.DATA_5X1_POINT4, [0.4], "==", 1e-8, True, xtd.EXP_DIS_EQ0),
358 # 11. mode='==', tolerance=0
359 (xtd.DATA_5X1_POINT4, [0.4], "==", None, True, xtd.EXP_DIS_EQ1),
360 # 12. mode='!=', tolerance=1e-8
361 (xtd.DATA_5X1_POINT4, [0.4], "!=", 1e-8, True, xtd.EXP_DIS_NE0),
362 # 13. mode='!=', tolerance=0
363 (xtd.DATA_5X1_POINT4, [0.4], "!=", None, True, xtd.EXP_DIS_NE1),
364 # Dataset input
365 # 14. 1-D data,
366 (
367 xr.Dataset({"zero": xtd.DATA_4X1, "one": xtd.DATA_4X1_NAN}),
368 [0.4, 0.5],
369 ">=",
370 1e-8,
371 True,
372 xtd.EXP_DIS_DS_1D,
373 ),
374 # 15. autosqueeze=True
375 (
376 xr.Dataset({"zero": xr.DataArray(0.5)}),
377 [0.3],
378 ">=",
379 None,
380 True,
381 xtd.EXP_DIS_DS_0D,
382 ),
383 # 16: 2-D with NaN,
384 (xtd.DATA_4X2_POINT4_POINT5_NAN, [0.4, 0.5], ">=", 1e-8, True, xtd.EXP_CDIS_14),
385 # Checks with 0-D input
386 # 17: float
387 (xtd.DATA_5X1_POINT4, 0.4, ">=", 1e-8, True, xtd.EXP_DIS_GE0),
388 # 18: np.array
389 (xtd.DATA_5X1_POINT4, np.array(0.4), "<=", 1e-8, True, xtd.EXP_DIS_LE0),
390 # 19: xr.DataArray
391 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", 1e-8, True, xtd.EXP_DIS_EQ0),
392 # 20: ignore autosqueeze
393 (xtd.DATA_5X1_POINT4, 0.4, ">=", 1e-8, False, xtd.EXP_DIS_GE0),
394 ],
395)
396def test_binary_discretise(data, thresholds, mode, abs_tolerance, autosqueeze, expected):
397 """
398 Tests binary_discretise
399 """
400 calc = binary_discretise(data, thresholds, mode, abs_tolerance=abs_tolerance, autosqueeze=autosqueeze)
401 xr.testing.assert_equal(calc, expected)
404@pytest.mark.parametrize(
405 (
406 "data",
407 "thresholds",
408 "mode",
409 "abs_tolerance",
410 "autosqueeze",
411 "error_class",
412 "error_msg_snippet",
413 ),
414 [
415 # invalid thresholds
416 (
417 xr.DataArray(0.5),
418 [0.3, 0.2],
419 ">=",
420 None,
421 False,
422 ValueError,
423 "Values in `thresholds` are not montonic increasing",
424 ),
425 # invalid abs_tolerance
426 (
427 xr.DataArray(0.5),
428 [0.2, 0.5],
429 ">=",
430 -1e-8,
431 False,
432 ValueError,
433 "value -1e-08 of abs_tolerance is invalid, it must be a non-negative float",
434 ),
435 # invalid mode
436 (
437 xr.DataArray(0.5),
438 [0.2, 0.5],
439 "&",
440 1e-8,
441 False,
442 ValueError,
443 "'&' is not a valid mode. Available modes are: ['<', '<=', '>', '>=', '!=', '==']",
444 ),
445 # 'threshold' in data.dims:
446 (
447 xr.DataArray([0.5], dims=["threshold"]),
448 [0.3],
449 ">=",
450 None,
451 False,
452 ValueError,
453 "'threshold' must not be in the supplied data object dimensions",
454 ),
455 ],
456)
457def test_binary_discretise_raises(data, thresholds, mode, abs_tolerance, autosqueeze, error_class, error_msg_snippet):
458 """
459 Tests that binary_discretise raises the correct error
460 """
461 with pytest.raises(error_class) as exc:
462 binary_discretise(data, thresholds, mode, abs_tolerance=abs_tolerance, autosqueeze=autosqueeze)
463 assert error_msg_snippet in str(exc.value)
466@pytest.mark.parametrize(
467 ("da"),
468 [
469 (xr.DataArray([0, 1, 2])),
470 (xr.DataArray([0, 1, -1])),
471 (xr.DataArray([0, 1, 0.5])),
472 (xr.DataArray([[0, 1, 1.0000001], [0, 1, 1]])),
473 ],
474)
475def test_check_binary_raises(da):
476 """test check_binary raises"""
477 with pytest.raises(ValueError) as exc:
478 check_binary(da, "my name")
479 assert "`my name` contains values that are not in the set {0, 1, np.nan}" in str(exc.value)
482@pytest.mark.parametrize(
483 ("da"),
484 [
485 (xr.DataArray([0, 1])),
486 (xr.DataArray([0, 0])),
487 (xr.DataArray([1, 1])),
488 (xr.DataArray([0, 1, np.nan])),
489 (xr.DataArray([[0, 1, np.nan], [0, 1, np.nan]])),
490 ],
491)
492def test_check_binary_doesnt_raise(da):
493 """test check_binary doesn't raise"""
494 check_binary(da, "my name")