Coverage for muutils\parallel.py: 94%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-02-21 19:08 -0700

1import multiprocessing 

2import functools 

3from typing import ( 

4 Any, 

5 Callable, 

6 Iterable, 

7 Literal, 

8 Optional, 

9 Tuple, 

10 TypeVar, 

11 Dict, 

12 List, 

13 Union, 

14 Protocol, 

15) 

16 

17# for no tqdm fallback 

18from muutils.spinner import SpinnerContext 

19from muutils.validate_type import get_fn_allowed_kwargs 

20 

21 

22InputType = TypeVar("InputType") 

23OutputType = TypeVar("OutputType") 

24# typevars for our iterable and map 

25 

26 

27class ProgressBarFunction(Protocol): 

28 "a protocol for a progress bar function" 

29 

30 def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ... 

31 

32 

33ProgressBarOption = Literal["tqdm", "spinner", "none", None] 

34# type for the progress bar option 

35 

36 

37DEFAULT_PBAR_FN: ProgressBarOption 

38# default progress bar function 

39 

40try: 

41 # use tqdm if it's available 

42 import tqdm # type: ignore[import-untyped] 

43 

44 DEFAULT_PBAR_FN = "tqdm" 

45 

46except ImportError: 

47 # use progress bar as fallback 

48 DEFAULT_PBAR_FN = "spinner" 

49 

50 

51def spinner_fn_wrap(x: Iterable, **kwargs) -> List: 

52 "spinner wrapper" 

53 spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs( 

54 SpinnerContext.__init__ 

55 ) 

56 mapped_kwargs: dict = { 

57 k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs 

58 } 

59 if "desc" in kwargs and "message" not in mapped_kwargs: 

60 mapped_kwargs["message"] = kwargs["desc"] 

61 

62 if "message" not in mapped_kwargs and "total" in kwargs: 

63 mapped_kwargs["message"] = f"Processing {kwargs['total']} items" 

64 

65 with SpinnerContext(**mapped_kwargs): 

66 output = list(x) 

67 

68 return output 

69 

70 

71def map_kwargs_for_tqdm(kwargs: dict) -> dict: 

72 "map kwargs for tqdm, cant wrap because the pbar dissapears?" 

73 tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__) 

74 print(f"{tqdm_allowed_kwargs = }") 

75 mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs} 

76 

77 if "desc" not in kwargs: 

78 if "message" in kwargs: 

79 mapped_kwargs["desc"] = kwargs["message"] 

80 

81 elif "total" in kwargs: 

82 mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items" 

83 print(f"{mapped_kwargs = }") 

84 return mapped_kwargs 

85 

86 

87def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable: 

88 "fallback to no progress bar" 

89 return x 

90 

91 

92def set_up_progress_bar_fn( 

93 pbar: Union[ProgressBarFunction, ProgressBarOption], 

94 pbar_kwargs: Optional[Dict[str, Any]] = None, 

95 **extra_kwargs, 

96) -> Tuple[ProgressBarFunction, dict]: 

97 """set up the progress bar function and its kwargs 

98 

99 # Parameters: 

100 - `pbar : Union[ProgressBarFunction, ProgressBarOption]` 

101 progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use 

102 - `pbar_kwargs : Optional[Dict[str, Any]]` 

103 kwargs passed to the progress bar function (default to `None`) 

104 (defaults to `None`) 

105 

106 # Returns: 

107 - `Tuple[ProgressBarFunction, dict]` 

108 a tuple of the progress bar function and its kwargs 

109 

110 # Raises: 

111 - `ValueError` : if `pbar` is not one of the valid options 

112 """ 

113 pbar_fn: ProgressBarFunction 

114 

115 if pbar_kwargs is None: 

116 pbar_kwargs = dict() 

117 

118 pbar_kwargs = {**extra_kwargs, **pbar_kwargs} 

119 

120 # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs` 

121 if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False): 

122 pbar_fn = no_progress_fn_wrap # type: ignore[assignment] 

123 

124 # if `pbar` is a different string, figure out which progress bar to use 

125 elif isinstance(pbar, str): 

126 if pbar == "tqdm": 

127 pbar_fn = tqdm.tqdm 

128 pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs) 

129 elif pbar == "spinner": 

130 pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs) 

131 pbar_kwargs = dict() 

132 else: 

133 raise ValueError( 

134 f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }" 

135 ) 

136 else: 

137 # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this 

138 pbar_fn = pbar 

139 

140 return pbar_fn, pbar_kwargs 

141 

142 

143def run_maybe_parallel( 

144 func: Callable[[InputType], OutputType], 

145 iterable: Iterable[InputType], 

146 parallel: Union[bool, int], 

147 pbar_kwargs: Optional[Dict[str, Any]] = None, 

148 chunksize: Optional[int] = None, 

149 keep_ordered: bool = True, 

150 use_multiprocess: bool = False, 

151 pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN, 

152) -> List[OutputType]: 

153 """a function to make it easier to sometimes parallelize an operation 

154 

155 - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)` 

156 - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes 

157 - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel` 

158 

159 the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())` 

160 

161 # Parameters: 

162 - `func : Callable[[InputType], OutputType]` 

163 function passed to either `map` or `Pool.imap` 

164 - `iterable : Iterable[InputType]` 

165 iterable passed to either `map` or `Pool.imap` 

166 - `parallel : bool | int` 

167 whether to run in parallel, and how many processes to use 

168 - `pbar_kwargs : Dict[str, Any]` 

169 kwargs passed to the progress bar function 

170 

171 # Returns: 

172 - `List[OutputType]` 

173 a list of the output of `func` for each element in `iterable` 

174 

175 # Raises: 

176 - `ValueError` : if `parallel` is not a boolean or an integer greater than 1 

177 - `ValueError` : if `use_multiprocess=True` and `parallel=False` 

178 - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available 

179 """ 

180 

181 # number of inputs in iterable 

182 n_inputs: int = len(iterable) # type: ignore[arg-type] 

183 if n_inputs == 0: 

184 # Return immediately if there is no input 

185 return list() 

186 

187 # which progress bar to use 

188 pbar_fn: ProgressBarFunction 

189 pbar_kwargs_processed: dict 

190 pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn( 

191 pbar=pbar, 

192 pbar_kwargs=pbar_kwargs, 

193 # extra kwargs 

194 total=n_inputs, 

195 ) 

196 

197 # number of processes 

198 num_processes: int 

199 if isinstance(parallel, bool): 

200 num_processes = multiprocessing.cpu_count() if parallel else 1 

201 elif isinstance(parallel, int): 

202 if parallel < 2: 

203 raise ValueError( 

204 f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }" 

205 ) 

206 num_processes = parallel 

207 else: 

208 raise ValueError( 

209 f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }" 

210 ) 

211 

212 # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process 

213 num_processes = min(num_processes, n_inputs) 

214 mp = multiprocessing 

215 if num_processes == 1: 

216 parallel = False 

217 

218 if use_multiprocess: 

219 if not parallel: 

220 raise ValueError("`use_multiprocess=True` requires `parallel=True`") 

221 

222 try: 

223 import multiprocess # type: ignore[import-untyped] 

224 except ImportError as e: 

225 raise ImportError( 

226 "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`" 

227 ) from e 

228 

229 mp = multiprocess 

230 

231 # set up the map function -- maybe its parallel, maybe it's just `map` 

232 do_map: Callable[ 

233 [Callable[[InputType], OutputType], Iterable[InputType]], 

234 Iterable[OutputType], 

235 ] 

236 if parallel: 

237 # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing` 

238 pool = mp.Pool(num_processes) 

239 

240 # use `imap` if we want to keep the order, otherwise use `imap_unordered` 

241 if keep_ordered: 

242 do_map = pool.imap 

243 else: 

244 do_map = pool.imap_unordered 

245 

246 # figure out a smart chunksize if one is not given 

247 chunksize_int: int 

248 if chunksize is None: 

249 chunksize_int = max(1, n_inputs // num_processes) 

250 else: 

251 chunksize_int = chunksize 

252 

253 # set the chunksize 

254 do_map = functools.partial(do_map, chunksize=chunksize_int) # type: ignore 

255 

256 else: 

257 do_map = map 

258 

259 # run the map function with a progress bar 

260 output: List[OutputType] = list( 

261 pbar_fn( 

262 do_map( 

263 func, 

264 iterable, 

265 ), 

266 **pbar_kwargs_processed, 

267 ) 

268 ) 

269 

270 # close the pool if we used one 

271 if parallel: 

272 pool.close() 

273 pool.join() 

274 

275 # return the output as a list 

276 return output