Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/pandas/core/groupby/ops.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Provide classes to perform the groupby aggregate operations.
4These are not exposed to the user and provide implementations of the grouping
5operations, primarily in cython. These classes (BaseGrouper and BinGrouper)
6are contained *in* the SeriesGroupBy and DataFrameGroupBy objects.
7"""
9import collections
10from typing import List, Optional, Sequence, Tuple, Type
12import numpy as np
14from pandas._libs import NaT, iNaT, lib
15import pandas._libs.groupby as libgroupby
16import pandas._libs.reduction as libreduction
17from pandas._typing import FrameOrSeries
18from pandas.errors import AbstractMethodError
19from pandas.util._decorators import cache_readonly
21from pandas.core.dtypes.common import (
22 ensure_float64,
23 ensure_int64,
24 ensure_int_or_float,
25 ensure_platform_int,
26 is_bool_dtype,
27 is_categorical_dtype,
28 is_complex_dtype,
29 is_datetime64_any_dtype,
30 is_datetime64tz_dtype,
31 is_extension_array_dtype,
32 is_integer_dtype,
33 is_numeric_dtype,
34 is_period_dtype,
35 is_sparse,
36 is_timedelta64_dtype,
37 needs_i8_conversion,
38)
39from pandas.core.dtypes.missing import _maybe_fill, isna
41import pandas.core.algorithms as algorithms
42from pandas.core.base import SelectionMixin
43import pandas.core.common as com
44from pandas.core.frame import DataFrame
45from pandas.core.generic import NDFrame
46from pandas.core.groupby import base, grouper
47from pandas.core.indexes.api import Index, MultiIndex, ensure_index
48from pandas.core.series import Series
49from pandas.core.sorting import (
50 compress_group_index,
51 decons_obs_group_ids,
52 get_flattened_iterator,
53 get_group_index,
54 get_group_index_sorter,
55 get_indexer_dict,
56)
59class BaseGrouper:
60 """
61 This is an internal Grouper class, which actually holds
62 the generated groups
64 Parameters
65 ----------
66 axis : Index
67 groupings : Sequence[Grouping]
68 all the grouping instances to handle in this grouper
69 for example for grouper list to groupby, need to pass the list
70 sort : bool, default True
71 whether this grouper will give sorted result or not
72 group_keys : bool, default True
73 mutated : bool, default False
74 indexer : intp array, optional
75 the indexer created by Grouper
76 some groupers (TimeGrouper) will sort its axis and its
77 group_info is also sorted, so need the indexer to reorder
79 """
81 def __init__(
82 self,
83 axis: Index,
84 groupings: "Sequence[grouper.Grouping]",
85 sort: bool = True,
86 group_keys: bool = True,
87 mutated: bool = False,
88 indexer: Optional[np.ndarray] = None,
89 ):
90 assert isinstance(axis, Index), axis
92 self._filter_empty_groups = self.compressed = len(groupings) != 1
93 self.axis = axis
94 self._groupings: List[grouper.Grouping] = list(groupings)
95 self.sort = sort
96 self.group_keys = group_keys
97 self.mutated = mutated
98 self.indexer = indexer
100 @property
101 def groupings(self) -> List["grouper.Grouping"]:
102 return self._groupings
104 @property
105 def shape(self):
106 return tuple(ping.ngroups for ping in self.groupings)
108 def __iter__(self):
109 return iter(self.indices)
111 @property
112 def nkeys(self) -> int:
113 return len(self.groupings)
115 def get_iterator(self, data: FrameOrSeries, axis: int = 0):
116 """
117 Groupby iterator
119 Returns
120 -------
121 Generator yielding sequence of (name, subsetted object)
122 for each group
123 """
124 splitter = self._get_splitter(data, axis=axis)
125 keys = self._get_group_keys()
126 for key, (i, group) in zip(keys, splitter):
127 yield key, group
129 def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> "DataSplitter":
130 comp_ids, _, ngroups = self.group_info
131 return get_splitter(data, comp_ids, ngroups, axis=axis)
133 def _get_grouper(self):
134 """
135 We are a grouper as part of another's groupings.
137 We have a specific method of grouping, so cannot
138 convert to a Index for our grouper.
139 """
140 return self.groupings[0].grouper
142 def _get_group_keys(self):
143 if len(self.groupings) == 1:
144 return self.levels[0]
145 else:
146 comp_ids, _, ngroups = self.group_info
148 # provide "flattened" iterator for multi-group setting
149 return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes)
151 def apply(self, f, data: FrameOrSeries, axis: int = 0):
152 mutated = self.mutated
153 splitter = self._get_splitter(data, axis=axis)
154 group_keys = self._get_group_keys()
155 result_values = None
157 sdata: FrameOrSeries = splitter._get_sorted_data()
158 if sdata.ndim == 2 and np.any(sdata.dtypes.apply(is_extension_array_dtype)):
159 # calling splitter.fast_apply will raise TypeError via apply_frame_axis0
160 # if we pass EA instead of ndarray
161 # TODO: can we have a workaround for EAs backed by ndarray?
162 pass
164 elif (
165 com.get_callable_name(f) not in base.plotting_methods
166 and isinstance(splitter, FrameSplitter)
167 and axis == 0
168 # fast_apply/libreduction doesn't allow non-numpy backed indexes
169 and not sdata.index._has_complex_internals
170 ):
171 try:
172 result_values, mutated = splitter.fast_apply(f, group_keys)
174 except libreduction.InvalidApply as err:
175 # This Exception is raised if `f` triggers an exception
176 # but it is preferable to raise the exception in Python.
177 if "Let this error raise above us" not in str(err):
178 # TODO: can we infer anything about whether this is
179 # worth-retrying in pure-python?
180 raise
182 else:
183 # If the fast apply path could be used we can return here.
184 # Otherwise we need to fall back to the slow implementation.
185 if len(result_values) == len(group_keys):
186 return group_keys, result_values, mutated
188 for key, (i, group) in zip(group_keys, splitter):
189 object.__setattr__(group, "name", key)
191 # result_values is None if fast apply path wasn't taken
192 # or fast apply aborted with an unexpected exception.
193 # In either case, initialize the result list and perform
194 # the slow iteration.
195 if result_values is None:
196 result_values = []
198 # If result_values is not None we're in the case that the
199 # fast apply loop was broken prematurely but we have
200 # already the result for the first group which we can reuse.
201 elif i == 0:
202 continue
204 # group might be modified
205 group_axes = group.axes
206 res = f(group)
207 if not _is_indexed_like(res, group_axes):
208 mutated = True
209 result_values.append(res)
211 return group_keys, result_values, mutated
213 @cache_readonly
214 def indices(self):
215 """ dict {group name -> group indices} """
216 if len(self.groupings) == 1:
217 return self.groupings[0].indices
218 else:
219 codes_list = [ping.codes for ping in self.groupings]
220 keys = [com.values_from_object(ping.group_index) for ping in self.groupings]
221 return get_indexer_dict(codes_list, keys)
223 @property
224 def codes(self) -> List[np.ndarray]:
225 return [ping.codes for ping in self.groupings]
227 @property
228 def levels(self) -> List[Index]:
229 return [ping.group_index for ping in self.groupings]
231 @property
232 def names(self):
233 return [ping.name for ping in self.groupings]
235 def size(self) -> Series:
236 """
237 Compute group sizes.
238 """
239 ids, _, ngroup = self.group_info
240 ids = ensure_platform_int(ids)
241 if ngroup:
242 out = np.bincount(ids[ids != -1], minlength=ngroup)
243 else:
244 out = []
245 return Series(out, index=self.result_index, dtype="int64")
247 @cache_readonly
248 def groups(self):
249 """ dict {group name -> group labels} """
250 if len(self.groupings) == 1:
251 return self.groupings[0].groups
252 else:
253 to_groupby = zip(*(ping.grouper for ping in self.groupings))
254 to_groupby = Index(to_groupby)
255 return self.axis.groupby(to_groupby)
257 @cache_readonly
258 def is_monotonic(self) -> bool:
259 # return if my group orderings are monotonic
260 return Index(self.group_info[0]).is_monotonic
262 @cache_readonly
263 def group_info(self):
264 comp_ids, obs_group_ids = self._get_compressed_codes()
266 ngroups = len(obs_group_ids)
267 comp_ids = ensure_int64(comp_ids)
268 return comp_ids, obs_group_ids, ngroups
270 @cache_readonly
271 def codes_info(self) -> np.ndarray:
272 # return the codes of items in original grouped axis
273 codes, _, _ = self.group_info
274 if self.indexer is not None:
275 sorter = np.lexsort((codes, self.indexer))
276 codes = codes[sorter]
277 return codes
279 def _get_compressed_codes(self) -> Tuple[np.ndarray, np.ndarray]:
280 all_codes = self.codes
281 if len(all_codes) > 1:
282 group_index = get_group_index(all_codes, self.shape, sort=True, xnull=True)
283 return compress_group_index(group_index, sort=self.sort)
285 ping = self.groupings[0]
286 return ping.codes, np.arange(len(ping.group_index))
288 @cache_readonly
289 def ngroups(self) -> int:
290 return len(self.result_index)
292 @property
293 def reconstructed_codes(self) -> List[np.ndarray]:
294 codes = self.codes
295 comp_ids, obs_ids, _ = self.group_info
296 return decons_obs_group_ids(comp_ids, obs_ids, self.shape, codes, xnull=True)
298 @cache_readonly
299 def result_index(self) -> Index:
300 if not self.compressed and len(self.groupings) == 1:
301 return self.groupings[0].result_index.rename(self.names[0])
303 codes = self.reconstructed_codes
304 levels = [ping.result_index for ping in self.groupings]
305 result = MultiIndex(
306 levels=levels, codes=codes, verify_integrity=False, names=self.names
307 )
308 return result
310 def get_group_levels(self):
311 if not self.compressed and len(self.groupings) == 1:
312 return [self.groupings[0].result_index]
314 name_list = []
315 for ping, codes in zip(self.groupings, self.reconstructed_codes):
316 codes = ensure_platform_int(codes)
317 levels = ping.result_index.take(codes)
319 name_list.append(levels)
321 return name_list
323 # ------------------------------------------------------------
324 # Aggregation functions
326 _cython_functions = {
327 "aggregate": {
328 "add": "group_add",
329 "prod": "group_prod",
330 "min": "group_min",
331 "max": "group_max",
332 "mean": "group_mean",
333 "median": "group_median",
334 "var": "group_var",
335 "first": "group_nth",
336 "last": "group_last",
337 "ohlc": "group_ohlc",
338 },
339 "transform": {
340 "cumprod": "group_cumprod",
341 "cumsum": "group_cumsum",
342 "cummin": "group_cummin",
343 "cummax": "group_cummax",
344 "rank": "group_rank",
345 },
346 }
348 _cython_arity = {"ohlc": 4} # OHLC
350 _name_functions = {"ohlc": ["open", "high", "low", "close"]}
352 def _is_builtin_func(self, arg):
353 """
354 if we define an builtin function for this argument, return it,
355 otherwise return the arg
356 """
357 return SelectionMixin._builtin_table.get(arg, arg)
359 def _get_cython_function(self, kind: str, how: str, values, is_numeric: bool):
361 dtype_str = values.dtype.name
362 ftype = self._cython_functions[kind][how]
364 # see if there is a fused-type version of function
365 # only valid for numeric
366 f = getattr(libgroupby, ftype, None)
367 if f is not None and is_numeric:
368 return f
370 # otherwise find dtype-specific version, falling back to object
371 for dt in [dtype_str, "object"]:
372 f2 = getattr(libgroupby, f"{ftype}_{dt}", None)
373 if f2 is not None:
374 return f2
376 if hasattr(f, "__signatures__"):
377 # inspect what fused types are implemented
378 if dtype_str == "object" and "object" not in f.__signatures__:
379 # disallow this function so we get a NotImplementedError below
380 # instead of a TypeError at runtime
381 f = None
383 func = f
385 if func is None:
386 raise NotImplementedError(
387 f"function is not implemented for this dtype: "
388 f"[how->{how},dtype->{dtype_str}]"
389 )
391 return func
393 def _get_cython_func_and_vals(
394 self, kind: str, how: str, values: np.ndarray, is_numeric: bool
395 ):
396 """
397 Find the appropriate cython function, casting if necessary.
399 Parameters
400 ----------
401 kind : sttr
402 how : srt
403 values : np.ndarray
404 is_numeric : bool
406 Returns
407 -------
408 func : callable
409 values : np.ndarray
410 """
411 try:
412 func = self._get_cython_function(kind, how, values, is_numeric)
413 except NotImplementedError:
414 if is_numeric:
415 try:
416 values = ensure_float64(values)
417 except TypeError:
418 if lib.infer_dtype(values, skipna=False) == "complex":
419 values = values.astype(complex)
420 else:
421 raise
422 func = self._get_cython_function(kind, how, values, is_numeric)
423 else:
424 raise
425 return func, values
427 def _cython_operation(
428 self, kind: str, values, how: str, axis, min_count: int = -1, **kwargs
429 ) -> Tuple[np.ndarray, Optional[List[str]]]:
430 """
431 Returns the values of a cython operation as a Tuple of [data, names].
433 Names is only useful when dealing with 2D results, like ohlc
434 (see self._name_functions).
435 """
437 assert kind in ["transform", "aggregate"]
438 orig_values = values
440 if values.ndim > 2:
441 raise NotImplementedError("number of dimensions is currently limited to 2")
442 elif values.ndim == 2:
443 # Note: it is *not* the case that axis is always 0 for 1-dim values,
444 # as we can have 1D ExtensionArrays that we need to treat as 2D
445 assert axis == 1, axis
447 # can we do this operation with our cython functions
448 # if not raise NotImplementedError
450 # we raise NotImplemented if this is an invalid operation
451 # entirely, e.g. adding datetimes
453 # categoricals are only 1d, so we
454 # are not setup for dim transforming
455 if is_categorical_dtype(values) or is_sparse(values):
456 raise NotImplementedError(f"{values.dtype} dtype not supported")
457 elif is_datetime64_any_dtype(values):
458 if how in ["add", "prod", "cumsum", "cumprod"]:
459 raise NotImplementedError(
460 f"datetime64 type does not support {how} operations"
461 )
462 elif is_timedelta64_dtype(values):
463 if how in ["prod", "cumprod"]:
464 raise NotImplementedError(
465 f"timedelta64 type does not support {how} operations"
466 )
468 if is_datetime64tz_dtype(values.dtype):
469 # Cast to naive; we'll cast back at the end of the function
470 # TODO: possible need to reshape? kludge can be avoided when
471 # 2D EA is allowed.
472 values = values.view("M8[ns]")
474 is_datetimelike = needs_i8_conversion(values.dtype)
475 is_numeric = is_numeric_dtype(values.dtype)
477 if is_datetimelike:
478 values = values.view("int64")
479 is_numeric = True
480 elif is_bool_dtype(values.dtype):
481 values = ensure_float64(values)
482 elif is_integer_dtype(values):
483 # we use iNaT for the missing value on ints
484 # so pre-convert to guard this condition
485 if (values == iNaT).any():
486 values = ensure_float64(values)
487 else:
488 values = ensure_int_or_float(values)
489 elif is_numeric and not is_complex_dtype(values):
490 values = ensure_float64(values)
491 else:
492 values = values.astype(object)
494 arity = self._cython_arity.get(how, 1)
496 vdim = values.ndim
497 swapped = False
498 if vdim == 1:
499 values = values[:, None]
500 out_shape = (self.ngroups, arity)
501 else:
502 if axis > 0:
503 swapped = True
504 assert axis == 1, axis
505 values = values.T
506 if arity > 1:
507 raise NotImplementedError(
508 "arity of more than 1 is not supported for the 'how' argument"
509 )
510 out_shape = (self.ngroups,) + values.shape[1:]
512 func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric)
514 if how == "rank":
515 out_dtype = "float"
516 else:
517 if is_numeric:
518 out_dtype = f"{values.dtype.kind}{values.dtype.itemsize}"
519 else:
520 out_dtype = "object"
522 codes, _, _ = self.group_info
524 if kind == "aggregate":
525 result = _maybe_fill(
526 np.empty(out_shape, dtype=out_dtype), fill_value=np.nan
527 )
528 counts = np.zeros(self.ngroups, dtype=np.int64)
529 result = self._aggregate(
530 result, counts, values, codes, func, is_datetimelike, min_count
531 )
532 elif kind == "transform":
533 result = _maybe_fill(
534 np.empty_like(values, dtype=out_dtype), fill_value=np.nan
535 )
537 # TODO: min_count
538 result = self._transform(
539 result, values, codes, func, is_datetimelike, **kwargs
540 )
542 if is_integer_dtype(result) and not is_datetimelike:
543 mask = result == iNaT
544 if mask.any():
545 result = result.astype("float64")
546 result[mask] = np.nan
547 elif (
548 how == "add"
549 and is_integer_dtype(orig_values.dtype)
550 and is_extension_array_dtype(orig_values.dtype)
551 ):
552 # We need this to ensure that Series[Int64Dtype].resample().sum()
553 # remains int64 dtype.
554 # Two options for avoiding this special case
555 # 1. mask-aware ops and avoid casting to float with NaN above
556 # 2. specify the result dtype when calling this method
557 result = result.astype("int64")
559 if kind == "aggregate" and self._filter_empty_groups and not counts.all():
560 assert result.ndim != 2
561 result = result[counts > 0]
563 if vdim == 1 and arity == 1:
564 result = result[:, 0]
566 names: Optional[List[str]] = self._name_functions.get(how, None)
568 if swapped:
569 result = result.swapaxes(0, axis)
571 if is_datetime64tz_dtype(orig_values.dtype) or is_period_dtype(
572 orig_values.dtype
573 ):
574 # We need to use the constructors directly for these dtypes
575 # since numpy won't recognize them
576 # https://github.com/pandas-dev/pandas/issues/31471
577 result = type(orig_values)(result.astype(np.int64), dtype=orig_values.dtype)
578 elif is_datetimelike and kind == "aggregate":
579 result = result.astype(orig_values.dtype)
581 return result, names
583 def aggregate(
584 self, values, how: str, axis: int = 0, min_count: int = -1
585 ) -> Tuple[np.ndarray, Optional[List[str]]]:
586 return self._cython_operation(
587 "aggregate", values, how, axis, min_count=min_count
588 )
590 def transform(self, values, how: str, axis: int = 0, **kwargs):
591 return self._cython_operation("transform", values, how, axis, **kwargs)
593 def _aggregate(
594 self,
595 result,
596 counts,
597 values,
598 comp_ids,
599 agg_func,
600 is_datetimelike: bool,
601 min_count: int = -1,
602 ):
603 if agg_func is libgroupby.group_nth:
604 # different signature from the others
605 # TODO: should we be using min_count instead of hard-coding it?
606 agg_func(result, counts, values, comp_ids, rank=1, min_count=-1)
607 else:
608 agg_func(result, counts, values, comp_ids, min_count)
610 return result
612 def _transform(
613 self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs
614 ):
616 comp_ids, _, ngroups = self.group_info
617 transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
619 return result
621 def agg_series(self, obj: Series, func):
622 # Caller is responsible for checking ngroups != 0
623 assert self.ngroups != 0
625 if len(obj) == 0:
626 # SeriesGrouper would raise if we were to call _aggregate_series_fast
627 return self._aggregate_series_pure_python(obj, func)
629 elif is_extension_array_dtype(obj.dtype):
630 # _aggregate_series_fast would raise TypeError when
631 # calling libreduction.Slider
632 # In the datetime64tz case it would incorrectly cast to tz-naive
633 # TODO: can we get a performant workaround for EAs backed by ndarray?
634 return self._aggregate_series_pure_python(obj, func)
636 elif obj.index._has_complex_internals:
637 # Pre-empt TypeError in _aggregate_series_fast
638 return self._aggregate_series_pure_python(obj, func)
640 try:
641 return self._aggregate_series_fast(obj, func)
642 except ValueError as err:
643 if "Function does not reduce" in str(err):
644 # raised in libreduction
645 pass
646 else:
647 raise
648 return self._aggregate_series_pure_python(obj, func)
650 def _aggregate_series_fast(self, obj: Series, func):
651 # At this point we have already checked that
652 # - obj.index is not a MultiIndex
653 # - obj is backed by an ndarray, not ExtensionArray
654 # - len(obj) > 0
655 # - ngroups != 0
656 func = self._is_builtin_func(func)
658 group_index, _, ngroups = self.group_info
660 # avoids object / Series creation overhead
661 dummy = obj._get_values(slice(None, 0))
662 indexer = get_group_index_sorter(group_index, ngroups)
663 obj = obj.take(indexer)
664 group_index = algorithms.take_nd(group_index, indexer, allow_fill=False)
665 grouper = libreduction.SeriesGrouper(obj, func, group_index, ngroups, dummy)
666 result, counts = grouper.get_result()
667 return result, counts
669 def _aggregate_series_pure_python(self, obj: Series, func):
671 group_index, _, ngroups = self.group_info
673 counts = np.zeros(ngroups, dtype=int)
674 result = None
676 splitter = get_splitter(obj, group_index, ngroups, axis=0)
678 for label, group in splitter:
679 res = func(group)
680 if result is None:
681 if isinstance(res, (Series, Index, np.ndarray)):
682 if len(res) == 1:
683 # e.g. test_agg_lambda_with_timezone lambda e: e.head(1)
684 # FIXME: are we potentially losing import res.index info?
685 res = res.item()
686 else:
687 raise ValueError("Function does not reduce")
688 result = np.empty(ngroups, dtype="O")
690 counts[label] = group.shape[0]
691 result[label] = res
693 assert result is not None
694 result = lib.maybe_convert_objects(result, try_float=0)
695 # TODO: try_cast back to EA?
697 return result, counts
700class BinGrouper(BaseGrouper):
701 """
702 This is an internal Grouper class
704 Parameters
705 ----------
706 bins : the split index of binlabels to group the item of axis
707 binlabels : the label list
708 filter_empty : boolean, default False
709 mutated : boolean, default False
710 indexer : a intp array
712 Examples
713 --------
714 bins: [2, 4, 6, 8, 10]
715 binlabels: DatetimeIndex(['2005-01-01', '2005-01-03',
716 '2005-01-05', '2005-01-07', '2005-01-09'],
717 dtype='datetime64[ns]', freq='2D')
719 the group_info, which contains the label of each item in grouped
720 axis, the index of label in label list, group number, is
722 (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]), array([0, 1, 2, 3, 4]), 5)
724 means that, the grouped axis has 10 items, can be grouped into 5
725 labels, the first and second items belong to the first label, the
726 third and forth items belong to the second label, and so on
728 """
730 def __init__(
731 self,
732 bins,
733 binlabels,
734 filter_empty: bool = False,
735 mutated: bool = False,
736 indexer=None,
737 ):
738 self.bins = ensure_int64(bins)
739 self.binlabels = ensure_index(binlabels)
740 self._filter_empty_groups = filter_empty
741 self.mutated = mutated
742 self.indexer = indexer
744 # These lengths must match, otherwise we could call agg_series
745 # with empty self.bins, which would raise in libreduction.
746 assert len(self.binlabels) == len(self.bins)
748 @cache_readonly
749 def groups(self):
750 """ dict {group name -> group labels} """
752 # this is mainly for compat
753 # GH 3881
754 result = {
755 key: value
756 for key, value in zip(self.binlabels, self.bins)
757 if key is not NaT
758 }
759 return result
761 @property
762 def nkeys(self) -> int:
763 return 1
765 def _get_grouper(self):
766 """
767 We are a grouper as part of another's groupings.
769 We have a specific method of grouping, so cannot
770 convert to a Index for our grouper.
771 """
772 return self
774 def get_iterator(self, data: FrameOrSeries, axis: int = 0):
775 """
776 Groupby iterator
778 Returns
779 -------
780 Generator yielding sequence of (name, subsetted object)
781 for each group
782 """
783 slicer = lambda start, edge: data._slice(slice(start, edge), axis=axis)
784 length = len(data.axes[axis])
786 start = 0
787 for edge, label in zip(self.bins, self.binlabels):
788 if label is not NaT:
789 yield label, slicer(start, edge)
790 start = edge
792 if start < length:
793 yield self.binlabels[-1], slicer(start, None)
795 @cache_readonly
796 def indices(self):
797 indices = collections.defaultdict(list)
799 i = 0
800 for label, bin in zip(self.binlabels, self.bins):
801 if i < bin:
802 if label is not NaT:
803 indices[label] = list(range(i, bin))
804 i = bin
805 return indices
807 @cache_readonly
808 def group_info(self):
809 ngroups = self.ngroups
810 obs_group_ids = np.arange(ngroups)
811 rep = np.diff(np.r_[0, self.bins])
813 rep = ensure_platform_int(rep)
814 if ngroups == len(self.bins):
815 comp_ids = np.repeat(np.arange(ngroups), rep)
816 else:
817 comp_ids = np.repeat(np.r_[-1, np.arange(ngroups)], rep)
819 return (
820 comp_ids.astype("int64", copy=False),
821 obs_group_ids.astype("int64", copy=False),
822 ngroups,
823 )
825 @cache_readonly
826 def reconstructed_codes(self) -> List[np.ndarray]:
827 # get unique result indices, and prepend 0 as groupby starts from the first
828 return [np.r_[0, np.flatnonzero(self.bins[1:] != self.bins[:-1]) + 1]]
830 @cache_readonly
831 def result_index(self):
832 if len(self.binlabels) != 0 and isna(self.binlabels[0]):
833 return self.binlabels[1:]
835 return self.binlabels
837 @property
838 def levels(self):
839 return [self.binlabels]
841 @property
842 def names(self):
843 return [self.binlabels.name]
845 @property
846 def groupings(self) -> "List[grouper.Grouping]":
847 return [
848 grouper.Grouping(lvl, lvl, in_axis=False, level=None, name=name)
849 for lvl, name in zip(self.levels, self.names)
850 ]
852 def agg_series(self, obj: Series, func):
853 # Caller is responsible for checking ngroups != 0
854 assert self.ngroups != 0
855 assert len(self.bins) > 0 # otherwise we'd get IndexError in get_result
857 if is_extension_array_dtype(obj.dtype):
858 # pre-empt SeriesBinGrouper from raising TypeError
859 return self._aggregate_series_pure_python(obj, func)
861 dummy = obj[:0]
862 grouper = libreduction.SeriesBinGrouper(obj, func, self.bins, dummy)
863 return grouper.get_result()
866def _is_indexed_like(obj, axes) -> bool:
867 if isinstance(obj, Series):
868 if len(axes) > 1:
869 return False
870 return obj.index.equals(axes[0])
871 elif isinstance(obj, DataFrame):
872 return obj.index.equals(axes[0])
874 return False
877# ----------------------------------------------------------------------
878# Splitting / application
881class DataSplitter:
882 def __init__(self, data: FrameOrSeries, labels, ngroups: int, axis: int = 0):
883 self.data = data
884 self.labels = ensure_int64(labels)
885 self.ngroups = ngroups
887 self.axis = axis
888 assert isinstance(axis, int), axis
890 @cache_readonly
891 def slabels(self):
892 # Sorted labels
893 return algorithms.take_nd(self.labels, self.sort_idx, allow_fill=False)
895 @cache_readonly
896 def sort_idx(self):
897 # Counting sort indexer
898 return get_group_index_sorter(self.labels, self.ngroups)
900 def __iter__(self):
901 sdata = self._get_sorted_data()
903 if self.ngroups == 0:
904 # we are inside a generator, rather than raise StopIteration
905 # we merely return signal the end
906 return
908 starts, ends = lib.generate_slices(self.slabels, self.ngroups)
910 for i, (start, end) in enumerate(zip(starts, ends)):
911 yield i, self._chop(sdata, slice(start, end))
913 def _get_sorted_data(self) -> FrameOrSeries:
914 return self.data.take(self.sort_idx, axis=self.axis)
916 def _chop(self, sdata, slice_obj: slice) -> NDFrame:
917 raise AbstractMethodError(self)
920class SeriesSplitter(DataSplitter):
921 def _chop(self, sdata: Series, slice_obj: slice) -> Series:
922 return sdata._get_values(slice_obj)
925class FrameSplitter(DataSplitter):
926 def fast_apply(self, f, names):
927 # must return keys::list, values::list, mutated::bool
928 starts, ends = lib.generate_slices(self.slabels, self.ngroups)
930 sdata = self._get_sorted_data()
931 return libreduction.apply_frame_axis0(sdata, f, names, starts, ends)
933 def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
934 if self.axis == 0:
935 return sdata.iloc[slice_obj]
936 else:
937 return sdata._slice(slice_obj, axis=1)
940def get_splitter(data: FrameOrSeries, *args, **kwargs) -> DataSplitter:
941 if isinstance(data, Series):
942 klass: Type[DataSplitter] = SeriesSplitter
943 else:
944 # i.e. DataFrame
945 klass = FrameSplitter
947 return klass(data, *args, **kwargs)