Coverage for muutils\parallel.py: 93%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-01-17 00:00 -0700

1import multiprocessing 

2import functools 

3from typing import ( 

4 Any, 

5 Callable, 

6 Iterable, 

7 Literal, 

8 Optional, 

9 TypeVar, 

10 Dict, 

11 List, 

12 Union, 

13 Protocol, 

14) 

15 

16# for no tqdm fallback 

17from muutils.spinner import SpinnerContext 

18from muutils.validate_type import get_fn_allowed_kwargs 

19 

20 

21InputType = TypeVar("InputType") 

22OutputType = TypeVar("OutputType") 

23# typevars for our iterable and map 

24 

25 

26class ProgressBarFunction(Protocol): 

27 "a protocol for a progress bar function" 

28 

29 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 

30 

31 

32ProgressBarOption = Literal["tqdm", "spinner", "none", None] 

33 

34 

35DEFAULT_PBAR_FN: Callable 

36# default progress bar function 

37 

38 

39# fallback to spinner option 

40def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 

41 mapped_kwargs: dict = { 

42 k: v 

43 for k, v in kwargs.items() 

44 if k in get_fn_allowed_kwargs(SpinnerContext.__init__) 

45 } 

46 if "desc" in kwargs and "message" not in mapped_kwargs: 

47 mapped_kwargs["message"] = kwargs.get("desc") 

48 

49 if "message" not in mapped_kwargs and "total" in kwargs: 

50 mapped_kwargs["message"] = f"Processing {kwargs.get('total')} items" 

51 

52 with SpinnerContext(**mapped_kwargs): 

53 output = list(x) 

54 

55 return output 

56 

57 

58def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 

59 "fallback to no progress bar" 

60 return x 

61 

62 

63# set the default progress bar function 

64try: 

65 # use tqdm if it's available 

66 import tqdm # type: ignore[import-untyped] 

67 

68 @functools.wraps(tqdm.tqdm) 

69 def tqdm_wrap(x: Iterable, **kwargs) -> Iterable: 

70 mapped_kwargs: dict = { 

71 k: v for k, v in kwargs.items() if k in get_fn_allowed_kwargs(tqdm.tqdm) 

72 } 

73 if "message" in kwargs and "desc" not in mapped_kwargs: 

74 mapped_kwargs["desc"] = mapped_kwargs.get("desc") 

75 return tqdm.tqdm(x, **mapped_kwargs) 

76 

77 DEFAULT_PBAR_FN = tqdm_wrap 

78 

79except ImportError: 

80 # use progress bar as fallback 

81 DEFAULT_PBAR_FN = spinner_fn_wrap 

82 

83 

84def set_up_progress_bar_fn( 

85 pbar: Union[ProgressBarFunction, ProgressBarOption], 

86 pbar_kwargs: Optional[Dict[str, Any]] = None, 

87 **extra_kwargs, 

88) -> ProgressBarFunction: 

89 pbar_fn: ProgressBarFunction 

90 

91 if pbar_kwargs is None: 

92 pbar_kwargs = dict() 

93 

94 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 

95 

96 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 

97 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 

98 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 

99 

100 # if `pbar` is a different string, figure out which progress bar to use 

101 elif isinstance(pbar, str): 

102 if pbar == "tqdm": 

103 pbar_fn = functools.partial(tqdm.tqdm, **pbar_kwargs) 

104 elif pbar == "spinner": 

105 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 

106 else: 

107 raise ValueError( 

108 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 

109 ) 

110 else: 

111 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 

112 pbar_fn = functools.partial(pbar, **pbar_kwargs) 

113 

114 return pbar_fn 

115 

116 

117def run_maybe_parallel( 

118 func: Callable[[InputType], OutputType], 

119 iterable: Iterable[InputType], 

120 parallel: Union[bool, int], 

121 pbar_kwargs: Optional[Dict[str, Any]] = None, 

122 chunksize: Optional[int] = None, 

123 keep_ordered: bool = True, 

124 use_multiprocess: bool = False, 

125 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 

126) -> List[OutputType]: 

127 """a function to make it easier to sometimes parallelize an operation 

128 

129 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 

130 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 

131 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 

132 

133 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 

134 

135 # Parameters: 

136 - `func : Callable[[InputType], OutputType]` 

137 function passed to either `map` or `Pool.imap` 

138 - `iterable : Iterable[InputType]` 

139 iterable passed to either `map` or `Pool.imap` 

140 - `parallel : bool | int` 

141 _description_ 

142 - `pbar_kwargs : Dict[str, Any]` 

143 _description_ 

144 

145 # Returns: 

146 - `List[OutputType]` 

147 _description_ 

148 

149 # Raises: 

150 - `ValueError` : _description_ 

151 """ 

152 

153 # number of inputs in iterable 

154 n_inputs: int = len(iterable) # type: ignore[arg-type] 

155 if n_inputs == 0: 

156 # Return immediately if there is no input 

157 return list() 

158 

159 # which progress bar to use 

160 pbar_fn: ProgressBarFunction = set_up_progress_bar_fn( 

161 pbar=pbar, 

162 pbar_kwargs=pbar_kwargs, 

163 # extra kwargs 

164 total=n_inputs, 

165 ) 

166 

167 # number of processes 

168 num_processes: int 

169 if isinstance(parallel, bool): 

170 num_processes = multiprocessing.cpu_count() if parallel else 1 

171 elif isinstance(parallel, int): 

172 if parallel < 2: 

173 raise ValueError( 

174 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 

175 ) 

176 num_processes = parallel 

177 else: 

178 raise ValueError( 

179 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 

180 ) 

181 

182 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 

183 num_processes = min(num_processes, n_inputs) 

184 mp = multiprocessing 

185 if num_processes == 1: 

186 parallel = False 

187 

188 if use_multiprocess: 

189 if not parallel: 

190 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 

191 

192 try: 

193 import multiprocess # type: ignore[import-untyped] 

194 except ImportError as e: 

195 raise ImportError( 

196 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 

197 ) from e 

198 

199 mp = multiprocess 

200 

201 # set up the map function -- maybe its parallel, maybe it's just `map` 

202 do_map: Callable[ 

203 [Callable[[InputType], OutputType], Iterable[InputType]], 

204 Iterable[OutputType], 

205 ] 

206 if parallel: 

207 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 

208 pool = mp.Pool(num_processes) 

209 

210 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 

211 if keep_ordered: 

212 do_map = pool.imap 

213 else: 

214 do_map = pool.imap_unordered 

215 

216 # figure out a smart chunksize if one is not given 

217 chunksize_int: int 

218 if chunksize is None: 

219 chunksize_int = max(1, n_inputs // num_processes) 

220 else: 

221 chunksize_int = chunksize 

222 

223 # set the chunksize 

224 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 

225 

226 else: 

227 do_map = map 

228 

229 # run the map function with a progress bar 

230 output: List[OutputType] = list( 

231 pbar_fn( 

232 do_map( 

233 func, 

234 iterable, 

235 ) 

236 ) 

237 ) 

238 

239 # close the pool if we used one 

240 if parallel: 

241 pool.close() 

242 pool.join() 

243 

244 # return the output as a list 

245 return output