Coverage for tests/probabilty/test_crps.py: 100%

87 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1# pylint disable: too-many-arguments 

2""" 

3Contains unit tests for scores.probability.crps 

4""" 

5import dask 

6import dask.array 

7import numpy as np 

8import pytest 

9import xarray as xr 

10 

11from scores.probability import ( 

12 adjust_fcst_for_crps, 

13 crps_cdf, 

14 crps_cdf_brier_decomposition, 

15 crps_for_ensemble, 

16) 

17from scores.probability.crps_impl import ( 

18 crps_cdf_exact, 

19 crps_cdf_reformat_inputs, 

20 crps_cdf_trapz, 

21 crps_step_threshold_weight, 

22) 

23from tests.assertions import assert_dataarray_equal, assert_dataset_equal 

24from tests.probabilty import crps_test_data 

25 

26 

27@pytest.mark.parametrize( 

28 ("weight_upper", "expected"), 

29 [ 

30 (True, crps_test_data.EXP_STEP_WEIGHT_UPPER), 

31 (False, crps_test_data.EXP_STEP_WEIGHT_LOWER), 

32 ], 

33) 

34def test_crps_stepweight( 

35 weight_upper, 

36 expected, 

37): 

38 """Tests `crps_step_threshold_weight` with a variety of inputs.""" 

39 result = crps_step_threshold_weight( 

40 crps_test_data.DA_STEP_WEIGHT, 

41 "x", 

42 [1, 2, 3, 4, 5, 6], 

43 True, 

44 0.2, 

45 weight_upper, 

46 ) 

47 assert_dataarray_equal(result, expected, decimals=7) 

48 

49 

50def test_crps_cdf_exact(): 

51 """Tests `crps_cdf_exact`.""" 

52 result = crps_cdf_exact( 

53 crps_test_data.DA_FCST_CRPS_EXACT, 

54 crps_test_data.DA_OBS_CRPS_EXACT, 

55 crps_test_data.DA_WT_CRPS_EXACT, 

56 "x", 

57 include_components=True, 

58 ) 

59 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=7) 

60 

61 result2 = crps_cdf_exact( 

62 crps_test_data.DA_FCST_CRPS_EXACT, 

63 crps_test_data.DA_OBS_CRPS_EXACT, 

64 crps_test_data.DA_WT_CRPS_EXACT, 

65 "x", 

66 include_components=False, 

67 ) 

68 

69 assert list(result2.data_vars) == ["total"] 

70 

71 

72def test_crps_cdf_exact_dask(): 

73 """Tests `crps_cdf_exact` works with Dask.""" 

74 result = crps_cdf_exact( 

75 crps_test_data.DA_FCST_CRPS_EXACT.chunk(), 

76 crps_test_data.DA_OBS_CRPS_EXACT.chunk(), 

77 crps_test_data.DA_WT_CRPS_EXACT, 

78 "x", 

79 include_components=True, 

80 ) 

81 assert isinstance(result.total.data, dask.array.Array) 

82 result = result.compute() 

83 assert isinstance(result.total.data, np.ndarray) 

84 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=7) 

85 

86 result2 = crps_cdf_exact( 

87 crps_test_data.DA_FCST_CRPS_EXACT.chunk(), 

88 crps_test_data.DA_OBS_CRPS_EXACT.chunk(), 

89 crps_test_data.DA_WT_CRPS_EXACT.chunk(), 

90 "x", 

91 include_components=False, 

92 ) 

93 assert isinstance(result2.total.data, dask.array.Array) 

94 result2 = result2.compute() 

95 assert isinstance(result2.total.data, np.ndarray) 

96 assert list(result2.data_vars) == ["total"] 

97 

98 

99@pytest.mark.parametrize( 

100 ("weight_upper", "expected"), 

101 [ 

102 (True, crps_test_data.EXP_STEP_WEIGHT_UPPER), 

103 (False, crps_test_data.EXP_STEP_WEIGHT_LOWER), 

104 ], 

105) 

106def test_crps_stepweight2( 

107 weight_upper, 

108 expected, 

109): 

110 """Tests `crps_step_threshold_weight` with a variety of inputs.""" 

111 result = crps_step_threshold_weight( 

112 crps_test_data.DA_STEP_WEIGHT, 

113 "x", 

114 [1, 2, 3, 4, 5, 6], 

115 True, 

116 0.2, 

117 weight_upper, 

118 ) 

119 assert_dataarray_equal(result, expected, decimals=7) 

120 

121 

122def test_crps_cdf_trapz(): 

123 """Tests `crps_cdf_trapz`. Uses dense interpolation to get an approximate (to 4 dec pl) result.""" 

124 result = crps_cdf_trapz( 

125 crps_test_data.DA_FCST_CRPS_DENSE, 

126 crps_test_data.DA_OBS_CRPS_DENSE, 

127 crps_test_data.DA_WT_CRPS_DENSE, 

128 "x", 

129 include_components=True, 

130 ) 

131 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=4) 

132 

133 result2 = crps_cdf_trapz( 

134 crps_test_data.DA_FCST_CRPS_DENSE, 

135 crps_test_data.DA_OBS_CRPS_DENSE, 

136 crps_test_data.DA_WT_CRPS_DENSE, 

137 "x", 

138 include_components=False, 

139 ) 

140 assert_dataset_equal(result, crps_test_data.EXP_CRPS_EXACT, decimals=4) 

141 assert list(result2.data_vars) == ["total"] 

142 

143 

144@pytest.mark.parametrize( 

145 ( 

146 "threshold_weight", 

147 "additional_thresholds", 

148 "expected", 

149 ), 

150 [ 

151 ( 

152 crps_test_data.DA_WT_REFORMAT1, 

153 None, 

154 crps_test_data.EXP_REFORMAT1, 

155 ), 

156 ( 

157 crps_test_data.DA_WT_REFORMAT1, 

158 [0, 1.5], 

159 crps_test_data.EXP_REFORMAT2, 

160 ), 

161 ( 

162 None, 

163 None, 

164 crps_test_data.EXP_REFORMAT3, 

165 ), 

166 ], 

167) 

168def test_crps_cdf_reformat_inputs( 

169 threshold_weight, 

170 additional_thresholds, 

171 expected, 

172): 

173 """Tests `crps_cdf_reformat_inputs` with a variety of inputs.""" 

174 result = crps_cdf_reformat_inputs( 

175 crps_test_data.DA_FCST_REFORMAT1, 

176 crps_test_data.DA_OBS_REFORMAT1, 

177 "x", 

178 threshold_weight, 

179 additional_thresholds, 

180 "linear", 

181 "forward", 

182 ) 

183 assert len(result) == len(expected) 

184 

185 for res, exp in zip(result, expected): 

186 assert_dataarray_equal(res, exp, decimals=7) 

187 

188 

189@pytest.mark.parametrize( 

190 ( 

191 "fcst", 

192 "obs", 

193 "threshold_dim", 

194 "threshold_weight", 

195 "fcst_fill_method", 

196 "threshold_weight_fill_method", 

197 "integration_method", 

198 "dims", 

199 "error_msg_snippet", 

200 ), 

201 [ 

202 ( 

203 crps_test_data.DA_FCST_REFORMAT1, 

204 crps_test_data.DA_OBS_REFORMAT1, 

205 "y", 

206 None, 

207 "linear", 

208 None, 

209 "exact", 

210 None, 

211 "'y' is not a dimension of `fcst`", 

212 ), 

213 ( 

214 crps_test_data.DA_FCST_REFORMAT1, 

215 crps_test_data.DA_OBS_REFORMAT1, 

216 "x", 

217 crps_test_data.DA_WT_CHECK_CRPS1, 

218 "linear", 

219 "forward", 

220 "exact", 

221 None, 

222 "'x' is not a dimension of `threshold_weight`", 

223 ), 

224 ( 

225 crps_test_data.DA_FCST_REFORMAT1, 

226 crps_test_data.DA_OBS_REFORMAT1, 

227 "station", 

228 None, 

229 "linear", 

230 None, 

231 "exact", 

232 None, 

233 "'station' is a dimension of `obs`", 

234 ), 

235 ( 

236 crps_test_data.DA_FCST_REFORMAT1, 

237 crps_test_data.DA_OBS_CHECK_CRPS, 

238 "x", 

239 None, 

240 "linear", 

241 None, 

242 "exact", 

243 None, 

244 "Dimensions of `obs` must be a subset of dimensions of `fcst`", 

245 ), 

246 ( 

247 crps_test_data.DA_FCST_REFORMAT1, 

248 crps_test_data.DA_OBS_REFORMAT1, 

249 "x", 

250 crps_test_data.DA_WT_CHECK_CRPS2, 

251 "linear", 

252 "backward", 

253 "exact", 

254 None, 

255 "Dimensions of `threshold_weight` must be a subset of dimensions of `fcst`", 

256 ), 

257 # TODO: Revisit if still needed after more handling in gather_dimensions 

258 # ( 

259 # crps_test_data.DA_FCST_REFORMAT1, 

260 # crps_test_data.DA_OBS_REFORMAT1, 

261 # "x", 

262 # None, 

263 # "linear", 

264 # None, 

265 # "exact", 

266 # ["y", "x"], 

267 # "`dims` must be a subset of `fcst` dimensions", 

268 # ), 

269 ( 

270 crps_test_data.DA_FCST_REFORMAT1, 

271 crps_test_data.DA_OBS_REFORMAT1, 

272 "x", 

273 None, 

274 "fat", 

275 None, 

276 "exact", 

277 None, 

278 "`fcst_fill_method` must be 'linear', 'step', 'forward' or 'backward'", 

279 ), 

280 ( 

281 crps_test_data.DA_FCST_REFORMAT1, 

282 crps_test_data.DA_OBS_REFORMAT1, 

283 "x", 

284 crps_test_data.DA_WT_REFORMAT1, 

285 "linear", 

286 "fat", 

287 "exact", 

288 None, 

289 "`threshold_weight_fill_method` must be 'linear', 'step', 'forward' or 'backward'", 

290 ), 

291 ( 

292 crps_test_data.DA_FCST_REFORMAT1, 

293 crps_test_data.DA_OBS_REFORMAT1, 

294 "x", 

295 crps_test_data.DA_WT_REFORMAT1, 

296 "linear", 

297 "step", 

298 "waffly", 

299 None, 

300 "`integration_method` must be 'exact' or 'trapz'", 

301 ), 

302 ( 

303 crps_test_data.DA_FCST_CHECK_CRPS, 

304 crps_test_data.DA_OBS_REFORMAT1, 

305 "x", 

306 crps_test_data.DA_WT_REFORMAT1, 

307 "linear", 

308 "step", 

309 "exact", 

310 None, 

311 "`threshold_dim` in `fcst` must have at least 2 values to calculate CRPS", 

312 ), 

313 ( 

314 crps_test_data.DA_FCST_CHECK_CRPS2, 

315 crps_test_data.DA_OBS_CHECK_CRPS2, 

316 "x", 

317 None, 

318 "linear", 

319 "step", 

320 "exact", 

321 None, 

322 "`threshold_dim` coordinates in `fcst` must be increasing", 

323 ), 

324 ( 

325 crps_test_data.DA_FCST_CHECK_CRPS2A, 

326 crps_test_data.DA_OBS_CHECK_CRPS2, 

327 "x", 

328 crps_test_data.DA_FCST_CHECK_CRPS2, # weight 

329 "linear", 

330 "step", 

331 "exact", 

332 None, 

333 "`threshold_dim` coordinates in `threshold_weight` must be increasing", 

334 ), 

335 ( 

336 crps_test_data.DA_FCST_CHECK_CRPS2A, 

337 crps_test_data.DA_OBS_CHECK_CRPS2, 

338 "x", 

339 crps_test_data.DA_WT_CHECK_CRPS3, # weight 

340 "linear", 

341 "step", 

342 "exact", 

343 None, 

344 "`threshold_weight` has negative values", 

345 ), 

346 ], 

347) 

348# pylint: disable=too-many-arguments 

349def test_crps_cdf_raises( 

350 fcst, 

351 obs, 

352 threshold_dim, 

353 threshold_weight, 

354 fcst_fill_method, 

355 threshold_weight_fill_method, 

356 integration_method, 

357 dims, 

358 error_msg_snippet, 

359): 

360 """Check that `crps` raises exceptions as expected.""" 

361 with pytest.raises(ValueError, match=error_msg_snippet): 

362 crps_cdf( 

363 fcst, 

364 obs, 

365 threshold_dim, 

366 threshold_weight, 

367 [], 

368 True, 

369 fcst_fill_method, 

370 threshold_weight_fill_method, 

371 integration_method, 

372 dims, 

373 ) 

374 

375 

376@pytest.mark.parametrize( 

377 ( 

378 "fcst", 

379 "threshold_weight", 

380 "propagate_nan", 

381 "integration_method", 

382 "dims", 

383 "expected_and_dec", # tuple: expected and decimals 

384 ), 

385 [ 

386 ( # exact, preserve station dim 

387 crps_test_data.DA_FCST_CRPS, 

388 crps_test_data.DA_WT_CRPS, 

389 True, 

390 "exact", 

391 ["station"], 

392 (crps_test_data.EXP_CRPS1, 7), 

393 ), 

394 ( # exact, preserve no dims 

395 crps_test_data.DA_FCST_CRPS, 

396 crps_test_data.DA_WT_CRPS, 

397 True, 

398 "exact", 

399 None, 

400 (crps_test_data.EXP_CRPS2, 7), 

401 ), 

402 ( # trapz, preserve station dim 

403 crps_test_data.DA_FCST_CRPS_DENSE, 

404 crps_test_data.DA_WT_CRPS_DENSE, 

405 True, 

406 "trapz", 

407 ["station"], 

408 (crps_test_data.EXP_CRPS1, 4), 

409 ), 

410 ( # exact, preserve station dim, weight is None 

411 crps_test_data.DA_FCST_CRPS, 

412 None, 

413 True, 

414 "exact", 

415 ["station"], 

416 (crps_test_data.EXP_CRPS3, 7), 

417 ), 

418 ( # exact, preserve station dim, weight is None, don't propagate nan 

419 crps_test_data.DA_FCST_CRPS, 

420 None, 

421 False, 

422 "exact", 

423 ["station"], 

424 (crps_test_data.EXP_CRPS4, 7), 

425 ), 

426 ( # exact, preserve station dim, don't propagate nan 

427 crps_test_data.DA_FCST_CRPS, 

428 crps_test_data.DA_WT_CRPS, 

429 False, 

430 "exact", 

431 ["station"], 

432 (crps_test_data.EXP_CRPS5, 7), 

433 ), 

434 ], 

435) 

436# pylint: disable=too-many-arguments 

437def test_crps_cdf( 

438 fcst, 

439 threshold_weight, 

440 propagate_nan, 

441 integration_method, 

442 dims, 

443 expected_and_dec, 

444): 

445 """Tests `crps` with a variety of inputs.""" 

446 result = crps_cdf( 

447 fcst, 

448 crps_test_data.DA_OBS_CRPS, 

449 "x", 

450 threshold_weight=threshold_weight, 

451 additional_thresholds=None, 

452 propagate_nans=propagate_nan, 

453 fcst_fill_method="linear", 

454 threshold_weight_fill_method="forward", 

455 integration_method=integration_method, 

456 preserve_dims=dims, 

457 include_components=True, 

458 ) 

459 

460 assert_dataset_equal(result, expected_and_dec[0], decimals=expected_and_dec[1]) 

461 

462 

463@pytest.mark.parametrize( 

464 ( 

465 "fcst", 

466 "obs", 

467 "decreasing_tolerance", 

468 "expected", 

469 ), 

470 [ 

471 ( 

472 crps_test_data.DA_FCST_ADJUST1, 

473 crps_test_data.DA_OBS_ADJUST1, 

474 0, 

475 crps_test_data.EXP_FCST_ADJUST1, 

476 ), 

477 ( # exact, preserve station dim 

478 crps_test_data.DA_FCST_ADJUST2, 

479 crps_test_data.DA_OBS_ADJUST2, 

480 10, 

481 crps_test_data.EXP_FCST_ADJUST2, 

482 ), 

483 ], 

484) 

485def test_adjust_fcst_for_crps( 

486 fcst, 

487 obs, 

488 decreasing_tolerance, 

489 expected, 

490): 

491 """Tests `adjust_fcst_for_crps` with a variety of inputs.""" 

492 result = adjust_fcst_for_crps(fcst, "x", obs, decreasing_tolerance) 

493 assert_dataarray_equal(result, expected, decimals=7) 

494 

495 

496@pytest.mark.parametrize( 

497 ( 

498 "threshold_dim", 

499 "decreasing_tolerance", 

500 "error_msg_snippet", 

501 ), 

502 [ 

503 ( 

504 "aj", 

505 0, 

506 "'aj' is not a dimension of `fcst`", 

507 ), 

508 ( 

509 "x", 

510 -10, 

511 "`decreasing_tolerance` must be nonnegative", 

512 ), 

513 ], 

514) 

515def test_adjust_fcst_raises( 

516 threshold_dim, 

517 decreasing_tolerance, 

518 error_msg_snippet, 

519): 

520 """Check that `adjust_fcst_for_crps` raises exceptions as expected.""" 

521 with pytest.raises(ValueError, match=error_msg_snippet): 

522 adjust_fcst_for_crps( 

523 crps_test_data.DA_FCST_ADJUST1, 

524 threshold_dim, 

525 crps_test_data.DA_OBS_ADJUST1, 

526 decreasing_tolerance, 

527 ) 

528 

529 

530@pytest.mark.parametrize( 

531 ( 

532 "fcst", 

533 "obs", 

534 "threshold_dim", 

535 "fcst_fill_method", 

536 "dims", 

537 "error_msg_snippet", 

538 ), 

539 [ 

540 ( 

541 crps_test_data.DA_FCST_REFORMAT1, 

542 crps_test_data.DA_OBS_REFORMAT1, 

543 "y", 

544 "linear", 

545 None, 

546 "'y' is not a dimension of `fcst`", 

547 ), 

548 ( 

549 crps_test_data.DA_FCST_REFORMAT1, 

550 crps_test_data.DA_OBS_REFORMAT1, 

551 "station", 

552 "linear", 

553 None, 

554 "'station' is a dimension of `obs`", 

555 ), 

556 ( 

557 crps_test_data.DA_FCST_REFORMAT1, 

558 crps_test_data.DA_OBS_CHECK_CRPS, 

559 "x", 

560 "linear", 

561 None, 

562 "Dimensions of `obs` must be a subset of dimensions of `fcst`", 

563 ), 

564 ( 

565 crps_test_data.DA_FCST_REFORMAT1, 

566 crps_test_data.DA_OBS_REFORMAT1, 

567 "x", 

568 "fat", 

569 None, 

570 "`fcst_fill_method` must be 'linear', 'step', 'forward' or 'backward'", 

571 ), 

572 ( 

573 crps_test_data.DA_FCST_CHECK_CRPS2, 

574 crps_test_data.DA_OBS_CHECK_CRPS2, 

575 "x", 

576 "linear", 

577 None, 

578 "`threshold_dim` coordinates in `fcst` must be increasing", 

579 ), 

580 ], 

581) 

582# pylint: disable=too-many-arguments 

583def test_crps_cdf_brier_raises( 

584 fcst, 

585 obs, 

586 threshold_dim, 

587 fcst_fill_method, 

588 dims, 

589 error_msg_snippet, 

590): 

591 """Check that `crps_cdf_brier_decomposition` raises exceptions as expected.""" 

592 with pytest.raises(ValueError, match=error_msg_snippet): 

593 crps_cdf_brier_decomposition(fcst, obs, threshold_dim, fcst_fill_method=fcst_fill_method, reduce_dims=dims) 

594 

595 

596@pytest.mark.parametrize( 

597 ("dims", "expected"), 

598 [ 

599 (None, crps_test_data.EXP_CRPS_BD1), 

600 (["station"], crps_test_data.EXP_CRPS_BD2), 

601 ], 

602) 

603def test_crps_cdf_brier_decomposition(dims, expected): 

604 """Tests `crps_cdf_brier_decomposition` with a variety of inputs.""" 

605 result = crps_cdf_brier_decomposition( 

606 crps_test_data.DA_FCST_CRPS_BD, crps_test_data.DA_OBS_CRPS_BD, "x", preserve_dims=dims 

607 ) 

608 assert_dataset_equal(result, expected, decimals=7) 

609 

610 

611def test_crps_for_ensemble(): 

612 """Tests `crps_for_ensemble` returns as expected.""" 

613 result_ecdf = crps_for_ensemble( 

614 crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="ecdf", preserve_dims="all" 

615 ) 

616 result_fair = crps_for_ensemble( 

617 crps_test_data.DA_FCST_CRPSENS, crps_test_data.DA_OBS_CRPSENS, "ens_member", method="fair", preserve_dims="all" 

618 ) 

619 result_weighted_mean = crps_for_ensemble( 

620 crps_test_data.DA_FCST_CRPSENS, 

621 crps_test_data.DA_OBS_CRPSENS, 

622 "ens_member", 

623 method="ecdf", 

624 weights=crps_test_data.DA_WT_CRPSENS, 

625 ) 

626 

627 assert_dataarray_equal(result_ecdf, crps_test_data.EXP_CRPSENS_ECDF, decimals=7) 

628 assert_dataarray_equal(result_fair, crps_test_data.EXP_CRPSENS_FAIR, decimals=7) 

629 assert_dataarray_equal(result_weighted_mean, crps_test_data.EXP_CRPSENS_WT, decimals=7) 

630 

631 

632def test_crps_for_ensemble_raises(): 

633 """Tests `crps_for_ensemble` raises exception as expected.""" 

634 with pytest.raises(ValueError) as excinfo: 

635 crps_for_ensemble(xr.DataArray(data=[1]), xr.DataArray(data=[1]), "ens_member", "unfair") 

636 assert "`method` must be one of 'ecdf' or 'fair'" in str(excinfo.value) 

637 

638 

639def test_crps_for_ensemble_dask(): 

640 """Tests `crps_for_ensemble` works with dask.""" 

641 result = crps_for_ensemble( 

642 fcst=crps_test_data.DA_FCST_CRPSENS.chunk(), 

643 obs=crps_test_data.DA_OBS_CRPSENS.chunk(), 

644 ensemble_member_dim="ens_member", 

645 method="ecdf", 

646 preserve_dims="all", 

647 ) 

648 assert isinstance(result.data, dask.array.Array) 

649 result = result.compute() 

650 assert isinstance(result.data, np.ndarray) 

651 assert_dataarray_equal(result, crps_test_data.EXP_CRPSENS_ECDF, decimals=7)