Coverage for tests/categorical/test_multicategorical.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2Contains unit tests for scores.categorical 

3""" 

4import dask 

5import dask.array 

6import numpy as np 

7import pytest 

8import xarray as xr 

9 

10from scores.categorical import firm 

11from scores.categorical.multicategorical_impl import _single_category_score 

12from scores.utils import DimensionError 

13from tests.categorical import multicategorical_test_data as mtd 

14 

15 

16@pytest.mark.parametrize( 

17 ("fcst", "obs", "categorical_threshold", "discount_distance", "threshold_assignment", "expected"), 

18 [ 

19 # Threshold 5, discount = 0, preserve all dims 

20 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0, "lower", mtd.EXP_SC_CASE0), 

21 # Threshold -200, discount = 0, preserve 1 dim 

22 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, -200, 0, "lower", mtd.EXP_SC_CASE1), 

23 # Threshold 200, discount = 0, preserve 1 dim 

24 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 200, 0, "lower", mtd.EXP_SC_CASE1), 

25 # Threshold 5, discount = 7, preserve all dims. 

26 # discount_distance is maximum for both false alarms and misses 

27 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 7, "lower", mtd.EXP_SC_CASE2), 

28 # Threshold 5, discount = 0.5, preserve all dims. 

29 # discount_distance is minimum for both false alarms and misses 

30 (mtd.DA_FCST_SC, mtd.DA_OBS_SC, 5, 0.5, "lower", mtd.EXP_SC_CASE3), 

31 # Test lower/right assignment 

32 (mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "lower", mtd.EXP_SC_CASE4), 

33 # Test upper/left assignment 

34 (mtd.DA_FCST_SC2, mtd.DA_OBS_SC2, 2, None, "upper", mtd.EXP_SC_CASE5), 

35 ], 

36) 

37def test__single_category_score(fcst, obs, categorical_threshold, discount_distance, threshold_assignment, expected): 

38 """Tests _single_category_score""" 

39 risk_parameter = 0.7 

40 

41 calculated = _single_category_score( 

42 fcst, obs, risk_parameter, categorical_threshold, discount_distance, threshold_assignment 

43 ) 

44 xr.testing.assert_allclose(calculated, expected) 

45 

46 

47@pytest.mark.parametrize( 

48 ( 

49 "fcst", 

50 "obs", 

51 "risk_parameters", 

52 "categorical_thresholds", 

53 "weights", 

54 "reduce_dims", 

55 "preserve_dims", 

56 "discount_distance", 

57 "expected", 

58 ), 

59 [ 

60 # Test for single category case identical to CASE0 in 

61 # test__single_category_score 

62 ( 

63 mtd.DA_FCST_SC, 

64 mtd.DA_OBS_SC, 

65 0.7, 

66 [5], 

67 [1], 

68 None, 

69 ["i", "j", "k"], 

70 0.0, 

71 mtd.EXP_SC_CASE0, 

72 ), 

73 # Test for single category with discount distance. Identical to CASE3 

74 # in test__single_category_score 

75 ( 

76 mtd.DA_FCST_SC, 

77 mtd.DA_OBS_SC, 

78 0.7, 

79 [5], 

80 [1], 

81 None, 

82 ["i", "j", "k"], 

83 0.5, 

84 mtd.EXP_SC_CASE3, 

85 ), 

86 # Test for single category case on slightly bigger dataset 

87 # Preserve dimensions 

88 ( 

89 mtd.DA_FCST_FIRM, 

90 mtd.DA_OBS_FIRM, 

91 0.7, 

92 [5], 

93 [1], 

94 None, 

95 ["i", "j", "k"], 

96 0, 

97 mtd.EXP_FIRM_CASE0, 

98 ), 

99 # Single category, only keep one dimension 

100 ( 

101 mtd.DA_FCST_FIRM, 

102 mtd.DA_OBS_FIRM, 

103 0.7, 

104 [5], 

105 [1], 

106 None, 

107 ["i"], 

108 0, 

109 mtd.EXP_FIRM_CASE1, 

110 ), 

111 # Single category, no dimensions 

112 ( 

113 mtd.DA_FCST_FIRM, 

114 mtd.DA_OBS_FIRM, 

115 0.7, 

116 [5], 

117 [1], 

118 None, 

119 None, 

120 0, 

121 mtd.EXP_FIRM_CASE2, 

122 ), 

123 # Single category, no dimensions 

124 ( 

125 mtd.DA_FCST_FIRM, 

126 mtd.DA_OBS_FIRM, 

127 0.7, 

128 [5], 

129 [1], 

130 "all", 

131 None, 

132 0, 

133 mtd.EXP_FIRM_CASE2, 

134 ), 

135 # 2 categories, same weight 

136 ( 

137 mtd.DA_FCST_FIRM, 

138 mtd.DA_OBS_FIRM, 

139 0.7, 

140 [0, 5], 

141 [1, 1], 

142 None, 

143 ["i", "j", "k"], 

144 0, 

145 mtd.EXP_FIRM_CASE3, 

146 ), 

147 # 3 categories, same weight 

148 ( 

149 mtd.DA_FCST_FIRM, 

150 mtd.DA_OBS_FIRM, 

151 0.7, 

152 [0, 5, 8], 

153 [1, 1, 1], 

154 None, 

155 ["i", "j", "k"], 

156 0, 

157 mtd.EXP_FIRM_CASE4, 

158 ), 

159 # 2 categories, 2 weights 

160 ( 

161 mtd.DA_FCST_FIRM, 

162 mtd.DA_OBS_FIRM, 

163 0.7, 

164 [0, 5], 

165 [2, 1], 

166 None, 

167 ["i", "j", "k"], 

168 0, 

169 mtd.EXP_FIRM_CASE4, 

170 ), 

171 # 2 categories, 2 weights that are xr.DataArrays 

172 ( 

173 mtd.DA_FCST_FIRM, 

174 mtd.DA_OBS_FIRM, 

175 0.7, 

176 [0, 5], 

177 mtd.LIST_WEIGHTS_FIRM0, 

178 None, 

179 ["i", "j", "k"], 

180 0, 

181 mtd.EXP_FIRM_CASE4, 

182 ), 

183 # 2 categories, 2 weights with the first a xr.DataArray and the second 

184 # a float 

185 ( 

186 mtd.DA_FCST_FIRM, 

187 mtd.DA_OBS_FIRM, 

188 0.7, 

189 [0, 5], 

190 [mtd.LIST_WEIGHTS_FIRM0[0], 1], 

191 None, 

192 ["i", "j", "k"], 

193 0, 

194 mtd.EXP_FIRM_CASE4, 

195 ), 

196 # 2 categories, 2 weights that are xr.DataArrays with different values 

197 # for different coords 

198 ( 

199 mtd.DA_FCST_FIRM, 

200 mtd.DA_OBS_FIRM, 

201 0.7, 

202 [0, 5], 

203 mtd.LIST_WEIGHTS_FIRM1, 

204 None, 

205 ["i", "j", "k"], 

206 0, 

207 mtd.EXP_FIRM_CASE5, 

208 ), 

209 # 2 categories, 2 weights that are xr.DataArrays with different values 

210 # for different coords, with NaN in weights 

211 ( 

212 mtd.DA_FCST_FIRM, 

213 mtd.DA_OBS_FIRM, 

214 0.7, 

215 [0, 5], 

216 mtd.LIST_WEIGHTS_FIRM2, 

217 None, 

218 ["i", "j", "k"], 

219 0, 

220 mtd.EXP_FIRM_CASE6, 

221 ), 

222 ], 

223) 

224def test_firm( 

225 fcst, 

226 obs, 

227 risk_parameters, 

228 categorical_thresholds, 

229 weights, 

230 reduce_dims, 

231 preserve_dims, 

232 discount_distance, 

233 expected, 

234): 

235 """Tests firm""" 

236 calculated = firm( 

237 fcst, 

238 obs, 

239 risk_parameters, 

240 categorical_thresholds, 

241 weights, 

242 discount_distance, 

243 reduce_dims, 

244 preserve_dims, 

245 ) 

246 if preserve_dims != None: 

247 calculated = calculated.transpose(*preserve_dims) 

248 xr.testing.assert_allclose( 

249 calculated, 

250 expected, 

251 atol=0.001, 

252 ) 

253 

254 

255def test_firm_dask(): 

256 """Tests firm works with dask""" 

257 calculated = firm( 

258 mtd.DA_FCST_FIRM.chunk(), 

259 mtd.DA_OBS_FIRM.chunk(), 

260 0.7, 

261 [0, 5], 

262 mtd.LIST_WEIGHTS_FIRM2, 

263 0, 

264 reduce_dims=None, 

265 preserve_dims=["i", "j", "k"], 

266 ) 

267 

268 calculated = calculated.transpose("i", "j", "k") 

269 

270 assert isinstance(calculated.firm_score.data, dask.array.Array) 

271 calculated = calculated.compute() 

272 assert isinstance(calculated.firm_score.data, np.ndarray) 

273 xr.testing.assert_allclose( 

274 calculated, 

275 mtd.EXP_FIRM_CASE6, 

276 atol=0.001, 

277 ) 

278 

279 

280@pytest.mark.parametrize( 

281 ( 

282 "fcst", 

283 "obs", 

284 "risk_parameters", 

285 "categorical_thresholds", 

286 "weights", 

287 "preserve_dims", 

288 "discount_distance", 

289 "threshold_assignment", 

290 "error_type", 

291 "error_msg_snippet", 

292 ), 

293 [ 

294 # len(categorical_thresholds) is 0 

295 ( 

296 mtd.DA_FCST_FIRM, 

297 mtd.DA_OBS_FIRM, 

298 0.5, 

299 [], 

300 [], 

301 ["i", "j", "k"], 

302 0, 

303 "upper", 

304 ValueError, 

305 "`categorical_thresholds` must have at least", 

306 ), 

307 # weights and categories don't match. risk_parameters is a float 

308 ( 

309 mtd.DA_FCST_FIRM, 

310 mtd.DA_OBS_FIRM, 

311 0.5, 

312 [1], 

313 [1, 2], 

314 ["i", "j", "k"], 

315 0, 

316 "upper", 

317 ValueError, 

318 "`categorical_thresholds` and `weights`", 

319 ), 

320 # risk_parameter = 0 

321 ( 

322 mtd.DA_FCST_FIRM, 

323 mtd.DA_OBS_FIRM, 

324 0.0, 

325 [5], 

326 [1], 

327 ["i", "j", "k"], 

328 0, 

329 "upper", 

330 ValueError, 

331 "0 < `risk_parameter` < 1 must", 

332 ), 

333 # risk_parameter = 1 

334 ( 

335 mtd.DA_FCST_FIRM, 

336 mtd.DA_OBS_FIRM, 

337 0.0, 

338 [5], 

339 [1], 

340 ["i", "j", "k"], 

341 0, 

342 "upper", 

343 ValueError, 

344 "0 < `risk_parameter` < 1 must", 

345 ), 

346 # negative weight 

347 ( 

348 mtd.DA_FCST_FIRM, 

349 mtd.DA_OBS_FIRM, 

350 0.5, 

351 [5, 6], 

352 [1, -1], 

353 ["i", "j", "k"], 

354 0, 

355 "upper", 

356 ValueError, 

357 "`weights` must be > 0", 

358 ), 

359 # negative weight with weights being a xr.DataArray 

360 ( 

361 mtd.DA_FCST_FIRM, 

362 mtd.DA_OBS_FIRM, 

363 0.5, 

364 [5, 6], 

365 mtd.LIST_WEIGHTS_FIRM3, 

366 ["i", "j", "k"], 

367 0, 

368 "upper", 

369 ValueError, 

370 "value was found in index 0 of `weights", 

371 ), 

372 # zero weight 

373 ( 

374 mtd.DA_FCST_FIRM, 

375 mtd.DA_OBS_FIRM, 

376 0.5, 

377 [5, 6], 

378 [1, 0], 

379 ["i", "j", "k"], 

380 0, 

381 "upper", 

382 ValueError, 

383 "`weights` must be > 0", 

384 ), 

385 # zero weight with weights being a xr.DataArray 

386 ( 

387 mtd.DA_FCST_FIRM, 

388 mtd.DA_OBS_FIRM, 

389 0.5, 

390 [5, 6], 

391 mtd.LIST_WEIGHTS_FIRM4, 

392 ["i", "j", "k"], 

393 0, 

394 "upper", 

395 ValueError, 

396 "No values <= 0 are allowed in `weights`", 

397 ), 

398 # bad dims in weights 

399 ( 

400 mtd.DA_FCST_FIRM, 

401 mtd.DA_OBS_FIRM, 

402 0.5, 

403 [5, 6], 

404 mtd.LIST_WEIGHTS_FIRM5, 

405 ["i", "j", "k"], 

406 0, 

407 "upper", 

408 DimensionError, 

409 "of data object are not subset to", 

410 ), 

411 # negative discount distance 

412 ( 

413 mtd.DA_FCST_FIRM, 

414 mtd.DA_OBS_FIRM, 

415 0.5, 

416 [5], 

417 [1], 

418 ["i", "j", "k"], 

419 -1, 

420 "upper", 

421 ValueError, 

422 "`discount_distance` must be >= 0", 

423 ), 

424 # wrong threshold assignment 

425 ( 

426 mtd.DA_FCST_FIRM, 

427 mtd.DA_OBS_FIRM, 

428 0.5, 

429 [5], 

430 [1], 

431 ["i", "j", "k"], 

432 0.0, 

433 "up", 

434 ValueError, 

435 """ `threshold_assignment` must be either \"upper\" or \"lower\" """, 

436 ), 

437 ], 

438) 

439def test_firm_raises( 

440 fcst, 

441 obs, 

442 risk_parameters, 

443 categorical_thresholds, 

444 weights, 

445 preserve_dims, 

446 discount_distance, 

447 threshold_assignment, 

448 error_type, 

449 error_msg_snippet, 

450): 

451 """ 

452 Tests that the firm raises the correct errors 

453 """ 

454 with pytest.raises(error_type, match=error_msg_snippet): 

455 firm( 

456 fcst, 

457 obs, 

458 risk_parameters, 

459 categorical_thresholds, 

460 weights, 

461 discount_distance, 

462 None, 

463 preserve_dims, 

464 threshold_assignment=threshold_assignment, 

465 )