docs for muutils v0.8.3
View Source on GitHub

muutils.parallel


  1import multiprocessing
  2import functools
  3from typing import (
  4    Any,
  5    Callable,
  6    Iterable,
  7    Literal,
  8    Optional,
  9    Tuple,
 10    TypeVar,
 11    Dict,
 12    List,
 13    Union,
 14    Protocol,
 15)
 16
 17# for no tqdm fallback
 18from muutils.spinner import SpinnerContext
 19from muutils.validate_type import get_fn_allowed_kwargs
 20
 21
 22InputType = TypeVar("InputType")
 23OutputType = TypeVar("OutputType")
 24# typevars for our iterable and map
 25
 26
 27class ProgressBarFunction(Protocol):
 28    "a protocol for a progress bar function"
 29
 30    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
 31
 32
 33ProgressBarOption = Literal["tqdm", "spinner", "none", None]
 34# type for the progress bar option
 35
 36
 37DEFAULT_PBAR_FN: ProgressBarOption
 38# default progress bar function
 39
 40try:
 41    # use tqdm if it's available
 42    import tqdm  # type: ignore[import-untyped]
 43
 44    DEFAULT_PBAR_FN = "tqdm"
 45
 46except ImportError:
 47    # use progress bar as fallback
 48    DEFAULT_PBAR_FN = "spinner"
 49
 50
 51def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
 52    "spinner wrapper"
 53    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
 54        SpinnerContext.__init__
 55    )
 56    mapped_kwargs: dict = {
 57        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
 58    }
 59    if "desc" in kwargs and "message" not in mapped_kwargs:
 60        mapped_kwargs["message"] = kwargs["desc"]
 61
 62    if "message" not in mapped_kwargs and "total" in kwargs:
 63        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
 64
 65    with SpinnerContext(**mapped_kwargs):
 66        output = list(x)
 67
 68    return output
 69
 70
 71def map_kwargs_for_tqdm(kwargs: dict) -> dict:
 72    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
 73    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
 74    print(f"{tqdm_allowed_kwargs = }")
 75    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
 76
 77    if "desc" not in kwargs:
 78        if "message" in kwargs:
 79            mapped_kwargs["desc"] = kwargs["message"]
 80
 81        elif "total" in kwargs:
 82            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
 83    print(f"{mapped_kwargs = }")
 84    return mapped_kwargs
 85
 86
 87def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
 88    "fallback to no progress bar"
 89    return x
 90
 91
 92def set_up_progress_bar_fn(
 93    pbar: Union[ProgressBarFunction, ProgressBarOption],
 94    pbar_kwargs: Optional[Dict[str, Any]] = None,
 95    **extra_kwargs,
 96) -> Tuple[ProgressBarFunction, dict]:
 97    """set up the progress bar function and its kwargs
 98
 99    # Parameters:
100     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
101       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
102     - `pbar_kwargs : Optional[Dict[str, Any]]`
103       kwargs passed to the progress bar function (default to `None`)
104       (defaults to `None`)
105
106    # Returns:
107     - `Tuple[ProgressBarFunction, dict]`
108         a tuple of the progress bar function and its kwargs
109
110    # Raises:
111     - `ValueError` : if `pbar` is not one of the valid options
112    """
113    pbar_fn: ProgressBarFunction
114
115    if pbar_kwargs is None:
116        pbar_kwargs = dict()
117
118    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
119
120    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
121    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
122        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]
123
124    # if `pbar` is a different string, figure out which progress bar to use
125    elif isinstance(pbar, str):
126        if pbar == "tqdm":
127            pbar_fn = tqdm.tqdm
128            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
129        elif pbar == "spinner":
130            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
131            pbar_kwargs = dict()
132        else:
133            raise ValueError(
134                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
135            )
136    else:
137        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
138        pbar_fn = pbar
139
140    return pbar_fn, pbar_kwargs
141
142
143def run_maybe_parallel(
144    func: Callable[[InputType], OutputType],
145    iterable: Iterable[InputType],
146    parallel: Union[bool, int],
147    pbar_kwargs: Optional[Dict[str, Any]] = None,
148    chunksize: Optional[int] = None,
149    keep_ordered: bool = True,
150    use_multiprocess: bool = False,
151    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
152) -> List[OutputType]:
153    """a function to make it easier to sometimes parallelize an operation
154
155    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
156    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
157    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
158
159    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
160
161    # Parameters:
162     - `func : Callable[[InputType], OutputType]`
163       function passed to either `map` or `Pool.imap`
164     - `iterable : Iterable[InputType]`
165       iterable passed to either `map` or `Pool.imap`
166     - `parallel : bool | int`
167       whether to run in parallel, and how many processes to use
168     - `pbar_kwargs : Dict[str, Any]`
169       kwargs passed to the progress bar function
170
171    # Returns:
172     - `List[OutputType]`
173       a list of the output of `func` for each element in `iterable`
174
175    # Raises:
176     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
177     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
178     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
179    """
180
181    # number of inputs in iterable
182    n_inputs: int = len(iterable)  # type: ignore[arg-type]
183    if n_inputs == 0:
184        # Return immediately if there is no input
185        return list()
186
187    # which progress bar to use
188    pbar_fn: ProgressBarFunction
189    pbar_kwargs_processed: dict
190    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
191        pbar=pbar,
192        pbar_kwargs=pbar_kwargs,
193        # extra kwargs
194        total=n_inputs,
195    )
196
197    # number of processes
198    num_processes: int
199    if isinstance(parallel, bool):
200        num_processes = multiprocessing.cpu_count() if parallel else 1
201    elif isinstance(parallel, int):
202        if parallel < 2:
203            raise ValueError(
204                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
205            )
206        num_processes = parallel
207    else:
208        raise ValueError(
209            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
210        )
211
212    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
213    num_processes = min(num_processes, n_inputs)
214    mp = multiprocessing
215    if num_processes == 1:
216        parallel = False
217
218    if use_multiprocess:
219        if not parallel:
220            raise ValueError("`use_multiprocess=True` requires `parallel=True`")
221
222        try:
223            import multiprocess  # type: ignore[import-untyped]
224        except ImportError as e:
225            raise ImportError(
226                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
227            ) from e
228
229        mp = multiprocess
230
231    # set up the map function -- maybe its parallel, maybe it's just `map`
232    do_map: Callable[
233        [Callable[[InputType], OutputType], Iterable[InputType]],
234        Iterable[OutputType],
235    ]
236    if parallel:
237        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
238        pool = mp.Pool(num_processes)
239
240        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
241        if keep_ordered:
242            do_map = pool.imap
243        else:
244            do_map = pool.imap_unordered
245
246        # figure out a smart chunksize if one is not given
247        chunksize_int: int
248        if chunksize is None:
249            chunksize_int = max(1, n_inputs // num_processes)
250        else:
251            chunksize_int = chunksize
252
253        # set the chunksize
254        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore
255
256    else:
257        do_map = map
258
259    # run the map function with a progress bar
260    output: List[OutputType] = list(
261        pbar_fn(
262            do_map(
263                func,
264                iterable,
265            ),
266            **pbar_kwargs_processed,
267        )
268    )
269
270    # close the pool if we used one
271    if parallel:
272        pool.close()
273        pool.join()
274
275    # return the output as a list
276    return output

class ProgressBarFunction(typing.Protocol):
28class ProgressBarFunction(Protocol):
29    "a protocol for a progress bar function"
30
31    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...

a protocol for a progress bar function

ProgressBarFunction(*args, **kwargs)
1710def _no_init_or_replace_init(self, *args, **kwargs):
1711    cls = type(self)
1712
1713    if cls._is_protocol:
1714        raise TypeError('Protocols cannot be instantiated')
1715
1716    # Already using a custom `__init__`. No need to calculate correct
1717    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1718    if cls.__init__ is not _no_init_or_replace_init:
1719        return
1720
1721    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1722    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1723    # searches for a proper new `__init__` in the MRO. The new `__init__`
1724    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1725    # instantiation of the protocol subclass will thus use the new
1726    # `__init__` and no longer call `_no_init_or_replace_init`.
1727    for base in cls.__mro__:
1728        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1729        if init is not _no_init_or_replace_init:
1730            cls.__init__ = init
1731            break
1732    else:
1733        # should not happen
1734        cls.__init__ = object.__init__
1735
1736    cls.__init__(self, *args, **kwargs)
ProgressBarOption = typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] = 'tqdm'
def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
52def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
53    "spinner wrapper"
54    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
55        SpinnerContext.__init__
56    )
57    mapped_kwargs: dict = {
58        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
59    }
60    if "desc" in kwargs and "message" not in mapped_kwargs:
61        mapped_kwargs["message"] = kwargs["desc"]
62
63    if "message" not in mapped_kwargs and "total" in kwargs:
64        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
65
66    with SpinnerContext(**mapped_kwargs):
67        output = list(x)
68
69    return output

spinner wrapper

def map_kwargs_for_tqdm(kwargs: dict) -> dict:
72def map_kwargs_for_tqdm(kwargs: dict) -> dict:
73    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
74    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
75    print(f"{tqdm_allowed_kwargs = }")
76    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
77
78    if "desc" not in kwargs:
79        if "message" in kwargs:
80            mapped_kwargs["desc"] = kwargs["message"]
81
82        elif "total" in kwargs:
83            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
84    print(f"{mapped_kwargs = }")
85    return mapped_kwargs

map kwargs for tqdm, cant wrap because the pbar dissapears?

def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
88def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
89    "fallback to no progress bar"
90    return x

fallback to no progress bar

def set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs) -> Tuple[ProgressBarFunction, dict]:
 93def set_up_progress_bar_fn(
 94    pbar: Union[ProgressBarFunction, ProgressBarOption],
 95    pbar_kwargs: Optional[Dict[str, Any]] = None,
 96    **extra_kwargs,
 97) -> Tuple[ProgressBarFunction, dict]:
 98    """set up the progress bar function and its kwargs
 99
100    # Parameters:
101     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
102       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
103     - `pbar_kwargs : Optional[Dict[str, Any]]`
104       kwargs passed to the progress bar function (default to `None`)
105       (defaults to `None`)
106
107    # Returns:
108     - `Tuple[ProgressBarFunction, dict]`
109         a tuple of the progress bar function and its kwargs
110
111    # Raises:
112     - `ValueError` : if `pbar` is not one of the valid options
113    """
114    pbar_fn: ProgressBarFunction
115
116    if pbar_kwargs is None:
117        pbar_kwargs = dict()
118
119    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
120
121    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
122    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
123        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]
124
125    # if `pbar` is a different string, figure out which progress bar to use
126    elif isinstance(pbar, str):
127        if pbar == "tqdm":
128            pbar_fn = tqdm.tqdm
129            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
130        elif pbar == "spinner":
131            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
132            pbar_kwargs = dict()
133        else:
134            raise ValueError(
135                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
136            )
137    else:
138        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
139        pbar_fn = pbar
140
141    return pbar_fn, pbar_kwargs

set up the progress bar function and its kwargs

Parameters:

  • pbar : Union[ProgressBarFunction, ProgressBarOption] progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
  • pbar_kwargs : Optional[Dict[str, Any]] kwargs passed to the progress bar function (default to None) (defaults to None)

Returns:

  • Tuple[ProgressBarFunction, dict] a tuple of the progress bar function and its kwargs

Raises:

  • ValueError : if pbar is not one of the valid options
def run_maybe_parallel( func: Callable[[~InputType], ~OutputType], iterable: Iterable[~InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm') -> List[~OutputType]:
144def run_maybe_parallel(
145    func: Callable[[InputType], OutputType],
146    iterable: Iterable[InputType],
147    parallel: Union[bool, int],
148    pbar_kwargs: Optional[Dict[str, Any]] = None,
149    chunksize: Optional[int] = None,
150    keep_ordered: bool = True,
151    use_multiprocess: bool = False,
152    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
153) -> List[OutputType]:
154    """a function to make it easier to sometimes parallelize an operation
155
156    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
157    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
158    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
159
160    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
161
162    # Parameters:
163     - `func : Callable[[InputType], OutputType]`
164       function passed to either `map` or `Pool.imap`
165     - `iterable : Iterable[InputType]`
166       iterable passed to either `map` or `Pool.imap`
167     - `parallel : bool | int`
168       whether to run in parallel, and how many processes to use
169     - `pbar_kwargs : Dict[str, Any]`
170       kwargs passed to the progress bar function
171
172    # Returns:
173     - `List[OutputType]`
174       a list of the output of `func` for each element in `iterable`
175
176    # Raises:
177     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
178     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
179     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
180    """
181
182    # number of inputs in iterable
183    n_inputs: int = len(iterable)  # type: ignore[arg-type]
184    if n_inputs == 0:
185        # Return immediately if there is no input
186        return list()
187
188    # which progress bar to use
189    pbar_fn: ProgressBarFunction
190    pbar_kwargs_processed: dict
191    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
192        pbar=pbar,
193        pbar_kwargs=pbar_kwargs,
194        # extra kwargs
195        total=n_inputs,
196    )
197
198    # number of processes
199    num_processes: int
200    if isinstance(parallel, bool):
201        num_processes = multiprocessing.cpu_count() if parallel else 1
202    elif isinstance(parallel, int):
203        if parallel < 2:
204            raise ValueError(
205                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
206            )
207        num_processes = parallel
208    else:
209        raise ValueError(
210            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
211        )
212
213    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
214    num_processes = min(num_processes, n_inputs)
215    mp = multiprocessing
216    if num_processes == 1:
217        parallel = False
218
219    if use_multiprocess:
220        if not parallel:
221            raise ValueError("`use_multiprocess=True` requires `parallel=True`")
222
223        try:
224            import multiprocess  # type: ignore[import-untyped]
225        except ImportError as e:
226            raise ImportError(
227                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
228            ) from e
229
230        mp = multiprocess
231
232    # set up the map function -- maybe its parallel, maybe it's just `map`
233    do_map: Callable[
234        [Callable[[InputType], OutputType], Iterable[InputType]],
235        Iterable[OutputType],
236    ]
237    if parallel:
238        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
239        pool = mp.Pool(num_processes)
240
241        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
242        if keep_ordered:
243            do_map = pool.imap
244        else:
245            do_map = pool.imap_unordered
246
247        # figure out a smart chunksize if one is not given
248        chunksize_int: int
249        if chunksize is None:
250            chunksize_int = max(1, n_inputs // num_processes)
251        else:
252            chunksize_int = chunksize
253
254        # set the chunksize
255        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore
256
257    else:
258        do_map = map
259
260    # run the map function with a progress bar
261    output: List[OutputType] = list(
262        pbar_fn(
263            do_map(
264                func,
265                iterable,
266            ),
267            **pbar_kwargs_processed,
268        )
269    )
270
271    # close the pool if we used one
272    if parallel:
273        pool.close()
274        pool.join()
275
276    # return the output as a list
277    return output

a function to make it easier to sometimes parallelize an operation

  • if parallel is False, then the function will run in serial, running map(func, iterable)
  • if parallel is True, then the function will run in parallel, running in parallel with the maximum number of processes
  • if parallel is an int, it must be greater than 1, and the function will run in parallel with the number of processes specified by parallel

the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())

Parameters:

  • func : Callable[[InputType], OutputType] function passed to either map or Pool.imap
  • iterable : Iterable[InputType] iterable passed to either map or Pool.imap
  • parallel : bool | int whether to run in parallel, and how many processes to use
  • pbar_kwargs : Dict[str, Any] kwargs passed to the progress bar function

Returns:

  • List[OutputType] a list of the output of func for each element in iterable

Raises:

  • ValueError : if parallel is not a boolean or an integer greater than 1
  • ValueError : if use_multiprocess=True and parallel=False
  • ImportError : if use_multiprocess=True and multiprocess is not available