Coverage for tests/probabilty/test_roc.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2Contains unit tests for scores.probability.roc_impl 

3""" 

4import dask 

5import numpy as np 

6import pytest 

7import xarray as xr 

8 

9from scores.probability import roc_curve_data 

10from tests.probabilty import roc_test_data as rtd 

11 

12 

13@pytest.mark.parametrize( 

14 ("fcst", "obs", "thresholds", "preserve_dims", "reduce_dims", "weights", "expected"), 

15 [ 

16 # preserve_dims=['lead_day'] 

17 ( 

18 rtd.FCST_2X3X2_WITH_NAN, 

19 rtd.OBS_3X3_WITH_NAN, 

20 [0, 0.3, 1], 

21 ["lead_day"], 

22 None, 

23 None, 

24 rtd.EXP_ROC_LEADDAY, 

25 ), 

26 # reduce_dims=['letter', 'pet'] 

27 ( 

28 rtd.FCST_2X3X2_WITH_NAN, 

29 rtd.OBS_3X3_WITH_NAN, 

30 [0, 0.3, 1], 

31 None, 

32 ["letter", "pet"], 

33 None, 

34 rtd.EXP_ROC_LEADDAY, 

35 ), 

36 # Weighting by lead day and reducing all dims except `lead_day` should produce identical results 

37 ( 

38 rtd.FCST_2X3X2_WITH_NAN, 

39 rtd.OBS_3X3_WITH_NAN, 

40 [0, 0.3, 1], 

41 None, 

42 ["letter", "pet"], 

43 rtd.LEAD_DAY_WEIGHTS, 

44 rtd.EXP_ROC_LEADDAY, 

45 ), 

46 # preserve_dims=None, reduce_dims=None 

47 ( 

48 rtd.FCST_2X3X2_WITH_NAN, 

49 rtd.OBS_3X3_WITH_NAN, 

50 [0, 0.3, 1], 

51 None, 

52 None, 

53 None, 

54 rtd.EXP_ROC_NONE, 

55 ), 

56 # preserve_dims=None, reduce_dims=None, weight by lead_day 

57 ( 

58 rtd.FCST_2X3X2_WITH_NAN, 

59 rtd.OBS_3X3_WITH_NAN, 

60 [0, 0.3, 1], 

61 None, 

62 None, 

63 rtd.LEAD_DAY_WEIGHTS, 

64 rtd.EXP_ROC_NONE_WEIGHTED, 

65 ), 

66 # Test AUC works with multiple DIMS preserved 

67 ( 

68 rtd.FCST_2X3X2_WITH_NAN, 

69 rtd.OBS_3X3_WITH_NAN, 

70 [0, 0.3, 1], 

71 ["lead_day", "letter"], 

72 None, 

73 None, 

74 rtd.EXP_ROC_MULTI_DIMS, 

75 ), 

76 # Test AUC works with multiple DIMS preserved (and dim order switched) 

77 ( 

78 rtd.FCST_2X3X2_WITH_NAN, 

79 rtd.OBS_3X3_WITH_NAN, 

80 [0, 0.3, 1], 

81 ["letter", "lead_day"], 

82 None, 

83 None, 

84 rtd.EXP_ROC_MULTI_DIMS, 

85 ), 

86 ], 

87) 

88def test_roc_curve_data(fcst, obs, thresholds, preserve_dims, reduce_dims, weights, expected): 

89 """ 

90 Tests the roc_curve_data 

91 """ 

92 result = roc_curve_data( 

93 fcst, obs, thresholds, preserve_dims=preserve_dims, reduce_dims=reduce_dims, weights=weights 

94 ) 

95 result.broadcast_equals(expected) 

96 

97 

98def test_roc_curve_data_dask(): 

99 """tests that roc_curve_data works with dask""" 

100 result = roc_curve_data( 

101 rtd.FCST_2X3X2_WITH_NAN.chunk(), 

102 rtd.OBS_3X3_WITH_NAN.chunk(), 

103 [0, 0.3, 1], 

104 preserve_dims=["letter", "lead_day"], 

105 check_args=False, 

106 ) 

107 assert isinstance(result.POD.data, dask.array.Array) # type: ignore 

108 assert isinstance(result.POFD.data, dask.array.Array) # type: ignore 

109 assert isinstance(result.AUC.data, dask.array.Array) # type: ignore 

110 

111 result = result.compute() 

112 

113 assert isinstance(result.POD.data, np.ndarray) 

114 assert isinstance(result.POFD.data, np.ndarray) 

115 assert isinstance(result.AUC.data, np.ndarray) 

116 result.broadcast_equals(rtd.EXP_ROC_MULTI_DIMS) 

117 

118 

119@pytest.mark.parametrize( 

120 ("fcst", "obs", "thresholds", "preserve_dims", "error_class", "error_msg_snippet"), 

121 [ 

122 # fcst has invalid values 

123 ( 

124 xr.DataArray([43, 100, 1, 0]), 

125 xr.DataArray([0, 1, 1, 0]), 

126 [0.3], 

127 None, 

128 ValueError, 

129 "`fcst` contains values outside of the range [0, 1]", 

130 ), 

131 # obs has invalid values 

132 ( 

133 xr.DataArray([0, 0.3, 1, 0]), 

134 xr.DataArray([0, 1, 1, 0.5]), 

135 [0.3], 

136 None, 

137 ValueError, 

138 "`obs` contains values that are not in the set {0, 1, np.nan}", 

139 ), 

140 # 'threshold' in fcst dims 

141 ( 

142 xr.DataArray([0.5], dims=["threshold"]), 

143 xr.DataArray(0), 

144 [0.3], 

145 None, 

146 ValueError, 

147 "'threshold' must not be in the supplied data object dimensions", 

148 ), 

149 # put an np.inf in the thresholds 

150 ( 

151 rtd.FCST_2X3X2_WITH_NAN, 

152 rtd.OBS_3X3_WITH_NAN, 

153 [-np.inf, 0, 0.3, 1, np.inf], 

154 None, 

155 ValueError, 

156 "`thresholds` contains values outside of the range [0, 1]", 

157 ), 

158 # thresholds are not monotonic increasing 

159 ( 

160 rtd.FCST_2X3X2_WITH_NAN, 

161 rtd.OBS_3X3_WITH_NAN, 

162 [0, 0.3, 0.2, 1], 

163 None, 

164 ValueError, 

165 "`thresholds` is not monotonic increasing between 0 and 1", 

166 ), 

167 ], 

168) 

169def test_roc_curve_data_raises(fcst, obs, thresholds, preserve_dims, error_class, error_msg_snippet): 

170 """ 

171 Tests that roc_curve_data raises the correct error 

172 """ 

173 with pytest.raises(error_class) as exc: 

174 roc_curve_data(fcst, obs, thresholds, preserve_dims=preserve_dims) 

175 assert error_msg_snippet in str(exc.value)