Coverage for tests/test_im_utils.py: 99%

116 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-19 16:12 -0700

1from typing import Iterable, Union 

2from image_utils import Im, strip_unsafe 

3from PIL import Image 

4import torch 

5import numpy as np 

6import pytest 

7from pathlib import Path 

8from einops import rearrange, repeat 

9 

10img_path = Path('tests/flower.jpg') 

11save_path = Path(__file__).parent / 'output' 

12 

13 

14def get_img(img_type: Union[np.ndarray, Image.Image, torch.Tensor], hwc_order=True, dtype=None, normalize=False, device=None, bw_img=False, batch_shape=None): 

15 if bw_img: 

16 if dtype is None: 

17 img = Image.fromarray(np.random.rand(128, 128) > 0.5) 

18 else: 

19 img = Image.fromarray(np.random.randint(256, size=(128, 128)).astype(dtype)) 

20 else: 

21 img = Image.open(img_path) 

22 

23 if img_type == Image.Image: 

24 return img 

25 

26 img = np.array(img) 

27 if img_type == torch.Tensor: 

28 img = torch.from_numpy(img) 

29 

30 if not hwc_order: 

31 img = rearrange(img, 'h w c -> c h w') 

32 

33 if dtype is not None: 

34 img = img / 255.0 

35 if img_type == torch.Tensor: 

36 img = img.to(dtype=dtype) 

37 else: 

38 img = img.astype(dtype) 

39 

40 if normalize: 

41 pass 

42 

43 if device is not None and img_type == torch.Tensor: 

44 img = img.to(device=device) 

45 

46 if batch_shape is not None: 

47 if len(img.shape) == 2: 

48 img = img[None] 

49 img = repeat(img, f'... -> {" ".join(sorted(list(batch_shape)))} ...', **batch_shape) 

50 

51 return img 

52 

53 

54@pytest.mark.parametrize("dim_size", [4, 10, 100]) 

55def test_single_arg_even(dim_size): 

56 dims = (dim_size, dim_size) 

57 rand_float_tensor = torch.FloatTensor(*dims).uniform_() 

58 rand_bool_tensor = torch.FloatTensor(*dims).uniform_() > 0.5 

59 rant_int_tensor = torch.randint(0, 100, dims) 

60 

61 rand_float_array = np.random.rand(*dims) 

62 rand_bool_array = np.random.rand(*dims) > 0.5 

63 rand_int_array = np.random.randint(100, size=dims) 

64 

65 print(rand_float_tensor) 

66 print(rand_bool_tensor) 

67 print(rant_int_tensor) 

68 

69 print(rand_float_array) 

70 print(rand_bool_array) 

71 print(rand_int_array) 

72 

73 

74valid_configs = [ 

75 {'img_type': Image.Image}, 

76 {'img_type': np.ndarray}, 

77 {'img_type': np.ndarray, 'hwc_order': False, }, 

78 {'img_type': np.ndarray, 'dtype': np.float16, }, 

79 {'img_type': np.ndarray, 'hwc_order': False, 'dtype': np.float16}, 

80 {'img_type': np.ndarray, 'hwc_order': False, 'dtype': np.float32, 'normalize': True}, 

81 {'img_type': torch.Tensor}, 

82 {'img_type': torch.Tensor, 'hwc_order': False, }, 

83 {'img_type': torch.Tensor, 'dtype': torch.float32, }, 

84 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.float16}, 

85 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.bfloat16}, 

86 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.float, 'normalize': True}, 

87 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.float16, 'normalize': True}, 

88 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.bfloat16, 'normalize': True}, 

89 {'img_type': np.ndarray, 'bw_img': True}, 

90 {'img_type': np.ndarray, 'bw_img': True, 'dtype': np.uint8}, 

91 {'img_type': np.ndarray, 'batch_shape': {'a': 2}}, 

92 {'img_type': np.ndarray, 'batch_shape': {'a': 2, 'b': 3, 'c': 4}}, 

93 {'img_type': np.ndarray, 'batch_shape': {'a': 2, 'b': 3}}, 

94] 

95 

96def get_file_path(img_params: dict, name: str): 

97 file_path = save_path / strip_unsafe('__'.join([f'{k}_{v}' for k, v in img_params.items()])) 

98 return file_path.parent / f"{file_path.name}_{name}" 

99 

100@pytest.mark.parametrize("img_params", valid_configs) 

101def test_save(img_params): 

102 img = Im(get_img(**img_params)) 

103 img.copy.save(get_file_path(img_params, 'save')) 

104 

105@pytest.mark.parametrize("img_params", valid_configs) 

106def test_write_text(img_params): 

107 img = Im(get_img(**img_params)) 

108 img.copy.write_text('test').save(get_file_path(img_params, 'text')) 

109 

110 

111@pytest.mark.parametrize("img_params", valid_configs) 

112def test_add_border(img_params): 

113 img = Im(get_img(**img_params)) 

114 img.copy.add_border(border=5, color=(128, 128, 128)).save(get_file_path(img_params, 'border')) 

115 

116 

117@pytest.mark.parametrize("img_params", valid_configs) 

118def test_resize(img_params): 

119 img = Im(get_img(**img_params)) 

120 img.copy.resize(128, 128).save(get_file_path(img_params, 'resize')) 

121 img.copy.scale(0.25).save(get_file_path(img_params, 'downscale')) 

122 img.copy.scale_to_width(128).save(get_file_path(img_params, 'scale_width')) 

123 img.copy.scale_to_height(128).save(get_file_path(img_params, 'scale_height')) 

124 img.copy.scale(0.5).scale_to_width(128).resize(512, 1024).scale_to_width(512).save(get_file_path(img_params, 'multiple_resize')) 

125 

126 

127@pytest.mark.parametrize("img_params", valid_configs) 

128def test_normalization(img_params): 

129 img = Im(get_img(**img_params)) 

130 if img_params.get('bw_img', False): 

131 return 

132 img.normalize().denormalize().save(get_file_path(img_params, 'normalize0')) 

133 img.denormalize().normalize().save(get_file_path(img_params, 'normalize1')) 

134 

135 

136@pytest.mark.parametrize("img_params", valid_configs) 

137def test_format(img_params): 

138 img = Im(get_img(**img_params)) 

139 pil_img = img.pil 

140 torch_img = img.torch 

141 np_img = img.np 

142 cv_img = img.opencv 

143 

144 

145@pytest.mark.parametrize("img_params", valid_configs) 

146def test_concat(img_params): 

147 img = Im(get_img(**img_params)) 

148 

149 input_data = [img, img, img] 

150 if img_params.get('batch_shape', False): 

151 return 

152 

153 Im.concat_horizontal(*input_data, spacing=5) 

154 Im.concat_vertical(*input_data, spacing=0) 

155 

156@pytest.mark.parametrize("img_params", valid_configs) 

157@pytest.mark.parametrize("format", ['webm', 'mp4', 'gif']) 

158def test_encode_video(img_params, format): 

159 img_params['batch_shape'] = {'a': 2} 

160 if img_params['img_type'] == Image.Image: 

161 return 

162 img = Im(get_img(**img_params)) 

163 img.encode_video(2, format) 

164 img.save_video(get_file_path(img_params, 'video'), 2, format) 

165 

166@pytest.mark.parametrize("img_params", valid_configs) 

167def test_complicated(img_params): 

168 img = Im(get_img(**img_params)) 

169 img = img.scale(0.5).resize(128, 128) 

170 img = img.add_border(border=5, color=(128, 128, 128)).normalize(mean=(0.5, 0.75, 0.5), std=(0.1, 0.01, 0.01)) 

171 img = img.torch 

172 img = Im(img).denormalize(mean=(0.5, 0.75, 0.5), std=(0.1, 0.01, 0.01)) 

173 img = img.colorize() 

174 img.save(get_file_path(img_params, 'complicated')) 

175