Coverage for tests/test_im_utils.py: 99%
116 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 16:12 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 16:12 -0700
1from typing import Iterable, Union
2from image_utils import Im, strip_unsafe
3from PIL import Image
4import torch
5import numpy as np
6import pytest
7from pathlib import Path
8from einops import rearrange, repeat
10img_path = Path('tests/flower.jpg')
11save_path = Path(__file__).parent / 'output'
14def get_img(img_type: Union[np.ndarray, Image.Image, torch.Tensor], hwc_order=True, dtype=None, normalize=False, device=None, bw_img=False, batch_shape=None):
15 if bw_img:
16 if dtype is None:
17 img = Image.fromarray(np.random.rand(128, 128) > 0.5)
18 else:
19 img = Image.fromarray(np.random.randint(256, size=(128, 128)).astype(dtype))
20 else:
21 img = Image.open(img_path)
23 if img_type == Image.Image:
24 return img
26 img = np.array(img)
27 if img_type == torch.Tensor:
28 img = torch.from_numpy(img)
30 if not hwc_order:
31 img = rearrange(img, 'h w c -> c h w')
33 if dtype is not None:
34 img = img / 255.0
35 if img_type == torch.Tensor:
36 img = img.to(dtype=dtype)
37 else:
38 img = img.astype(dtype)
40 if normalize:
41 pass
43 if device is not None and img_type == torch.Tensor:
44 img = img.to(device=device)
46 if batch_shape is not None:
47 if len(img.shape) == 2:
48 img = img[None]
49 img = repeat(img, f'... -> {" ".join(sorted(list(batch_shape)))} ...', **batch_shape)
51 return img
54@pytest.mark.parametrize("dim_size", [4, 10, 100])
55def test_single_arg_even(dim_size):
56 dims = (dim_size, dim_size)
57 rand_float_tensor = torch.FloatTensor(*dims).uniform_()
58 rand_bool_tensor = torch.FloatTensor(*dims).uniform_() > 0.5
59 rant_int_tensor = torch.randint(0, 100, dims)
61 rand_float_array = np.random.rand(*dims)
62 rand_bool_array = np.random.rand(*dims) > 0.5
63 rand_int_array = np.random.randint(100, size=dims)
65 print(rand_float_tensor)
66 print(rand_bool_tensor)
67 print(rant_int_tensor)
69 print(rand_float_array)
70 print(rand_bool_array)
71 print(rand_int_array)
74valid_configs = [
75 {'img_type': Image.Image},
76 {'img_type': np.ndarray},
77 {'img_type': np.ndarray, 'hwc_order': False, },
78 {'img_type': np.ndarray, 'dtype': np.float16, },
79 {'img_type': np.ndarray, 'hwc_order': False, 'dtype': np.float16},
80 {'img_type': np.ndarray, 'hwc_order': False, 'dtype': np.float32, 'normalize': True},
81 {'img_type': torch.Tensor},
82 {'img_type': torch.Tensor, 'hwc_order': False, },
83 {'img_type': torch.Tensor, 'dtype': torch.float32, },
84 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.float16},
85 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.bfloat16},
86 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.float, 'normalize': True},
87 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.float16, 'normalize': True},
88 {'img_type': torch.Tensor, 'hwc_order': False, 'dtype': torch.bfloat16, 'normalize': True},
89 {'img_type': np.ndarray, 'bw_img': True},
90 {'img_type': np.ndarray, 'bw_img': True, 'dtype': np.uint8},
91 {'img_type': np.ndarray, 'batch_shape': {'a': 2}},
92 {'img_type': np.ndarray, 'batch_shape': {'a': 2, 'b': 3, 'c': 4}},
93 {'img_type': np.ndarray, 'batch_shape': {'a': 2, 'b': 3}},
94]
96def get_file_path(img_params: dict, name: str):
97 file_path = save_path / strip_unsafe('__'.join([f'{k}_{v}' for k, v in img_params.items()]))
98 return file_path.parent / f"{file_path.name}_{name}"
100@pytest.mark.parametrize("img_params", valid_configs)
101def test_save(img_params):
102 img = Im(get_img(**img_params))
103 img.copy.save(get_file_path(img_params, 'save'))
105@pytest.mark.parametrize("img_params", valid_configs)
106def test_write_text(img_params):
107 img = Im(get_img(**img_params))
108 img.copy.write_text('test').save(get_file_path(img_params, 'text'))
111@pytest.mark.parametrize("img_params", valid_configs)
112def test_add_border(img_params):
113 img = Im(get_img(**img_params))
114 img.copy.add_border(border=5, color=(128, 128, 128)).save(get_file_path(img_params, 'border'))
117@pytest.mark.parametrize("img_params", valid_configs)
118def test_resize(img_params):
119 img = Im(get_img(**img_params))
120 img.copy.resize(128, 128).save(get_file_path(img_params, 'resize'))
121 img.copy.scale(0.25).save(get_file_path(img_params, 'downscale'))
122 img.copy.scale_to_width(128).save(get_file_path(img_params, 'scale_width'))
123 img.copy.scale_to_height(128).save(get_file_path(img_params, 'scale_height'))
124 img.copy.scale(0.5).scale_to_width(128).resize(512, 1024).scale_to_width(512).save(get_file_path(img_params, 'multiple_resize'))
127@pytest.mark.parametrize("img_params", valid_configs)
128def test_normalization(img_params):
129 img = Im(get_img(**img_params))
130 if img_params.get('bw_img', False):
131 return
132 img.normalize().denormalize().save(get_file_path(img_params, 'normalize0'))
133 img.denormalize().normalize().save(get_file_path(img_params, 'normalize1'))
136@pytest.mark.parametrize("img_params", valid_configs)
137def test_format(img_params):
138 img = Im(get_img(**img_params))
139 pil_img = img.pil
140 torch_img = img.torch
141 np_img = img.np
142 cv_img = img.opencv
145@pytest.mark.parametrize("img_params", valid_configs)
146def test_concat(img_params):
147 img = Im(get_img(**img_params))
149 input_data = [img, img, img]
150 if img_params.get('batch_shape', False):
151 return
153 Im.concat_horizontal(*input_data, spacing=5)
154 Im.concat_vertical(*input_data, spacing=0)
156@pytest.mark.parametrize("img_params", valid_configs)
157@pytest.mark.parametrize("format", ['webm', 'mp4', 'gif'])
158def test_encode_video(img_params, format):
159 img_params['batch_shape'] = {'a': 2}
160 if img_params['img_type'] == Image.Image:
161 return
162 img = Im(get_img(**img_params))
163 img.encode_video(2, format)
164 img.save_video(get_file_path(img_params, 'video'), 2, format)
166@pytest.mark.parametrize("img_params", valid_configs)
167def test_complicated(img_params):
168 img = Im(get_img(**img_params))
169 img = img.scale(0.5).resize(128, 128)
170 img = img.add_border(border=5, color=(128, 128, 128)).normalize(mean=(0.5, 0.75, 0.5), std=(0.1, 0.01, 0.01))
171 img = img.torch
172 img = Im(img).denormalize(mean=(0.5, 0.75, 0.5), std=(0.1, 0.01, 0.01))
173 img = img.colorize()
174 img.save(get_file_path(img_params, 'complicated'))