muutils.parallel
1import multiprocessing 2import functools 3from typing import ( 4 Any, 5 Callable, 6 Iterable, 7 Literal, 8 Optional, 9 Tuple, 10 TypeVar, 11 Dict, 12 List, 13 Union, 14 Protocol, 15) 16 17# for no tqdm fallback 18from muutils.spinner import SpinnerContext 19from muutils.validate_type import get_fn_allowed_kwargs 20 21 22InputType = TypeVar("InputType") 23OutputType = TypeVar("OutputType") 24# typevars for our iterable and map 25 26 27class ProgressBarFunction(Protocol): 28 "a protocol for a progress bar function" 29 30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 31 32 33ProgressBarOption = Literal["tqdm", "spinner", "none", None] 34# type for the progress bar option 35 36 37DEFAULT_PBAR_FN: ProgressBarOption 38# default progress bar function 39 40try: 41 # use tqdm if it's available 42 import tqdm # type: ignore[import-untyped] 43 44 DEFAULT_PBAR_FN = "tqdm" 45 46except ImportError: 47 # use progress bar as fallback 48 DEFAULT_PBAR_FN = "spinner" 49 50 51def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 52 "spinner wrapper" 53 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 54 SpinnerContext.__init__ 55 ) 56 mapped_kwargs: dict = { 57 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 58 } 59 if "desc" in kwargs and "message" not in mapped_kwargs: 60 mapped_kwargs["message"] = kwargs["desc"] 61 62 if "message" not in mapped_kwargs and "total" in kwargs: 63 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 64 65 with SpinnerContext(**mapped_kwargs): 66 output = list(x) 67 68 return output 69 70 71def map_kwargs_for_tqdm(kwargs: dict) -> dict: 72 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 73 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 74 print(f"{tqdm_allowed_kwargs = }") 75 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 76 77 if "desc" not in kwargs: 78 if "message" in kwargs: 79 mapped_kwargs["desc"] = kwargs["message"] 80 81 elif "total" in kwargs: 82 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 83 print(f"{mapped_kwargs = }") 84 return mapped_kwargs 85 86 87def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 88 "fallback to no progress bar" 89 return x 90 91 92def set_up_progress_bar_fn( 93 pbar: Union[ProgressBarFunction, ProgressBarOption], 94 pbar_kwargs: Optional[Dict[str, Any]] = None, 95 **extra_kwargs, 96) -> Tuple[ProgressBarFunction, dict]: 97 """set up the progress bar function and its kwargs 98 99 # Parameters: 100 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 101 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 102 - `pbar_kwargs : Optional[Dict[str, Any]]` 103 kwargs passed to the progress bar function (default to `None`) 104 (defaults to `None`) 105 106 # Returns: 107 - `Tuple[ProgressBarFunction, dict]` 108 a tuple of the progress bar function and its kwargs 109 110 # Raises: 111 - `ValueError` : if `pbar` is not one of the valid options 112 """ 113 pbar_fn: ProgressBarFunction 114 115 if pbar_kwargs is None: 116 pbar_kwargs = dict() 117 118 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 119 120 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 121 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 122 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 123 124 # if `pbar` is a different string, figure out which progress bar to use 125 elif isinstance(pbar, str): 126 if pbar == "tqdm": 127 pbar_fn = tqdm.tqdm 128 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 129 elif pbar == "spinner": 130 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 131 pbar_kwargs = dict() 132 else: 133 raise ValueError( 134 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 135 ) 136 else: 137 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 138 pbar_fn = pbar 139 140 return pbar_fn, pbar_kwargs 141 142 143def run_maybe_parallel( 144 func: Callable[[InputType], OutputType], 145 iterable: Iterable[InputType], 146 parallel: Union[bool, int], 147 pbar_kwargs: Optional[Dict[str, Any]] = None, 148 chunksize: Optional[int] = None, 149 keep_ordered: bool = True, 150 use_multiprocess: bool = False, 151 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 152) -> List[OutputType]: 153 """a function to make it easier to sometimes parallelize an operation 154 155 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 156 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 157 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 158 159 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 160 161 # Parameters: 162 - `func : Callable[[InputType], OutputType]` 163 function passed to either `map` or `Pool.imap` 164 - `iterable : Iterable[InputType]` 165 iterable passed to either `map` or `Pool.imap` 166 - `parallel : bool | int` 167 whether to run in parallel, and how many processes to use 168 - `pbar_kwargs : Dict[str, Any]` 169 kwargs passed to the progress bar function 170 171 # Returns: 172 - `List[OutputType]` 173 a list of the output of `func` for each element in `iterable` 174 175 # Raises: 176 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 177 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 178 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 179 """ 180 181 # number of inputs in iterable 182 n_inputs: int = len(iterable) # type: ignore[arg-type] 183 if n_inputs == 0: 184 # Return immediately if there is no input 185 return list() 186 187 # which progress bar to use 188 pbar_fn: ProgressBarFunction 189 pbar_kwargs_processed: dict 190 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 191 pbar=pbar, 192 pbar_kwargs=pbar_kwargs, 193 # extra kwargs 194 total=n_inputs, 195 ) 196 197 # number of processes 198 num_processes: int 199 if isinstance(parallel, bool): 200 num_processes = multiprocessing.cpu_count() if parallel else 1 201 elif isinstance(parallel, int): 202 if parallel < 2: 203 raise ValueError( 204 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 205 ) 206 num_processes = parallel 207 else: 208 raise ValueError( 209 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 210 ) 211 212 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 213 num_processes = min(num_processes, n_inputs) 214 mp = multiprocessing 215 if num_processes == 1: 216 parallel = False 217 218 if use_multiprocess: 219 if not parallel: 220 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 221 222 try: 223 import multiprocess # type: ignore[import-untyped] 224 except ImportError as e: 225 raise ImportError( 226 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 227 ) from e 228 229 mp = multiprocess 230 231 # set up the map function -- maybe its parallel, maybe it's just `map` 232 do_map: Callable[ 233 [Callable[[InputType], OutputType], Iterable[InputType]], 234 Iterable[OutputType], 235 ] 236 if parallel: 237 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 238 pool = mp.Pool(num_processes) 239 240 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 241 if keep_ordered: 242 do_map = pool.imap 243 else: 244 do_map = pool.imap_unordered 245 246 # figure out a smart chunksize if one is not given 247 chunksize_int: int 248 if chunksize is None: 249 chunksize_int = max(1, n_inputs // num_processes) 250 else: 251 chunksize_int = chunksize 252 253 # set the chunksize 254 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 255 256 else: 257 do_map = map 258 259 # run the map function with a progress bar 260 output: List[OutputType] = list( 261 pbar_fn( 262 do_map( 263 func, 264 iterable, 265 ), 266 **pbar_kwargs_processed, 267 ) 268 ) 269 270 # close the pool if we used one 271 if parallel: 272 pool.close() 273 pool.join() 274 275 # return the output as a list 276 return output
class
ProgressBarFunction(typing.Protocol):
28class ProgressBarFunction(Protocol): 29 "a protocol for a progress bar function" 30 31 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
a protocol for a progress bar function
ProgressBarFunction(*args, **kwargs)
1710def _no_init_or_replace_init(self, *args, **kwargs): 1711 cls = type(self) 1712 1713 if cls._is_protocol: 1714 raise TypeError('Protocols cannot be instantiated') 1715 1716 # Already using a custom `__init__`. No need to calculate correct 1717 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1718 if cls.__init__ is not _no_init_or_replace_init: 1719 return 1720 1721 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1722 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1723 # searches for a proper new `__init__` in the MRO. The new `__init__` 1724 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1725 # instantiation of the protocol subclass will thus use the new 1726 # `__init__` and no longer call `_no_init_or_replace_init`. 1727 for base in cls.__mro__: 1728 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1729 if init is not _no_init_or_replace_init: 1730 cls.__init__ = init 1731 break 1732 else: 1733 # should not happen 1734 cls.__init__ = object.__init__ 1735 1736 cls.__init__(self, *args, **kwargs)
ProgressBarOption =
typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] =
'tqdm'
def
spinner_fn_wrap(x: Iterable, **kwargs) -> List:
52def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 53 "spinner wrapper" 54 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 55 SpinnerContext.__init__ 56 ) 57 mapped_kwargs: dict = { 58 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 59 } 60 if "desc" in kwargs and "message" not in mapped_kwargs: 61 mapped_kwargs["message"] = kwargs["desc"] 62 63 if "message" not in mapped_kwargs and "total" in kwargs: 64 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 65 66 with SpinnerContext(**mapped_kwargs): 67 output = list(x) 68 69 return output
spinner wrapper
def
map_kwargs_for_tqdm(kwargs: dict) -> dict:
72def map_kwargs_for_tqdm(kwargs: dict) -> dict: 73 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 74 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 75 print(f"{tqdm_allowed_kwargs = }") 76 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 77 78 if "desc" not in kwargs: 79 if "message" in kwargs: 80 mapped_kwargs["desc"] = kwargs["message"] 81 82 elif "total" in kwargs: 83 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 84 print(f"{mapped_kwargs = }") 85 return mapped_kwargs
map kwargs for tqdm, cant wrap because the pbar dissapears?
def
no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
88def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 89 "fallback to no progress bar" 90 return x
fallback to no progress bar
def
set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs) -> Tuple[ProgressBarFunction, dict]:
93def set_up_progress_bar_fn( 94 pbar: Union[ProgressBarFunction, ProgressBarOption], 95 pbar_kwargs: Optional[Dict[str, Any]] = None, 96 **extra_kwargs, 97) -> Tuple[ProgressBarFunction, dict]: 98 """set up the progress bar function and its kwargs 99 100 # Parameters: 101 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 102 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 103 - `pbar_kwargs : Optional[Dict[str, Any]]` 104 kwargs passed to the progress bar function (default to `None`) 105 (defaults to `None`) 106 107 # Returns: 108 - `Tuple[ProgressBarFunction, dict]` 109 a tuple of the progress bar function and its kwargs 110 111 # Raises: 112 - `ValueError` : if `pbar` is not one of the valid options 113 """ 114 pbar_fn: ProgressBarFunction 115 116 if pbar_kwargs is None: 117 pbar_kwargs = dict() 118 119 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 120 121 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 122 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 123 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 124 125 # if `pbar` is a different string, figure out which progress bar to use 126 elif isinstance(pbar, str): 127 if pbar == "tqdm": 128 pbar_fn = tqdm.tqdm 129 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 130 elif pbar == "spinner": 131 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 132 pbar_kwargs = dict() 133 else: 134 raise ValueError( 135 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 136 ) 137 else: 138 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 139 pbar_fn = pbar 140 141 return pbar_fn, pbar_kwargs
set up the progress bar function and its kwargs
Parameters:
pbar : Union[ProgressBarFunction, ProgressBarOption]
progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to usepbar_kwargs : Optional[Dict[str, Any]]
kwargs passed to the progress bar function (default toNone
) (defaults toNone
)
Returns:
Tuple[ProgressBarFunction, dict]
a tuple of the progress bar function and its kwargs
Raises:
ValueError
: ifpbar
is not one of the valid options
def
run_maybe_parallel( func: Callable[[~InputType], ~OutputType], iterable: Iterable[~InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm') -> List[~OutputType]:
144def run_maybe_parallel( 145 func: Callable[[InputType], OutputType], 146 iterable: Iterable[InputType], 147 parallel: Union[bool, int], 148 pbar_kwargs: Optional[Dict[str, Any]] = None, 149 chunksize: Optional[int] = None, 150 keep_ordered: bool = True, 151 use_multiprocess: bool = False, 152 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 153) -> List[OutputType]: 154 """a function to make it easier to sometimes parallelize an operation 155 156 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 157 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 158 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 159 160 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 161 162 # Parameters: 163 - `func : Callable[[InputType], OutputType]` 164 function passed to either `map` or `Pool.imap` 165 - `iterable : Iterable[InputType]` 166 iterable passed to either `map` or `Pool.imap` 167 - `parallel : bool | int` 168 whether to run in parallel, and how many processes to use 169 - `pbar_kwargs : Dict[str, Any]` 170 kwargs passed to the progress bar function 171 172 # Returns: 173 - `List[OutputType]` 174 a list of the output of `func` for each element in `iterable` 175 176 # Raises: 177 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 178 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 179 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 180 """ 181 182 # number of inputs in iterable 183 n_inputs: int = len(iterable) # type: ignore[arg-type] 184 if n_inputs == 0: 185 # Return immediately if there is no input 186 return list() 187 188 # which progress bar to use 189 pbar_fn: ProgressBarFunction 190 pbar_kwargs_processed: dict 191 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 192 pbar=pbar, 193 pbar_kwargs=pbar_kwargs, 194 # extra kwargs 195 total=n_inputs, 196 ) 197 198 # number of processes 199 num_processes: int 200 if isinstance(parallel, bool): 201 num_processes = multiprocessing.cpu_count() if parallel else 1 202 elif isinstance(parallel, int): 203 if parallel < 2: 204 raise ValueError( 205 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 206 ) 207 num_processes = parallel 208 else: 209 raise ValueError( 210 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 211 ) 212 213 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 214 num_processes = min(num_processes, n_inputs) 215 mp = multiprocessing 216 if num_processes == 1: 217 parallel = False 218 219 if use_multiprocess: 220 if not parallel: 221 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 222 223 try: 224 import multiprocess # type: ignore[import-untyped] 225 except ImportError as e: 226 raise ImportError( 227 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 228 ) from e 229 230 mp = multiprocess 231 232 # set up the map function -- maybe its parallel, maybe it's just `map` 233 do_map: Callable[ 234 [Callable[[InputType], OutputType], Iterable[InputType]], 235 Iterable[OutputType], 236 ] 237 if parallel: 238 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 239 pool = mp.Pool(num_processes) 240 241 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 242 if keep_ordered: 243 do_map = pool.imap 244 else: 245 do_map = pool.imap_unordered 246 247 # figure out a smart chunksize if one is not given 248 chunksize_int: int 249 if chunksize is None: 250 chunksize_int = max(1, n_inputs // num_processes) 251 else: 252 chunksize_int = chunksize 253 254 # set the chunksize 255 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 256 257 else: 258 do_map = map 259 260 # run the map function with a progress bar 261 output: List[OutputType] = list( 262 pbar_fn( 263 do_map( 264 func, 265 iterable, 266 ), 267 **pbar_kwargs_processed, 268 ) 269 ) 270 271 # close the pool if we used one 272 if parallel: 273 pool.close() 274 pool.join() 275 276 # return the output as a list 277 return output
a function to make it easier to sometimes parallelize an operation
- if
parallel
isFalse
, then the function will run in serial, runningmap(func, iterable)
- if
parallel
isTrue
, then the function will run in parallel, running in parallel with the maximum number of processes - if
parallel
is anint
, it must be greater than 1, and the function will run in parallel with the number of processes specified byparallel
the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())
Parameters:
func : Callable[[InputType], OutputType]
function passed to eithermap
orPool.imap
iterable : Iterable[InputType]
iterable passed to eithermap
orPool.imap
parallel : bool | int
whether to run in parallel, and how many processes to usepbar_kwargs : Dict[str, Any]
kwargs passed to the progress bar function
Returns:
List[OutputType]
a list of the output offunc
for each element initerable
Raises:
ValueError
: ifparallel
is not a boolean or an integer greater than 1ValueError
: ifuse_multiprocess=True
andparallel=False
ImportError
: ifuse_multiprocess=True
andmultiprocess
is not available