Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1"""Common utility functions for rolling operations""" 

2from collections import defaultdict 

3from typing import Callable, Optional 

4import warnings 

5 

6import numpy as np 

7 

8from pandas.core.dtypes.common import is_integer 

9from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries 

10 

11import pandas.core.common as com 

12from pandas.core.generic import _shared_docs 

13from pandas.core.groupby.base import GroupByMixin 

14from pandas.core.indexes.api import MultiIndex 

15 

16_shared_docs = dict(**_shared_docs) 

17_doc_template = """ 

18 Returns 

19 ------- 

20 Series or DataFrame 

21 Return type is determined by the caller. 

22 

23 See Also 

24 -------- 

25 Series.%(name)s : Series %(name)s. 

26 DataFrame.%(name)s : DataFrame %(name)s. 

27""" 

28 

29 

30def _dispatch(name: str, *args, **kwargs): 

31 """ 

32 Dispatch to apply. 

33 """ 

34 

35 def outer(self, *args, **kwargs): 

36 def f(x): 

37 x = self._shallow_copy(x, groupby=self._groupby) 

38 return getattr(x, name)(*args, **kwargs) 

39 

40 return self._groupby.apply(f) 

41 

42 outer.__name__ = name 

43 return outer 

44 

45 

46class WindowGroupByMixin(GroupByMixin): 

47 """ 

48 Provide the groupby facilities. 

49 """ 

50 

51 def __init__(self, obj, *args, **kwargs): 

52 kwargs.pop("parent", None) 

53 groupby = kwargs.pop("groupby", None) 

54 if groupby is None: 

55 groupby, obj = obj, obj.obj 

56 self._groupby = groupby 

57 self._groupby.mutated = True 

58 self._groupby.grouper.mutated = True 

59 super().__init__(obj, *args, **kwargs) 

60 

61 count = _dispatch("count") 

62 corr = _dispatch("corr", other=None, pairwise=None) 

63 cov = _dispatch("cov", other=None, pairwise=None) 

64 

65 def _apply( 

66 self, 

67 func: Callable, 

68 center: bool, 

69 require_min_periods: int = 0, 

70 floor: int = 1, 

71 is_weighted: bool = False, 

72 name: Optional[str] = None, 

73 use_numba_cache: bool = False, 

74 **kwargs, 

75 ): 

76 """ 

77 Dispatch to apply; we are stripping all of the _apply kwargs and 

78 performing the original function call on the grouped object. 

79 """ 

80 kwargs.pop("floor", None) 

81 

82 # TODO: can we de-duplicate with _dispatch? 

83 def f(x, name=name, *args): 

84 x = self._shallow_copy(x) 

85 

86 if isinstance(name, str): 

87 return getattr(x, name)(*args, **kwargs) 

88 

89 return x.apply(name, *args, **kwargs) 

90 

91 return self._groupby.apply(f) 

92 

93 

94def _flex_binary_moment(arg1, arg2, f, pairwise=False): 

95 

96 if not ( 

97 isinstance(arg1, (np.ndarray, ABCSeries, ABCDataFrame)) 

98 and isinstance(arg2, (np.ndarray, ABCSeries, ABCDataFrame)) 

99 ): 

100 raise TypeError( 

101 "arguments to moment function must be of type " 

102 "np.ndarray/Series/DataFrame" 

103 ) 

104 

105 if isinstance(arg1, (np.ndarray, ABCSeries)) and isinstance( 

106 arg2, (np.ndarray, ABCSeries) 

107 ): 

108 X, Y = prep_binary(arg1, arg2) 

109 return f(X, Y) 

110 

111 elif isinstance(arg1, ABCDataFrame): 

112 from pandas import DataFrame 

113 

114 def dataframe_from_int_dict(data, frame_template): 

115 result = DataFrame(data, index=frame_template.index) 

116 if len(result.columns) > 0: 

117 result.columns = frame_template.columns[result.columns] 

118 return result 

119 

120 results = {} 

121 if isinstance(arg2, ABCDataFrame): 

122 if pairwise is False: 

123 if arg1 is arg2: 

124 # special case in order to handle duplicate column names 

125 for i, col in enumerate(arg1.columns): 

126 results[i] = f(arg1.iloc[:, i], arg2.iloc[:, i]) 

127 return dataframe_from_int_dict(results, arg1) 

128 else: 

129 if not arg1.columns.is_unique: 

130 raise ValueError("'arg1' columns are not unique") 

131 if not arg2.columns.is_unique: 

132 raise ValueError("'arg2' columns are not unique") 

133 with warnings.catch_warnings(record=True): 

134 warnings.simplefilter("ignore", RuntimeWarning) 

135 X, Y = arg1.align(arg2, join="outer") 

136 X = X + 0 * Y 

137 Y = Y + 0 * X 

138 

139 with warnings.catch_warnings(record=True): 

140 warnings.simplefilter("ignore", RuntimeWarning) 

141 res_columns = arg1.columns.union(arg2.columns) 

142 for col in res_columns: 

143 if col in X and col in Y: 

144 results[col] = f(X[col], Y[col]) 

145 return DataFrame(results, index=X.index, columns=res_columns) 

146 elif pairwise is True: 

147 results = defaultdict(dict) 

148 for i, k1 in enumerate(arg1.columns): 

149 for j, k2 in enumerate(arg2.columns): 

150 if j < i and arg2 is arg1: 

151 # Symmetric case 

152 results[i][j] = results[j][i] 

153 else: 

154 results[i][j] = f( 

155 *prep_binary(arg1.iloc[:, i], arg2.iloc[:, j]) 

156 ) 

157 

158 from pandas import concat 

159 

160 result_index = arg1.index.union(arg2.index) 

161 if len(result_index): 

162 

163 # construct result frame 

164 result = concat( 

165 [ 

166 concat( 

167 [results[i][j] for j, c in enumerate(arg2.columns)], 

168 ignore_index=True, 

169 ) 

170 for i, c in enumerate(arg1.columns) 

171 ], 

172 ignore_index=True, 

173 axis=1, 

174 ) 

175 result.columns = arg1.columns 

176 

177 # set the index and reorder 

178 if arg2.columns.nlevels > 1: 

179 result.index = MultiIndex.from_product( 

180 arg2.columns.levels + [result_index] 

181 ) 

182 result = result.reorder_levels([2, 0, 1]).sort_index() 

183 else: 

184 result.index = MultiIndex.from_product( 

185 [range(len(arg2.columns)), range(len(result_index))] 

186 ) 

187 result = result.swaplevel(1, 0).sort_index() 

188 result.index = MultiIndex.from_product( 

189 [result_index] + [arg2.columns] 

190 ) 

191 else: 

192 

193 # empty result 

194 result = DataFrame( 

195 index=MultiIndex( 

196 levels=[arg1.index, arg2.columns], codes=[[], []] 

197 ), 

198 columns=arg2.columns, 

199 dtype="float64", 

200 ) 

201 

202 # reset our index names to arg1 names 

203 # reset our column names to arg2 names 

204 # careful not to mutate the original names 

205 result.columns = result.columns.set_names(arg1.columns.names) 

206 result.index = result.index.set_names( 

207 result_index.names + arg2.columns.names 

208 ) 

209 

210 return result 

211 

212 else: 

213 raise ValueError("'pairwise' is not True/False") 

214 else: 

215 results = { 

216 i: f(*prep_binary(arg1.iloc[:, i], arg2)) 

217 for i, col in enumerate(arg1.columns) 

218 } 

219 return dataframe_from_int_dict(results, arg1) 

220 

221 else: 

222 return _flex_binary_moment(arg2, arg1, f) 

223 

224 

225def _get_center_of_mass(comass, span, halflife, alpha): 

226 valid_count = com.count_not_none(comass, span, halflife, alpha) 

227 if valid_count > 1: 

228 raise ValueError("comass, span, halflife, and alpha are mutually exclusive") 

229 

230 # Convert to center of mass; domain checks ensure 0 < alpha <= 1 

231 if comass is not None: 

232 if comass < 0: 

233 raise ValueError("comass must satisfy: comass >= 0") 

234 elif span is not None: 

235 if span < 1: 

236 raise ValueError("span must satisfy: span >= 1") 

237 comass = (span - 1) / 2.0 

238 elif halflife is not None: 

239 if halflife <= 0: 

240 raise ValueError("halflife must satisfy: halflife > 0") 

241 decay = 1 - np.exp(np.log(0.5) / halflife) 

242 comass = 1 / decay - 1 

243 elif alpha is not None: 

244 if alpha <= 0 or alpha > 1: 

245 raise ValueError("alpha must satisfy: 0 < alpha <= 1") 

246 comass = (1.0 - alpha) / alpha 

247 else: 

248 raise ValueError("Must pass one of comass, span, halflife, or alpha") 

249 

250 return float(comass) 

251 

252 

253def calculate_center_offset(window): 

254 if not is_integer(window): 

255 window = len(window) 

256 return int((window - 1) / 2.0) 

257 

258 

259def calculate_min_periods( 

260 window: int, 

261 min_periods: Optional[int], 

262 num_values: int, 

263 required_min_periods: int, 

264 floor: int, 

265) -> int: 

266 """ 

267 Calculates final minimum periods value for rolling aggregations. 

268 

269 Parameters 

270 ---------- 

271 window : passed window value 

272 min_periods : passed min periods value 

273 num_values : total number of values 

274 required_min_periods : required min periods per aggregation function 

275 floor : required min periods per aggregation function 

276 

277 Returns 

278 ------- 

279 min_periods : int 

280 """ 

281 if min_periods is None: 

282 min_periods = window 

283 else: 

284 min_periods = max(required_min_periods, min_periods) 

285 if min_periods > window: 

286 raise ValueError(f"min_periods {min_periods} must be <= window {window}") 

287 elif min_periods > num_values: 

288 min_periods = num_values + 1 

289 elif min_periods < 0: 

290 raise ValueError("min_periods must be >= 0") 

291 return max(min_periods, floor) 

292 

293 

294def zsqrt(x): 

295 with np.errstate(all="ignore"): 

296 result = np.sqrt(x) 

297 mask = x < 0 

298 

299 if isinstance(x, ABCDataFrame): 

300 if mask.values.any(): 

301 result[mask] = 0 

302 else: 

303 if mask.any(): 

304 result[mask] = 0 

305 

306 return result 

307 

308 

309def prep_binary(arg1, arg2): 

310 if not isinstance(arg2, type(arg1)): 

311 raise Exception("Input arrays must be of the same type!") 

312 

313 # mask out values, this also makes a common index... 

314 X = arg1 + 0 * arg2 

315 Y = arg2 + 0 * arg1 

316 

317 return X, Y 

318 

319 

320def get_weighted_roll_func(cfunc: Callable) -> Callable: 

321 def func(arg, window, min_periods=None): 

322 if min_periods is None: 

323 min_periods = len(window) 

324 return cfunc(arg, window, min_periods) 

325 

326 return func