Coverage for src/scores/continuous/flip_flop_impl.py: 100%

86 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-02-28 12:51 +1100

1""" 

2This module contains functions for calculating flip flop indices 

3""" 

4 

5from collections.abc import Generator, Iterable, Sequence 

6from typing import Optional, Union, overload 

7 

8import numpy as np 

9import xarray as xr 

10 

11from scores.functions import angular_difference 

12from scores.processing import proportion_exceeding 

13from scores.typing import FlexibleDimensionTypes, XarrayLike 

14from scores.utils import DimensionError, check_dims, dims_complement 

15 

16 

17def _flip_flop_index(data: xr.DataArray, sampling_dim: str, is_angular: bool = False) -> xr.DataArray: 

18 """ 

19 Calculates the flip-flop index by collapsing the dimension specified by 

20 `sampling_dim`. 

21 

22 Args: 

23 data: Data from which to draw subsets. 

24 sampling_dim: The name of the dimension along which to calculate 

25 the flip-flop index. 

26 is_angular: specifies whether `data` is directional data (e.g. wind 

27 direction). 

28 

29 Returns: 

30 A xarray.DataArray of the flip-flop index with the dimensions of 

31 `data`, except for the `sampling_dim` dimension which is collapsed. 

32 

33 See also: 

34 `scores.continuous.flip_flop.flip_flop_index` 

35 """ 

36 # check that `sampling_dim` is in `data`. 

37 check_dims(data, [sampling_dim], mode="superset") 

38 # the maximum possible number of discrete flip_flops 

39 sequence_length = len(data[sampling_dim]) 

40 max_possible_flip_flop_count = sequence_length - 2 

41 

42 # calculate the range 

43 # skip_na=False guarantees that if there is a nan in that row, 

44 # it will show up as nan in the end 

45 if is_angular: 

46 # get complementary dimensions as `encompassing_sector_size` takes 

47 # dimensions to be preserved, not collapsed 

48 dims_to_preserve = dims_complement(data, [sampling_dim]) 

49 # get maximum forecast range, if > 180 then clip to 180 as this is the 

50 # maximum possible angular difference between two forecasts 

51 enc_size = encompassing_sector_size(data=data, dims=dims_to_preserve) 

52 range_val = np.clip(enc_size, a_min=None, a_max=180.0) 

53 flip_flop = angular_difference(data.shift({sampling_dim: 1}), data) 

54 else: 

55 max_val = data.max(dim=sampling_dim, skipna=False) 

56 min_val = data.min(dim=sampling_dim, skipna=False) 

57 range_val = max_val - min_val 

58 # subtract each consecutive 'row' from eachother 

59 flip_flop = data.shift({sampling_dim: 1}) - data 

60 

61 # take the absolute value and sum. 

62 # I don't do skipna=False here because .shift makes a row of nan 

63 flip_flop = abs(flip_flop).sum(dim=sampling_dim) 

64 # adjust based on the range. This is where nan will be introduced. 

65 flip_flop = flip_flop - range_val 

66 # normalise 

67 return flip_flop / max_possible_flip_flop_count 

68 

69 

70# If there are selections, a DataSet is always returned 

71@overload 

72def flip_flop_index( 

73 data: xr.DataArray, sampling_dim: str, is_angular: bool = False, **selections: Iterable[int] 

74) -> xr.Dataset: 

75 ... 

76 

77 

78# If there are no selections, a DataArray is always returned 

79@overload 

80def flip_flop_index( 

81 data: xr.DataArray, sampling_dim: str, is_angular: bool = False, **selections: None 

82) -> xr.DataArray: 

83 ... 

84 

85 

86# Return type is more precise at runtime when it is known if selections are being used 

87def flip_flop_index( 

88 data: xr.DataArray, sampling_dim: str, is_angular: bool = False, **selections: Optional[Iterable[int]] 

89) -> XarrayLike: 

90 """ 

91 Calculates the Flip-flop Index along the dimensions `sampling_dim`. 

92 

93 Args: 

94 data: Data from which to draw subsets. 

95 sampling_dim: The name of the dimension along which to calculate 

96 the flip-flop index. 

97 is_angular: specifies whether `data` is directional data (e.g. wind 

98 direction). 

99 **selections: Additional keyword arguments specify 

100 subsets to draw from the dimension `sampling_dim` of the supplied `data` 

101 before calculation of the flip_flop index. e.g. days123=[1, 2, 3] 

102 

103 Returns: 

104 If `selections` are not supplied: An xarray.DataArray, the Flip-flop 

105 Index by collapsing the dimension `sampling_dim`. 

106 

107 If `selections` are supplied: An xarray.Dataset. Each data variable 

108 is a supplied key-word argument, and corresponds to selecting the 

109 values specified from `sampling_dim` of `data`. The Flip-flop Index 

110 is calculated for each of these selections. 

111 

112 Notes: 

113 

114 .. math:: 

115 

116 \\text{{Flip-Flop Index}} = \\frac{{1}}{{N-2}} 

117 \\left [ 

118 \\left(\\sum\\limits_{{i=1}}^{{N-1}}|x_i - x_{{i+1}}|\\right) 

119 - \\left(\\max_{{j}}\\{{x_j\\}} - \\min_{{j}}\\{{x_j\\}}\\right) 

120 \\right ] 

121 

122 Where :math:`N` is the number of data points, and :math:`x_i` is the 

123 :math:`i^{{\\text{{th}}}}` data point. 

124 

125 Examples: 

126 >>> data = xr.DataArray([50, 20, 40, 80], coords={{'lead_day': [1, 2, 3, 4]}}) 

127 

128 >>> flip_flop_index(data, 'lead_day') 

129 <xarray.DataArray ()> 

130 array(15.0) 

131 Attributes: 

132 sampling_dim: lead_day 

133 

134 >>> flip_flop_index(data, 'lead_day', days123=[1, 2, 3], all_days=[1, 2, 3, 4]) 

135 <xarray.Dataset> 

136 Dimensions: () 

137 Coordinates: 

138 *empty* 

139 Data variables: 

140 days123 float64 20.0 

141 all_days float64 15.0 

142 Attributes: 

143 selections: {{'days123': [1, 2, 3], 'all_days': [1, 2, 3, 4]}} 

144 sampling_dim: lead_day 

145 

146 """ 

147 

148 if not selections and isinstance(data, xr.DataArray): 

149 result = _flip_flop_index(data, sampling_dim, is_angular=is_angular) 

150 else: 

151 result = xr.Dataset() 

152 result.attrs["selections"] = selections 

153 for key, data_subset in iter_selections(data, sampling_dim, **selections): 

154 result[key] = _flip_flop_index(data_subset, sampling_dim, is_angular=is_angular) 

155 result.attrs["sampling_dim"] = sampling_dim 

156 

157 return result 

158 

159 

160# DataArray input types lead to DataArray output types 

161@overload 

162def iter_selections( 

163 data: xr.DataArray, sampling_dim: str, **selections: Optional[Iterable[int]] 

164) -> Generator[tuple[str, xr.DataArray], None, None]: 

165 ... 

166 

167 

168# Dataset input types load to Dataset output types 

169@overload 

170def iter_selections( 

171 data: xr.Dataset, sampling_dim: str, **selections: Optional[Iterable[int]] 

172) -> Generator[tuple[str, xr.Dataset], None, None]: 

173 ... 

174 

175 

176def iter_selections( 

177 data: XarrayLike, sampling_dim: str, **selections: Optional[Iterable[int]] 

178) -> Generator[tuple[str, XarrayLike], None, None]: 

179 """ 

180 Selects subsets of data along dimension sampling_dim according to 

181 `selections`. 

182 

183 Args: 

184 data: The data to sample from. 

185 sampling_dim: The dimension from which to sample. 

186 selections: Each supplied keyword corresponds to a 

187 selection of `data` from the dimensions `sampling_dim`. The 

188 key is the first element of the yielded tuple. 

189 

190 Yields: 

191 A tuple (key, data_subset), where key is the supplied `**selections` 

192 keyword, and data_subset is the `data` at the values along 

193 `sampling_dim` specified by `**selections`. 

194 

195 Raises: 

196 KeyError: values in selections are not in data[sampling_dim] 

197 

198 Examples: 

199 >>> data = xr.DataArray( 

200 ... [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7], 

201 ... coords={'lead_day': [1, 2, 3, 4, 5, 6, 7]} 

202 ... ) 

203 >>> for key, data_subset in iter_selections( 

204 ... data, 'lead_day', days123=[1, 2, 3], all_days=[1, 2, 3, 4, 5, 6, 7] 

205 ... ): 

206 ... print(key, ':', data_subset) 

207 all_days : <xarray.DataArray (lead_day: 7)> 

208 array([ 0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.7]) 

209 Coordinates: 

210 * lead_day (lead_day) int64 1 2 3 4 5 6 7 

211 days123 : <xarray.DataArray (lead_day: 3)> 

212 array([ 0. , 0.1, 0.2]) 

213 Coordinates: 

214 * lead_day (lead_day) int64 1 2 3 

215 

216 """ 

217 check_dims(data, [sampling_dim], "superset") 

218 

219 for key, values in selections.items(): 

220 try: 

221 # Need copy so that attributes added in _iter_selections_with_attrs 

222 # don't affect the whole dataframe but just the subset 

223 data_subset = data.sel({sampling_dim: values}).copy(deep=False) 

224 except KeyError as ex: 

225 raise KeyError( 

226 f"for `selections` item {str({key: values})}, not all values found in " f"dimension '{sampling_dim}'", 

227 ) from ex 

228 

229 yield key, data_subset 

230 

231 

232def encompassing_sector_size(data: xr.DataArray, dims: Sequence[str], skipna: bool = False) -> xr.DataArray: 

233 """ 

234 Calculates the minimum angular distance which encompasses all data points 

235 within an xarray.DataArray along a specified dimension. Assumes data is in 

236 degrees. 

237 Only one dimension may be collapsed each time, so length of dims must be of 

238 length one less than the length of data.dims otherwise an exception will be 

239 raised. 

240 

241 Args: 

242 data: direction data in degrees 

243 dims: Strings corresponding to the dimensions in the input 

244 xarray data objects that we wish to preserve in the output. All other 

245 dimensions in the input data objects are collapsed. 

246 skipna: specifies whether to ignore nans in the data. If False 

247 (default), will return a nan if one or more nans are present 

248 

249 Returns: 

250 an xarray.DataArray of minimum encompassing sector sizes with 

251 dimensions `dims`. 

252 

253 Raises: 

254 scores.utils.DimensionError: raised if 

255 

256 - the set of data dimensions is not a proper superset of `dims` 

257 - dimension to be collapsed isn't 1 

258 """ 

259 check_dims(data, dims, "proper superset") 

260 dims_to_collapse = dims_complement(data, dims) 

261 if len(dims_to_collapse) != 1: 

262 raise DimensionError("can only collapse one dimension") 

263 dim_to_collapse = dims_to_collapse[0] 

264 axis_to_collapse = data.get_axis_num(dim_to_collapse) 

265 values = _encompassing_sector_size_np( 

266 data=data.values, 

267 axis_to_collapse=axis_to_collapse, 

268 skipna=skipna, 

269 ) 

270 new_dims = [dim for dim in data.dims if dim in dims] 

271 coords = [data.coords[dim] for dim in new_dims] 

272 result = xr.DataArray(values, dims=new_dims, coords=coords) 

273 return result 

274 

275 

276@np.errstate(invalid="ignore") 

277def _encompassing_sector_size_np( 

278 data: np.ndarray, axis_to_collapse: Union[int, tuple[int, ...]] = 0, skipna: bool = False 

279) -> np.ndarray: 

280 """ 

281 Calculates the minimum angular distance which encompasses all data points 

282 within an xarray.DataArray along a specified dimension. Assumes data is in 

283 degrees. 

284 

285 Args: 

286 data: direction data in degrees 

287 axis_to_collapse: number of axis to collapse in data, the numpy.ndarray 

288 skipna: specifies whether to ignore nans in the data. If False 

289 (default), will return a nan if one or more nans are present 

290 

291 Returns: 

292 an numpy.ndarray of minimum encompassing sector sizes 

293 """ 

294 # code will be simpler, and makes broadcasting easier if we are dealing 

295 # with the axis=0 

296 data = np.moveaxis(data, axis_to_collapse, 0) 

297 # make data in range [0, 360) 

298 data = data % 360 

299 data = np.sort(data, axis=0) 

300 if skipna: 

301 # rotate so one angle is at zero, then we can replace Nans with zeroes 

302 if data.ndim == 1: 

303 data = (data - data[0]) % 360 

304 else: 

305 data = (data - data[0, :]) % 360 

306 all_nans = np.all(np.isnan(data), axis=0) 

307 # if all NaNs, we don't want to change, and still want end result to be 

308 # NaN. 

309 # if some NaNs but not all, then set to zero, which will just end up 

310 # being a duplicate value after we've rotated so at least one zero value 

311 data[np.isnan(data) & ~all_nans] = 0 

312 # make a back-shifted copy of `data` 

313 data_rolled = np.roll(data, shift=-1, axis=0) 

314 # determine absolute angular difference between all adjacent angles 

315 angular_diffs = np.abs(data - data_rolled) 

316 angular_diffs = np.where( 

317 # nan_to_num so doesn't complain about comparing with NaN 

318 np.nan_to_num(angular_diffs) > 180, 

319 360 - angular_diffs, 

320 angular_diffs, 

321 ) 

322 # the max difference between adjacent angles, or its complement, is 

323 # equivalent to the smallest sector size which encompasses all angles in 

324 # `data`. 

325 max_args = np.argmax(angular_diffs, axis=0) 

326 max_indices = tuple([max_args] + list(np.indices(max_args.shape))) 

327 # determine the first of the two angles resulting in max difference 

328 first_bounding_angle = data[max_indices] 

329 # rotate all angles by `first_bounding_angle` (anticlockwise), and make any 

330 # resulting negative angles positive. This ensures that the rotated 

331 # `first_bounding_angle` is 0, and is therefore the smallest angle in the 

332 # rotated set 

333 rotated = (data_rolled - first_bounding_angle) % 360 

334 # determine the second of the two angles, now rotated, resulting in max 

335 # difference 

336 second_bound_angle_rotated = rotated[max_indices] 

337 max_of_rotated = np.max(rotated, axis=0) 

338 # if `second_bounding_angle_rotated` is the largest element, then 

339 # sector size is the clockwise span of 0 -> `second_bounding_angle_rotated`, 

340 # otherwise it's the anticlockwise span 

341 result = np.where( 

342 max_of_rotated == second_bound_angle_rotated, 

343 second_bound_angle_rotated, 

344 360 - second_bound_angle_rotated, 

345 ) 

346 # if there are only one or two distinct angles, return the unique difference 

347 # calculated 

348 n_unique_angles = (angular_diffs != 0).sum(axis=0) 

349 result = np.where(n_unique_angles <= 2, np.max(angular_diffs, axis=0), result) 

350 return result 

351 

352 

353def flip_flop_index_proportion_exceeding( 

354 data: xr.DataArray, 

355 sampling_dim: str, 

356 thresholds: Iterable, 

357 is_angular: bool = False, 

358 preserve_dims: FlexibleDimensionTypes = None, 

359 reduce_dims: FlexibleDimensionTypes = None, 

360 **selections: Iterable[int], 

361): 

362 """ 

363 Calculates the flip-flop index and returns the proportion exceeding 

364 (or equal to) each of the supplied `thresholds`. 

365 

366 Args: 

367 data: Data from which to draw subsets. 

368 sampling_dim: The name of the dimension along which to calculate 

369 thresholds: The proportion of Flip-Flop index results 

370 equal to or exceeding these thresholds will be calculated. 

371 the flip-flop index. 

372 is_angular: specifies whether `data` is directional data (e.g. wind 

373 direction). 

374 reduce_dims: Dimensions to reduce. 

375 preserve_dims: Dimensions to preserve. 

376 **selections: Additional keyword arguments specify 

377 subsets to draw from the dimension `sampling_dim` of the supplied `data` 

378 before calculation of the flip_flop index. e.g. days123=[1, 2, 3] 

379 Returns: 

380 If `selections` are not supplied - An xarray.DataArray with dimensions 

381 `dims` + 'threshold'. The DataArray is the proportion of the Flip-flop 

382 Index calculated by collapsing dimension `sampling_dim` exceeding or 

383 equal to `thresholds`. 

384 

385 If `selections` are supplied - An xarray.Dataset with dimensions `dims` 

386 + 'threshold'. There is a data variable for each keyword in 

387 `selections`, and corresponds to the Flip-Flop Index proportion 

388 exceeding for the subset of data specified by the keyword values. 

389 

390 Examples: 

391 >>> data = xr.DataArray( 

392 ... [[50, 20, 40, 80], [10, 50, 10, 100], [0, 30, 20, 50]], 

393 ... dims=['station_number', 'lead_day'], 

394 ... coords=[[10001, 10002, 10003], [1, 2, 3, 4]] 

395 ... ) 

396 

397 >>> flip_flop_index_proportion_exceeding(data, 'lead_day', [20]) 

398 <xarray.DataArray (threshold: 1)> 

399 array([ 0.33333333]) 

400 Coordinates: 

401 * threshold (threshold) int64 20 

402 Attributes: 

403 sampling_dim: lead_day 

404 

405 >>> flip_flop_index_proportion_exceeding( 

406 ... data, 'lead_day', [20], days123=[1, 2, 3], all_days=[1, 2, 3, 4] 

407 ... ) 

408 <xarray.Dataset> 

409 Dimensions: (threshold: 1) 

410 Coordinates: 

411 * threshold (threshold) int64 20 

412 Data variables: 

413 days123 (threshold) float64 0.6667 

414 all_days (threshold) float64 0.3333 

415 Attributes: 

416 selections: {{'days123': [1, 2, 3], 'all_days': [1, 2, 3, 4]}} 

417 sampling_dim: lead_day 

418 

419 See also: 

420 `scores.continuous.flip_flop_index` 

421 

422 """ 

423 if preserve_dims is not None and sampling_dim in list(preserve_dims): 

424 raise DimensionError( 

425 f"`sampling_dim`: '{sampling_dim}' must not be in dimensions to preserve " 

426 f"`preserve_dims`: {list(preserve_dims)}" 

427 ) 

428 if reduce_dims is not None and sampling_dim in list(reduce_dims): 

429 raise DimensionError( 

430 f"`sampling_dim`: '{sampling_dim}' must not be in dimensions to reduce " 

431 f"`reduce_dims`: {list(reduce_dims)}" 

432 ) 

433 # calculate the flip-flop index 

434 flip_flop_data = flip_flop_index(data, sampling_dim, is_angular=is_angular, **selections) 

435 # calculate the proportion exceeding each threshold 

436 flip_flop_exceeding = proportion_exceeding(flip_flop_data, thresholds, reduce_dims, preserve_dims) 

437 # overwrite the attributes 

438 flip_flop_exceeding.attrs = flip_flop_data.attrs 

439 

440 return flip_flop_exceeding