Coverage for src/scores/utils.py: 100%

121 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2Contains frequently-used functions of a general nature within scores 

3""" 

4import inspect 

5import warnings 

6from collections.abc import Hashable, Iterable, Sequence 

7from typing import Optional 

8 

9import xarray as xr 

10 

11from scores.typing import FlexibleDimensionTypes, XarrayLike 

12 

13WARN_ALL_DATA_CONFLICT_MSG = """ 

14You are requesting to reduce or preserve every dimension by specifying the string 'all'. 

15In this case, 'all' is also a named dimension in your data, leading to an ambiguity. 

16In order to reduce or preserve the named data dimension, specify ['all'] as a list item 

17rather than relying on string interpretation. The program will continue to interpret the 

18string as an instruction to reduce or preserve every dimension. 

19""" 

20 

21ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION = """ 

22You are requesting to preserve a dimension which does not appear in your data (fcst or obs). 

23It is ambiguous how to proceed therefore an exception has been raised instead. 

24""" 

25 

26ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2 = """ 

27You are requesting to preserve a dimension which does not appear in your data  

28(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been 

29raised instead. 

30""" 

31 

32ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION = """ 

33You are requesting to reduce a dimension which does not appear in your data (fcst or obs). 

34It is ambiguous how to proceed therefore an exception has been raised instead. 

35""" 

36 

37ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2 = """ 

38You are requesting to reduce a dimension which does not appear in your data 

39(fcst, obs or weights). It is ambiguous how to proceed therefore an exception has been  

40raised instead. 

41""" 

42 

43ERROR_OVERSPECIFIED_PRESERVE_REDUCE = """ 

44You have specified both preserve_dims and reduce_dims. This method doesn't know how 

45to properly interpret that, therefore an exception has been raised. 

46""" 

47 

48 

49class DimensionError(Exception): 

50 """ 

51 Custom exception used when attempting to operate over xarray DataArray or 

52 Dataset objects that do not have compatible dimensions. 

53 """ 

54 

55 

56def gather_dimensions( # pylint: disable=too-many-branches 

57 fcst_dims: Iterable[Hashable], 

58 obs_dims: Iterable[Hashable], 

59 reduce_dims: FlexibleDimensionTypes = None, 

60 preserve_dims: FlexibleDimensionTypes = None, 

61) -> set[Hashable]: 

62 """ 

63 Establish which dimensions to reduce when calculating errors but before taking means. 

64 

65 Note: `scores.utils.gather_dimensions` and `scores.utils.gather_dimensions2` will be 

66 integrated at some point in the future. `scores.utils.gather_dimensions2` offers 

67 more comprehensive and less restrictive dimension checking and should be preferred in 

68 the meantime. See `scores.probability.crps_cdf` for an example of 

69 `scores.utils.gather_dimensions2` usage. 

70 

71 Args: 

72 fcst_dims: Forecast dimensions inputs 

73 obs_dims: Observation dimensions inputs. 

74 reduce_dims: Dimensions to reduce. 

75 preserve_dims: Dimensions to preserve. 

76 

77 Returns: 

78 Dimensions based on optional args. 

79 Raises: 

80 ValueError: When `preserve_dims and `reduce_dims` are both specified. 

81 

82 See also: 

83 `scores.utils.gather_dimensions2` 

84 """ 

85 

86 all_dims = set(fcst_dims).union(set(obs_dims)) 

87 

88 # Handle error conditions related to specified dimensions 

89 if preserve_dims is not None and reduce_dims is not None: 

90 raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE) 

91 

92 # Handle error conditions related to specified dimensions 

93 specified = preserve_dims or reduce_dims 

94 if specified == "all": 

95 if "all" in all_dims: 

96 warnings.warn(WARN_ALL_DATA_CONFLICT_MSG) 

97 elif specified is not None: 

98 if isinstance(specified, str): 

99 specified = [specified] 

100 

101 if not set(specified).issubset(all_dims): 

102 if preserve_dims is not None: 

103 raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION) 

104 raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION) 

105 

106 # Handle preserve_dims case 

107 if preserve_dims is not None: 

108 if preserve_dims == "all": 

109 return set([]) 

110 

111 if isinstance(preserve_dims, str): 

112 preserve_dims = [preserve_dims] 

113 

114 reduce_dims = set(all_dims).difference(preserve_dims) 

115 

116 # Handle reduce all 

117 elif reduce_dims == "all": 

118 reduce_dims = set(all_dims) 

119 

120 # Handle is reduce_dims and preserve_dims are both None 

121 if reduce_dims is None and preserve_dims is None: 

122 reduce_dims = set(all_dims) 

123 

124 # Handle reduce by string 

125 elif isinstance(reduce_dims, str): 

126 reduce_dims = set([reduce_dims]) 

127 

128 # Turn into a set if needed 

129 assert reduce_dims is not None 

130 reduce_dims = set(reduce_dims) 

131 

132 # Reduce by list is the default so no handling needed 

133 return reduce_dims 

134 

135 

136def gather_dimensions2( 

137 fcst: xr.DataArray, 

138 obs: xr.DataArray, 

139 weights: xr.DataArray = None, 

140 reduce_dims: FlexibleDimensionTypes = None, 

141 preserve_dims: FlexibleDimensionTypes = None, 

142 special_fcst_dims: FlexibleDimensionTypes = None, 

143) -> set[Hashable]: 

144 """ 

145 Performs standard dimensions checks for inputs of functions that calculate (mean) scores. 

146 Returns a set of the dimensions to reduce. 

147 

148 Note: `scores.utils.gather_dimensions` and `scores.utils.gather_dimensions2` will be 

149 integrated at some point in the future. `scores.utils.gather_dimensions2` offers 

150 more comprehensive and less restrictive dimension checking and should be preferred in 

151 the meantime. See `scores.probability.crps_cdf` for an example of 

152 `scores.utils.gather_dimensions2` usage. 

153 

154 Args: 

155 fcst: Forecast data 

156 obs: Observation data 

157 weights: Weights for calculating a weighted mean of scores 

158 reduce_dims: Dimensions to reduce. Can be "all" to reduce all dimensions. 

159 preserve_dims: Dimensions to preserve. Can be "all" to preserve all dimensions. 

160 special_fcst_dims: Dimension(s) in `fcst` that are reduced to calculate individual scores. 

161 Must not appear as a dimension in `obs`, `weights`, `reduce_dims` or `preserve_dims`. 

162 e.g. the ensemble member dimension if calculating CRPS for ensembles, or the 

163 threshold dimension of calculating CRPS for CDFs. 

164 

165 Returns: 

166 Set of dimensions over which to take the mean once the checks are passed. 

167 

168 Raises: 

169 ValueError: when `preserve_dims and `reduce_dims` are both specified. 

170 ValueError: when `special_fcst_dims` is not a subset of `fcst.dims`. 

171 ValueError: when `obs.dims`, `weights.dims`, `reduce_dims` or `preserve_dims` 

172 contains elements from `special_fcst_dims`. 

173 ValueError: when `preserve_dims and `reduce_dims` contain elements not among dimensions 

174 of the data (`fcst`, `obs` or `weights`). 

175 

176 See also: 

177 `scores.utils.gather_dimensions` 

178 """ 

179 all_data_dims = set(fcst.dims).union(set(obs.dims)) 

180 if weights is not None: 

181 all_data_dims = all_data_dims.union(set(weights.dims)) 

182 

183 # all_scoring_dims is the set of dims remaining after individual scores are computed. 

184 all_scoring_dims = all_data_dims.copy() 

185 

186 # Handle error conditions related to specified dimensions 

187 if preserve_dims is not None and reduce_dims is not None: 

188 raise ValueError(ERROR_OVERSPECIFIED_PRESERVE_REDUCE) 

189 

190 specified_dims = preserve_dims or reduce_dims 

191 

192 if specified_dims == "all": 

193 if "all" in all_data_dims: 

194 warnings.warn(WARN_ALL_DATA_CONFLICT_MSG) 

195 elif specified_dims is not None: 

196 if isinstance(specified_dims, str): 

197 specified_dims = [specified_dims] 

198 

199 # check that special_fcst_dims are in fcst.dims only 

200 if special_fcst_dims is not None: 

201 if isinstance(special_fcst_dims, str): 

202 special_fcst_dims = [special_fcst_dims] 

203 if not set(special_fcst_dims).issubset(set(fcst.dims)): 

204 raise ValueError("`special_fcst_dims` must be a subset of `fcst` dimensions") 

205 if len(set(obs.dims).intersection(set(special_fcst_dims))) > 0: 

206 raise ValueError("`obs.dims` must not contain any `special_fcst_dims`") 

207 if weights is not None: 

208 if len(set(weights.dims).intersection(set(special_fcst_dims))) > 0: 

209 raise ValueError("`weights.dims` must not contain any `special_fcst_dims`") 

210 if specified_dims is not None and specified_dims != "all": 

211 if len(set(specified_dims).intersection(set(special_fcst_dims))) > 0: 

212 raise ValueError("`reduce_dims` and `preserve_dims` must not contain any `special_fcst_dims`") 

213 # remove special_fcst_dims from all_scoring_dims 

214 all_scoring_dims = all_scoring_dims.difference(set(special_fcst_dims)) 

215 

216 if specified_dims is not None and specified_dims != "all": 

217 if not set(specified_dims).issubset(all_scoring_dims): 

218 if preserve_dims is not None: 

219 raise ValueError(ERROR_SPECIFIED_NONPRESENT_PRESERVE_DIMENSION2) 

220 raise ValueError(ERROR_SPECIFIED_NONPRESENT_REDUCE_DIMENSION2) 

221 

222 # all errors have been captured, so now return list of dims to reduce 

223 if specified_dims is None: 

224 return all_scoring_dims 

225 if reduce_dims is not None: 

226 if reduce_dims == "all": 

227 return all_scoring_dims 

228 return set(specified_dims) 

229 if preserve_dims == "all": 

230 return set([]) 

231 return all_scoring_dims.difference(set(specified_dims)) 

232 

233 

234def dims_complement(data, dims=None) -> list[str]: 

235 """Returns the complement of data.dims and dims 

236 

237 Args: 

238 data: Input xarray object 

239 dims: an Iterable of strings corresponding to dimension names 

240 

241 Returns: 

242 A sorted list of dimension names, the complement of data.dims and dims 

243 """ 

244 

245 if dims is None: 

246 dims = [] 

247 

248 # check that dims is in data.dims, and that dims is a of a valid form 

249 check_dims(data, dims, mode="superset") 

250 

251 complement = set(data.dims) - set(dims) 

252 return sorted(list(complement)) 

253 

254 

255def check_dims(xr_data: XarrayLike, expected_dims: Sequence[str], mode: Optional[str] = None): 

256 """ 

257 Checks the dimensions xr_data with expected_dims, according to `mode`. 

258 

259 Args: 

260 xr_data: if a Dataset is supplied, 

261 all of its data variables (DataArray objects) are checked. 

262 expected_dims: an Iterable of dimension names. 

263 mode: one of 'equal' (default), 'subset' or 'superset'. 

264 If 'equal', checks that the data object has the same dimensions 

265 as `expected_dims`. 

266 If 'subset', checks that the dimensions of the data object is a 

267 subset of `expected_dims`. 

268 If 'superset', checks that the dimensions of the data object is a 

269 superset of `expected_dims`, (i.e. contains `expected_dims`). 

270 If 'proper subset', checks that the dimensions of the data object is a 

271 subset of `expected_dims`, (i.e. is a subset, but not equal to 

272 `expected_dims`). 

273 If 'proper superset', checks that the dimensions of the data object 

274 is a proper superset of `expected_dims`, (i.e. contains but is not 

275 equal to `expected_dims`). 

276 If 'disjoint', checks that the dimensions of the data object shares no 

277 elements with `expected_dims`. 

278 

279 Raises: 

280 scores.utils.DimensionError: the dimensions of `xr_data` does 

281 not pass the check as specified by `mode`. 

282 TypeError: `xr_data` is not an xarray data object. 

283 ValueError: `expected_dims` contains duplicate values. 

284 ValueError: `expected_dims` cannot be coerced into a set. 

285 ValueError: `mode` is not one of 'equal', 'subset', 'superset', 

286 'proper subset', 'proper superset', or 'disjoint' 

287 """ 

288 

289 if isinstance(expected_dims, str): 

290 raise TypeError(f"Supplied dimensions '{expected_dims}' must be an iterable of strings, not a string itself.") 

291 

292 try: 

293 dims_set = set(expected_dims) 

294 except Exception as exc: 

295 raise ValueError( 

296 f"Cannot convert supplied dims {expected_dims} into a set. Check debug log for more information." 

297 ) from exc 

298 

299 if len(dims_set) != len(expected_dims): 

300 raise ValueError(f"Supplied dimensions {expected_dims} contains duplicate values.") 

301 

302 if not hasattr(xr_data, "dims"): 

303 raise DimensionError("Supplied object has no dimensions") 

304 

305 # internal functions to check a data array 

306 check_modes = { 

307 "equal": lambda da, dims_set: set(da.dims) == dims_set, 

308 "subset": lambda da, dims_set: set(da.dims) <= dims_set, 

309 "superset": lambda da, dims_set: set(da.dims) >= dims_set, 

310 "proper subset": lambda da, dims_set: set(da.dims) < dims_set, 

311 "proper superset": lambda da, dims_set: set(da.dims) > dims_set, 

312 "disjoint": lambda da, dims_set: len(set(da.dims) & dims_set) == 0, 

313 } 

314 

315 if mode is None: 

316 mode = "equal" 

317 if mode not in check_modes: 

318 raise ValueError(f"No such mode {mode}, mode must be one of: {list(check_modes.keys())}") 

319 

320 check_fn = check_modes[mode] 

321 

322 # check the dims 

323 if not check_fn(xr_data, dims_set): 

324 raise DimensionError( 

325 f"Dimensions {list(xr_data.dims)} of data object are not {mode} to the " 

326 f"dimensions {sorted(list(dims_set))}." 

327 ) 

328 

329 if isinstance(xr_data, xr.Dataset): 

330 # every data variable must pass the dims check too! 

331 for data_var in xr_data.data_vars: 

332 if not check_fn(xr_data[data_var], dims_set): 

333 raise DimensionError( 

334 f"Dimensions {list(xr_data[data_var].dims)} of data variable " 

335 f"'{data_var}' are not {mode} to the dimensions {sorted(dims_set)}" 

336 ) 

337 

338 

339def tmp_coord_name(xr_data: xr.DataArray, count=1) -> str: 

340 """ 

341 Generates temporary coordinate names that are not among the coordinate or dimension 

342 names of `xr_data`. 

343 

344 Args: 

345 xr_data: Input xarray data array 

346 count: Number of unique names to generate 

347 

348 Returns: 

349 If count = 1, a string which is the concatenation of 'new' with all coordinate and 

350 dimension names in the input array. (this is the default) 

351 If count > 1, a list of such strings, each unique from one another 

352 """ 

353 all_names = ["new"] + list(xr_data.dims) + list(xr_data.coords) 

354 result = "".join(all_names) 

355 

356 if count == 1: 

357 return result 

358 

359 results = [str(i) + result for i in range(count)] 

360 return results