Coverage for tests/continuous/test_quantile_loss.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2Contains unit tests for scores.probability.continuous 

3""" 

4import dask 

5import dask.array 

6import numpy as np 

7import pytest 

8import xarray as xr 

9 

10from scores.continuous import quantile_score 

11from scores.utils import DimensionError 

12from tests.assertions import assert_dataarray_equal, assert_dataset_equal 

13from tests.continuous import quantile_loss_test_data as qltd 

14 

15 

16@pytest.mark.parametrize("alpha", [1.0, 1.1, 0, -0.8]) 

17def test_qsf_value_error_alpha(alpha): 

18 """quantile_score raises ValueError.""" 

19 with pytest.raises(ValueError): 

20 quantile_score(qltd.DA1_2X2, qltd.DA1_2X2, alpha) 

21 

22 

23@pytest.mark.parametrize( 

24 ("obs", "reduce_dims", "preserve_dims"), 

25 [ 

26 # fcst and obs with mismatched dims 

27 (qltd.DA1_2X2X2, None, None), 

28 # dims not in fcst or obs 

29 (qltd.DA1_2X2, ["bananas"], None), 

30 (qltd.DA1_2X2, None, ["orange"]), 

31 ], 

32) 

33def test_qsf_exceptions(obs, reduce_dims, preserve_dims): 

34 """quantile_score raises DimensionError.""" 

35 with pytest.raises(DimensionError): 

36 quantile_score(qltd.DA1_2X2, obs, 0.5, reduce_dims=reduce_dims, preserve_dims=preserve_dims) 

37 

38 

39@pytest.mark.parametrize( 

40 ("fcst", "obs", "alpha", "preserve_dims", "reduce_dims", "weights", "expected"), 

41 [ 

42 ( 

43 qltd.DA1_2X2, 

44 qltd.DA2_2X2, 

45 0.7, 

46 ["i", "j"], 

47 None, 

48 None, 

49 xr.DataArray([[0.9, 0], [0.7, 1.4]], coords=[[0, 1], [0, 1]], dims=["i", "j"]), 

50 ), 

51 ( 

52 qltd.DA2_2X2, 

53 qltd.DA1_2X2, 

54 0.7, 

55 ["i", "j"], 

56 None, 

57 None, 

58 xr.DataArray([[2.1, 0], [0.3, 0.6]], coords=[[0, 1], [0, 1]], dims=["i", "j"]), 

59 ), 

60 ( 

61 qltd.FCST1, 

62 qltd.OBS1, 

63 0.2, 

64 ["valid_start", "station_index", "lead_time"], 

65 None, 

66 None, 

67 qltd.EXPECTED1, 

68 ), 

69 ( 

70 qltd.FCST1, 

71 qltd.OBS1, 

72 0.1, 

73 ["lead_time"], 

74 None, 

75 None, 

76 qltd.EXPECTED2, 

77 ), 

78 ( 

79 qltd.FCST1, 

80 qltd.OBS1, 

81 0.1, 

82 None, 

83 ["valid_start", "station_index"], 

84 None, 

85 qltd.EXPECTED2, 

86 ), 

87 # To test weight 

88 ( 

89 qltd.FCST1, 

90 qltd.OBS1, 

91 0.1, 

92 None, 

93 ["valid_start", "station_index"], 

94 qltd.WEIGHTS, 

95 qltd.EXPECTED4, 

96 ), 

97 # To test missing data 

98 (qltd.FCST2, qltd.FCST2, 0.8, ["lead_time"], None, None, qltd.EXPECTED3), 

99 (qltd.FCST2, qltd.FCST2, 0.8, None, None, None, xr.DataArray([0.0]).squeeze()), 

100 # To test function can handle Dataset as input 

101 ( 

102 qltd.FCST_DS, 

103 qltd.OBS_DS, 

104 0.1, 

105 None, 

106 ["valid_start", "station_index"], 

107 None, 

108 qltd.EXPECTED_DS1, 

109 ), 

110 ( 

111 qltd.FCST_DS, 

112 qltd.OBS_DS, 

113 0.1, 

114 None, 

115 ["valid_start", "station_index"], 

116 qltd.WEIGHTS_DS, 

117 qltd.EXPECTED_DS2, 

118 ), 

119 ], 

120) 

121def test_qsf_calculations(fcst, obs, alpha, preserve_dims, reduce_dims, weights, expected): 

122 """quantile_score returns the expected object.""" 

123 result = quantile_score(fcst, obs, alpha, preserve_dims=preserve_dims, reduce_dims=reduce_dims, weights=weights) 

124 if isinstance(fcst, xr.DataArray): 

125 assert_dataarray_equal(result, expected, decimals=7) 

126 else: 

127 assert_dataset_equal(result, expected, decimals=7) 

128 

129 

130def test_quantile_score_dask(): 

131 """Tests quantile_score works with dask""" 

132 result = quantile_score( 

133 fcst=qltd.FCST1.chunk(), 

134 obs=qltd.OBS1.chunk(), 

135 alpha=0.1, 

136 reduce_dims=["valid_start", "station_index"], 

137 ) 

138 assert isinstance(result.data, dask.array.Array) 

139 result = result.compute() 

140 assert isinstance(result.data, np.ndarray) 

141 assert_dataarray_equal(result, qltd.EXPECTED2, decimals=7)