Coverage for tests\unit\test_parallel.py: 96%

131 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-15 17:52 -0700

1import pytest 

2import multiprocessing 

3import time 

4from typing import Any, List, Iterable 

5 

6# Import the function to test 

7from muutils.parallel import DEFAULT_PBAR_FN, run_maybe_parallel 

8 

9DATA: dict = dict( 

10 empty=[], 

11 single=[5], 

12 small=list(range(4)), 

13 medium=list(range(10)), 

14 large=list(range(50)), 

15) 

16SQUARE_RESULTS: dict = {k: [x**2 for x in v] for k, v in DATA.items()} 

17ADD_ONE_RESULTS: dict = {k: [x + 1 for x in v] for k, v in DATA.items()} 

18 

19 

20# Basic test functions 

21def square(x: int) -> int: 

22 return x**2 

23 

24 

25def add_one(x: int) -> int: 

26 return x + 1 

27 

28 

29def raise_value_error(x: int) -> int: 

30 if x == 5: 

31 raise ValueError("Test error") 

32 return x**2 

33 

34 

35def slow_square(x: int) -> int: 

36 time.sleep(0.0001) 

37 return x**2 

38 

39 

40def raise_on_negative(x: int) -> int: 

41 if x < 0: 

42 raise ValueError("Negative number") 

43 return x 

44 

45 

46def stateful_fn(x: list) -> list: 

47 x.append(1) 

48 return x 

49 

50 

51class ComplexObject: 

52 def __init__(self, value: int): 

53 self.value = value 

54 

55 def __eq__(self, other: Any) -> bool: 

56 return isinstance(other, ComplexObject) and self.value == other.value 

57 

58 

59def dataset_decorator(keys: List[str]): 

60 def wrapper(test_func): 

61 return pytest.mark.parametrize( 

62 "input_range, expected", 

63 [(DATA[k], SQUARE_RESULTS[k]) for k in keys], 

64 ids=keys, 

65 )(test_func) 

66 

67 return wrapper 

68 

69 

70@dataset_decorator(["empty", "single", "small"]) 

71@pytest.mark.parametrize("parallel", [False, True, 2, 4]) 

72@pytest.mark.parametrize("keep_ordered", [True, False]) 

73@pytest.mark.parametrize("use_multiprocess", [True, False]) 

74def test_general_functionality( 

75 input_range, expected, parallel, keep_ordered, use_multiprocess 

76): 

77 # whether it's possible to use multiprocess 

78 if use_multiprocess and ( 

79 parallel is False or parallel == 1 or len(input_range) == 1 

80 ): 

81 return 

82 

83 # run the function 

84 results = run_maybe_parallel( 

85 func=square, 

86 iterable=input_range, 

87 parallel=parallel, 

88 pbar_kwargs={}, 

89 keep_ordered=keep_ordered, 

90 use_multiprocess=use_multiprocess, 

91 ) 

92 

93 # check the results 

94 assert set(results) == set(expected) 

95 if keep_ordered: 

96 assert results == expected 

97 

98 

99@dataset_decorator(["small"]) 

100@pytest.mark.parametrize( 

101 "pbar_type", 

102 ["tqdm", "spinner", "none", None, "invalid"], 

103) 

104@pytest.mark.parametrize("disable_flag", [True, False]) 

105def test_progress_bar_types_and_disable(input_range, expected, pbar_type, disable_flag): 

106 pbar_kwargs = {"disable": disable_flag} 

107 if pbar_type == "invalid" and not disable_flag: 

108 with pytest.raises(ValueError): 

109 run_maybe_parallel(square, input_range, False, pbar_kwargs, pbar=pbar_type) 

110 else: 

111 results = run_maybe_parallel( 

112 square, input_range, False, pbar_kwargs, pbar=pbar_type 

113 ) 

114 assert results == expected 

115 

116 

117@dataset_decorator(["small"]) 

118@pytest.mark.parametrize("chunksize", [None, 1, 5]) 

119@pytest.mark.parametrize("parallel", [False, True, 2]) 

120def test_chunksize_and_parallel(input_range, expected, chunksize, parallel): 

121 results = run_maybe_parallel(square, input_range, parallel, {}, chunksize=chunksize) 

122 assert results == expected 

123 

124 

125@dataset_decorator(["small"]) 

126@pytest.mark.parametrize("invalid_parallel", ["invalid", 0, -1, 1.5]) 

127def test_invalid_parallel_values(input_range, expected, invalid_parallel): 

128 with pytest.raises(ValueError): 

129 run_maybe_parallel(square, input_range, invalid_parallel) 

130 

131 

132def test_exception_in_func(): 

133 # one of the inputs is 0..3, no error here 

134 # Let's inject a known error 

135 error_input = [5] # Will raise ValueError 

136 with pytest.raises(ValueError): 

137 run_maybe_parallel(raise_value_error, error_input, True, {}) 

138 

139 

140@dataset_decorator(["small"]) 

141@pytest.mark.parametrize( 

142 "iterable_factory", 

143 [ 

144 lambda x: list(x), 

145 lambda x: tuple(x), 

146 lambda x: set(x), 

147 lambda x: dict.fromkeys(x, 0), 

148 ], 

149) 

150def test_different_iterables(input_range, expected, iterable_factory): 

151 test_input = iterable_factory(input_range) 

152 result = run_maybe_parallel(square, test_input, False) 

153 if isinstance(test_input, set): 

154 assert set(result) == set(expected) 

155 else: 

156 assert result == expected 

157 

158 

159@pytest.mark.parametrize("parallel", [False, True]) 

160def test_error_handling(parallel): 

161 # input_range is all positive small range, let's modify it to include negatives 

162 input_data = [-1, 0, 1, -2] 

163 with pytest.raises(ValueError): 

164 run_maybe_parallel(raise_on_negative, input_data, parallel) 

165 

166 

167def _process_complex(obj): 

168 return ComplexObject(obj.value * 2) 

169 

170 

171COMPLEX_DATA: List[ComplexObject] = [ComplexObject(i) for i in range(5)] 

172EXPECTED_COMPLEX = [ComplexObject(i * 2) for i in range(5)] 

173 

174 

175@pytest.mark.parametrize("parallel", [False, True]) 

176@pytest.mark.parametrize("pbar_type", [None, DEFAULT_PBAR_FN]) 

177def test_complex_objects(parallel, pbar_type): 

178 # override input_range with complex objects just for this test 

179 result = run_maybe_parallel( 

180 _process_complex, COMPLEX_DATA, parallel, pbar=pbar_type 

181 ) 

182 expected_complex = EXPECTED_COMPLEX 

183 assert all(a == b for a, b in zip(result, expected_complex)) 

184 

185 

186@dataset_decorator(["small"]) 

187def test_resource_cleanup(input_range, expected): 

188 initial_processes = len(multiprocessing.active_children()) 

189 run_maybe_parallel(square, input_range, True) 

190 time.sleep(0.05) 

191 final_processes = len(multiprocessing.active_children()) 

192 assert abs(final_processes - initial_processes) <= 2 

193 

194 

195@dataset_decorator(["small"]) 

196def test_custom_progress_bar(input_range, expected): 

197 def custom_progress_bar_fn(iterable: Iterable, **kwargs: Any) -> Iterable: 

198 return iterable 

199 

200 result = run_maybe_parallel(square, input_range, False, pbar=custom_progress_bar_fn) 

201 assert result == expected 

202 

203 

204@dataset_decorator(["small"]) 

205@pytest.mark.parametrize( 

206 "kwargs", 

207 [ 

208 None, 

209 dict(), 

210 dict(desc="Processing"), 

211 dict(disable=True), 

212 dict(ascii=True), 

213 dict(config="default"), 

214 dict(config="bar"), 

215 dict(ascii=True, config="bar"), 

216 dict(message="Processing"), 

217 dict(message="Processing", desc="Processing"), 

218 ], 

219) 

220def test_progress_bar_kwargs(input_range, expected, kwargs): 

221 result = run_maybe_parallel(square, input_range, False, pbar_kwargs=kwargs) 

222 assert result == expected 

223 

224 

225@dataset_decorator(["medium"]) 

226def test_parallel_performance(input_range, expected): 

227 serial_result = run_maybe_parallel(slow_square, input_range, False) 

228 parallel_result = run_maybe_parallel(slow_square, input_range, True) 

229 assert serial_result == parallel_result 

230 

231 

232@dataset_decorator(["small"]) 

233def test_reject_pbar_str_when_not_str_or_callable(input_range, expected): 

234 with pytest.raises(TypeError): 

235 run_maybe_parallel(square, input_range, False, pbar=12345) 

236 

237 

238def custom_pbar(iterable: Iterable, **kwargs: Any) -> List: 

239 return list(iterable) 

240 

241 

242@dataset_decorator(["small"]) 

243def test_manual_callable_pbar(input_range, expected): 

244 results = run_maybe_parallel(square, input_range, False, pbar=custom_pbar) 

245 assert results == expected, "Manual callable pbar test failed." 

246 

247 

248@pytest.mark.parametrize( 

249 "input_data, parallel", 

250 [ 

251 (range(multiprocessing.cpu_count() + 1), True), 

252 (range(multiprocessing.cpu_count() - 1), True), 

253 ], 

254) 

255def test_edge_cases(input_data, parallel): 

256 result = run_maybe_parallel(square, input_data, parallel) 

257 assert result == [square(x) for x in input_data]