Coverage for src/image_utils/image_utils.py: 80%
412 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 16:15 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 16:15 -0700
1from __future__ import annotations
3import colorsys
4import copy
5import string
6import tempfile
7from enum import auto
8from io import BytesIO
9from pathlib import Path
10from typing import (Callable, Iterable, Optional, Tuple, Type, TypeAlias,
11 Union, cast)
13import cv2
14import numpy as np
15import torch
16import torch.nn.functional as F
17import torchvision.transforms.functional as T
18from einops import pack, rearrange, repeat
19from PIL import Image, ImageOps
20from strenum import StrEnum
21from torchvision.transforms.functional import InterpolationMode
23from image_utils.file_utils import get_date_time_str
25if int(Image.__version__.split(".")[0]) >= 9 and int(Image.__version__.split(".")[1]) > 0:
26 resampling_module = Image.Resampling
27else:
28 resampling_module = Image
30colorize_weights = {}
32ImArr: TypeAlias = Union[np.ndarray, torch.Tensor]
33ImArrType: TypeAlias = Type[Union[np.ndarray, torch.Tensor]]
34ImDtype: TypeAlias = Union[torch.dtype, np.dtype]
36def is_tensor(obj: ImArr):
37 return torch.is_tensor(obj)
40def is_ndarray(obj: ImArr):
41 return isinstance(obj, np.ndarray)
44def is_pil(obj: ImArr):
45 return isinstance(obj, Image.Image)
48def is_arr(obj: ImArr):
49 return torch.is_tensor(obj) | isinstance(obj, np.ndarray)
51def dispatch_op(obj: ImArr, np_op, torch_op, *args):
52 if is_ndarray(obj):
53 return np_op(obj, *args)
54 elif is_tensor(obj):
55 return torch_op(obj, *args)
56 else:
57 raise ValueError(f'obj must be numpy array or torch tensor, not {type(obj)}')
60class ChannelOrder(StrEnum):
61 HWC = auto()
62 CHW = auto()
65class ChannelRange(StrEnum):
66 UINT8 = auto()
67 FLOAT = auto()
68 BOOL = auto()
70from jaxtyping import Float
73class Im:
74 """
75 This class is a helper class to easily convert between formats (PIL/NumPy ndarray/PyTorch Tensor)
76 and perform common operations, regardless of input dtype, batching, normalization, etc.
78 Note: Be careful when using this class directly as part of a training pipeline. Many operations will cause the underlying data to convert between formats (e.g., Tensor -> Pillow) and move the data back to system memory and/or incur loss of precision (e.g., float -> uint8)
79 """
81 default_normalize_mean = [0.4265, 0.4489, 0.4769]
82 default_normalize_std = [0.2053, 0.2206, 0.2578]
84 def __init__(self, arr: Union['Im', torch.Tensor, Image.Image, np.ndarray], channel_range: Optional[ChannelRange] = None, **kwargs):
85 if isinstance(arr, Im):
86 for attr in dir(arr):
87 if not attr.startswith("__"):
88 setattr(self, attr, getattr(arr, attr))
89 return
91 self.device: torch.device
92 self.arr_type: ImArrType
94 if isinstance(arr, Image.Image):
95 arr = np.array(arr)
97 assert isinstance(arr, (np.ndarray, torch.Tensor)), f'arr must be numpy array, pillow image, or torch tensor, not {type(arr)}'
98 self.arr: ImArr = arr
99 if isinstance(self.arr, np.ndarray):
100 self.arr_type = np.ndarray
101 self.device = torch.device('cpu')
102 elif isinstance(self.arr, torch.Tensor):
103 self.device = self.arr.device
104 self.arr_type = torch.Tensor
105 else:
106 raise ValueError('Must be numpy array, pillow image, or torch tensor')
108 if len(self.arr.shape) == 2:
109 self.arr = self.arr[..., None]
111 self.channel_order: ChannelOrder = ChannelOrder.HWC if self.arr.shape[-1] < min(self.arr.shape[-3:-1]) else ChannelOrder.CHW
112 self.dtype: ImDtype = self.arr.dtype
113 self.shape = self.arr.shape
115 if len(self.shape) == 3:
116 self.arr_transform = lambda x: rearrange(x, '() a b c -> a b c')
117 elif len(self.shape) == 4:
118 self.arr_transform = lambda x: x
119 elif len(self.shape) >= 5:
120 extra_dims = self.shape[:-3]
121 mapping = {k: v for k, v in zip(string.ascii_uppercase, extra_dims)}
122 self.arr_transform = lambda x: rearrange(
123 x, f'({" ".join(sorted(list(mapping.keys())))}) a b c -> {" ".join(sorted(list(mapping.keys())))} a b c', **mapping)
124 else:
125 raise ValueError('Must be between 3-5 dims')
127 self.arr = rearrange(self.arr, '... a b c -> (...) a b c')
129 if channel_range is not None:
130 self.channel_range = channel_range
131 elif self.dtype == np.uint8 or self.dtype == torch.uint8:
132 if self.arr_type == Image.Image or self.arr.max() > 1:
133 self.channel_range = ChannelRange.UINT8
134 else:
135 self.channel_range = ChannelRange.BOOL
136 elif self.dtype in (np.float16, np.float32, torch.float16, torch.bfloat16, torch.float32):
137 if -128 <= self.arr.min() <= self.arr.max() <= 128:
138 self.channel_range = ChannelRange.FLOAT
139 else:
140 raise ValueError('Not supported')
141 elif self.dtype == np.bool_ or torch.bool:
142 self.channel_range = ChannelRange.BOOL
143 else:
144 raise ValueError('Invalid Type')
146 def __repr__(self):
147 if self.arr_type == np.ndarray:
148 arr_name = 'ndarray'
149 elif self.arr_type == torch.Tensor:
150 arr_name = 'tensor'
151 else:
152 raise ValueError('Must be numpy array, pillow image, or torch tensor')
154 if is_pil(self.arr):
155 shape_str = repr(self.arr)
156 else:
157 shape_str = f'type: {arr_name}, shape: {self.shape}'
159 return f'Im of {shape_str}, device: {self.device}'
161 def convert(self, desired_datatype: ImArrType, desired_order: ChannelOrder = ChannelOrder.HWC, desired_range: ChannelRange = ChannelRange.UINT8) -> Im:
162 if self.arr_type != desired_datatype or self.channel_order != desired_order or self.channel_range != desired_range:
164 # We preserve the original dtype, shape, and device
165 orig_shape, orig_transform, orig_device, orig_dtype = self.shape, self.arr_transform, self.device, self.arr.dtype
167 if desired_datatype == np.ndarray:
168 self = Im(self.get_np(order=desired_order, range=desired_range))
169 elif desired_datatype == torch.Tensor:
170 self = Im(self.get_torch(order=desired_order, range=desired_range))
172 self.device = orig_device
173 self.arr_transform = orig_transform
174 self.dtype = orig_dtype
175 self.shape = orig_shape
177 return self
179 @staticmethod
180 def convert_to_datatype(desired_datatype: ImArrType, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8):
181 def custom_decorator(func):
182 def wrapper(self, *args, **kwargs):
183 self = self.convert(desired_datatype, desired_order, desired_range)
184 return func(self, *args, **kwargs)
185 return wrapper
186 return custom_decorator
188 def handle_order_transform(self, im, desired_order: ChannelOrder, desired_range: ChannelRange, select_batch=None):
189 if select_batch:
190 im = im[select_batch]
191 else:
192 im = self.arr_transform(im)
194 if desired_order == ChannelOrder.CHW and self.channel_order == ChannelOrder.HWC:
195 im = rearrange(im, '... h w c -> ... c h w')
196 elif desired_order == ChannelOrder.HWC and self.channel_order == ChannelOrder.CHW:
197 im = rearrange(im, '... c h w -> ... h w c')
199 start_cur_order = 'h w ()' if desired_order == ChannelOrder.HWC else '() h w'
200 end_cur_order = start_cur_order.replace('()', 'c')
202 if self.channel_range != desired_range:
203 if is_ndarray(im):
204 if self.channel_range == ChannelRange.FLOAT and desired_range == ChannelRange.UINT8:
205 im = (im * 255).astype(np.uint8)
206 elif self.channel_range == ChannelRange.UINT8 and desired_range == ChannelRange.FLOAT:
207 im = (im / 255.0).astype(np.float32)
208 elif self.channel_range == ChannelRange.BOOL and desired_range == ChannelRange.UINT8:
209 assert self.channels == 1
210 im = (repeat(im, f"... {start_cur_order} -> ... {end_cur_order}", c=3) * 255).astype(np.uint8)
211 else:
212 raise ValueError("Not supported")
213 elif is_tensor(im):
214 if self.channel_range == ChannelRange.FLOAT and desired_range == ChannelRange.UINT8:
215 im = (im * 255).to(torch.uint8)
216 elif self.channel_range == ChannelRange.UINT8 and desired_range == ChannelRange.FLOAT:
217 im = (im / 255.0).to(torch.float32)
218 elif self.channel_range == ChannelRange.BOOL and desired_range == ChannelRange.UINT8:
219 assert self.channels == 1
220 im = (repeat(im, f"... {start_cur_order} -> ... {end_cur_order}", c=3) * 255).to(torch.uint8)
221 elif self.channel_range == ChannelRange.BOOL and desired_range == ChannelRange.FLOAT:
222 assert self.channels == 1
223 im = repeat(im, f"... {start_cur_order} -> ... {end_cur_order}", c=3).to(torch.float32)
224 else:
225 print(self.channel_range, desired_range)
226 raise ValueError("Not supported")
228 return im
230 def get_np(self, order=ChannelOrder.HWC, range=ChannelRange.UINT8) -> np.ndarray:
231 arr = self.arr
232 if is_tensor(arr):
233 arr = torch_to_numpy(arr)
235 arr = self.handle_order_transform(arr, order, range)
237 return arr
239 def get_torch(self, order=ChannelOrder.CHW, range=ChannelRange.FLOAT) -> torch.Tensor:
240 arr = self.arr
241 if is_ndarray(arr):
242 arr = torch.from_numpy(arr)
244 arr = self.handle_order_transform(arr, order, range)
245 if self.device is not None:
246 arr = arr.to(self.device)
247 return arr
249 def get_pil(self) -> Union[Image.Image, list[Image.Image]]:
250 if len(self.shape) == 3:
251 return Image.fromarray(self.get_np())
252 else:
253 img = rearrange(self.get_np(), '... h w c -> (...) h w c')
254 if img.shape[0] == 1:
255 return Image.fromarray(img[0])
256 else:
257 return [Image.fromarray(img[i]) for i in range(img.shape[0])]
259 @property
260 def copy(self):
261 return copy.deepcopy(self)
263 @property
264 def height(self):
265 return self.image_shape[0]
267 @property
268 def width(self):
269 return self.image_shape[1]
271 @property
272 def channels(self):
273 return self.arr.shape[-1] if self.channel_order == ChannelOrder.HWC else self.arr.shape[-3]
275 @property
276 def image_shape(self): # returns h,w
277 return (self.arr.shape[-3], self.arr.shape[-2]) if self.channel_order == ChannelOrder.HWC else (self.arr.shape[-2], self.arr.shape[-1])
279 @staticmethod
280 def open(filepath: Path, use_imageio=False) -> Im:
281 if use_imageio:
282 from imageio import v3 as iio
283 img = iio.imread(filepath)
284 else:
285 img = Image.open(filepath)
286 return Im(img)
288 @convert_to_datatype(desired_datatype=torch.Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT)
289 def resize(self, height, width, resampling_mode=InterpolationMode.BILINEAR):
290 assert isinstance(self.arr, torch.Tensor)
291 return Im(T.resize(self.arr, [height, width], resampling_mode))
293 def scale(self, scale) -> Im:
294 width, height = self.width, self.height
295 return self.resize(int(height * scale), int(width * scale))
297 def scale_to_width(self, new_width) -> Im:
298 width, height = self.width, self.height
299 wpercent = (new_width/float(width))
300 hsize = int((float(height)*float(wpercent)))
301 return self.resize(hsize, new_width)
303 def scale_to_height(self, new_height) -> Im:
304 width, height = self.width, self.height
305 hpercent = (new_height/float(height))
306 wsize = int((float(width)*float(hpercent)))
307 return self.resize(new_height, wsize)
309 @staticmethod
310 def _save_data(filepath: Path = Path(get_date_time_str()), filetype='png'):
311 filepath = Path(filepath)
312 if filepath.suffix == '':
313 filepath = filepath.with_suffix(f'.{filetype}')
315 if len(filepath.parents) == 1:
316 filepath = Path('output') / filepath
317 filepath.parent.mkdir(parents=True, exist_ok=True)
319 return filepath
321 def save(self, filepath: Path = Path(get_date_time_str()), filetype='png', optimize=False, quality=None):
322 img = self.get_torch()
324 filepath = Im._save_data(filepath, filetype)
326 if len(img.shape) > 3:
327 from torchvision import utils
328 img = rearrange(img, '... h w c -> (...) h w c')
329 img = utils.make_grid(img)
330 img = Im(img).get_pil()
331 else:
332 img = self.get_pil()
334 assert isinstance(img, Image.Image)
336 flags = {'optimize': True, 'quality': quality if quality else 0.95} if optimize or quality else {}
338 img.save(filepath, **flags)
340 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
341 def write_text(self, text: str) -> Im:
342 for i in range(self.arr.shape[0]):
343 text_to_write = text[i] if isinstance(text, list) else text
344 assert isinstance(self.arr[i], np.ndarray)
345 im = cv2.cvtColor(cast(np.ndarray, self.arr[i]), cv2.COLOR_RGB2BGR)
346 im = cv2.putText(im, text_to_write, (0, im.shape[0] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.002 *
347 min(self.arr.shape[-3:-1]), (255, 0, 0), max(1, round(min(self.arr.shape[-3:-1]) / 150)), cv2.LINE_AA)
348 self.arr[i] = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
350 return self
352 def add_border(self, border: int, color: Tuple[int, int, int]):
353 imgs = self.pil
354 if isinstance(imgs, Iterable):
355 imgs = Im(np.stack([Im(ImageOps.expand(ImageOps.crop(img, border=border), border=border, fill=color)).np for img in imgs], axis=0))
356 else:
357 imgs = Im(ImageOps.expand(ImageOps.crop(imgs, border=border), border=border, fill=color))
358 return imgs
360 def normalize_setup(self, mean=default_normalize_mean, std=default_normalize_std):
361 def convert_instance_np(arr_1, arr_2):
362 assert isinstance(self.dtype, np.dtype)
363 return np.array(arr_2).astype(self.dtype)
364 def convert_instance_torch(arr_1, arr_2):
365 assert isinstance(self.dtype, torch.dtype)
366 if self.dtype in (torch.float16, torch.bfloat16, torch.half):
367 return torch.tensor(arr_2).to(dtype=torch.float, device=self.device)
368 else:
369 return torch.tensor(arr_2).to(dtype=self.dtype, device=self.device)
371 if is_ndarray(self.arr):
372 self = Im(self.get_np(ChannelOrder.HWC, ChannelRange.FLOAT))
373 elif is_tensor(self.arr):
374 self = Im(self.get_torch(ChannelOrder.HWC, ChannelRange.FLOAT))
376 mean = dispatch_op(self.arr, convert_instance_np, convert_instance_torch, mean)
377 std = dispatch_op(self.arr, convert_instance_np, convert_instance_torch, std)
379 return self, mean, std
381 def normalize(self, **kwargs) -> Im:
382 self, mean, std = self.normalize_setup(**kwargs)
383 self.arr = (self.arr - mean) / std
384 return self
386 def denormalize(self, clamp: Union[bool, tuple[float, float]] = (0, 1.0), **kwargs) -> Im:
387 self, mean, std = self.normalize_setup(**kwargs)
388 self.arr = (self.arr * std) + mean
389 if isinstance(self.arr, np.ndarray):
390 self.arr = self.arr.clip(*clamp) if clamp else self.arr
391 elif isinstance(self.arr, torch.Tensor):
392 self.arr = self.arr.clamp(*clamp) if clamp else self.arr
393 return self
395 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
396 def get_opencv(self):
397 return self.arr
399 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
400 def convert_opencv_color(self, color: int):
401 """E.g.,cv2.COLOR_RGB2BGR """
402 assert isinstance(self.arr, np.ndarray)
403 self.arr = cv2.cvtColor(self.arr, color)
405 @staticmethod
406 def concat_vertical(*args, **kwargs) -> Im:
407 """Concatenates images vertically (i.e. stacked on top of each other)"""
408 return concat_variable(concat_vertical_, *args, **kwargs)
410 @staticmethod
411 def concat_horizontal(*args, **kwargs) -> Im:
412 """Concatenates images horizontally (i.e. left to right)"""
413 return concat_variable(concat_horizontal_, *args, **kwargs)
415 def save_video(self, filepath: Path, fps: int, format='mp4'):
416 filepath = Im._save_data(filepath, format)
417 byte_stream = self.encode_video(fps, format)
418 with open(filepath, "wb") as f:
419 f.write(byte_stream.getvalue())
421 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8)
422 def encode_video(self, fps: int, format='mp4') -> BytesIO:
423 assert len(self.arr.shape) == 4, "Video data must be 4D (time, height, width, channels)"
424 import imageio
425 byte_stream = BytesIO()
427 # TODO: We shouldn't need to write -> read. An imageio/ffmpeg issue is causing this.
428 with tempfile.NamedTemporaryFile(suffix=f'.{format}') as ntp:
429 if format == 'webm':
430 writer = imageio.get_writer(ntp.name, format='webm', codec='libvpx-vp9', pixelformat='yuv420p', output_params=['-lossless', '1'], fps=fps)
431 elif format == 'gif':
432 writer = imageio.get_writer(ntp.name, format='GIF', mode="I", duration=(1000 * 1/fps))
433 elif format == 'mp4':
434 writer = imageio.get_writer(ntp.name, quality=10, pixelformat='yuv420p', codec='libx264', fps=fps)
435 else:
436 raise NotImplementedError(f'Format {format} not implemented.')
438 for frame in self.arr:
439 writer.append_data(frame)
441 writer.close()
442 with open(ntp.name, 'rb') as f:
443 byte_stream.write(f.read())
445 byte_stream.seek(0)
446 return byte_stream
448 def to(self, device: torch.device):
449 assert isinstance(self.arr, torch.Tensor), "Can only convert to device if array is a torch.Tensor"
450 self.arr = self.arr.to(device)
451 return self
453 @staticmethod
454 def stack_imgs(*args: Im):
455 imgs = [img.convert(desired_datatype=np.ndarray) if img.arr_type == Image.Image else img for img in args]
456 return Im(rearrange([img.handle_order_transform(img.arr, desired_order=ChannelOrder.HWC, desired_range=img.channel_range) for img in imgs], 'b ... -> b ...'))
458 @convert_to_datatype(desired_datatype=torch.Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT)
459 def colorize(self) -> Im:
460 if self.channels not in colorize_weights:
461 colorize_weights[self.channels] = torch.randn(3, self.channels, 1, 1)
463 assert isinstance(self.arr, torch.Tensor)
464 self.arr = F.conv2d(self.arr, weight=colorize_weights[self.channels])
465 self.arr = (self.arr-self.arr.min())/(self.arr.max()-self.arr.min())
466 return self
468 pil = property(get_pil)
469 np = property(get_np)
470 torch = property(get_torch)
471 opencv = property(get_opencv)
474def torch_to_numpy(arr):
475 if arr.dtype == torch.bfloat16 or arr.dtype == torch.float16:
476 return arr.float().cpu().detach().numpy()
477 else:
478 return arr.cpu().detach().numpy()
481def pil_to_numpy(arr):
482 return np.array(arr.convert('RGB'))
485def concat_variable(concat_func: Callable[..., Im], *args: Im, **kwargs) -> Im:
486 output_img = None
487 for img in args:
488 if output_img is None:
489 output_img = img
490 else:
491 output_img = concat_func(output_img, img, **kwargs)
493 assert isinstance(output_img, Im)
494 return output_img
497def get_arr_hwc(im: Im): return im.handle_order_transform(im.arr, desired_order=ChannelOrder.HWC, desired_range=im.channel_range)
500def concat_horizontal_(im1: Im, im2: Im, spacing=0) -> Im:
501 if im1.height != im2.height:
502 raise ValueError(f'Images must have same height. Got {im1.height} and {im2.height}')
503 return Im(pack([get_arr_hwc(im1), get_arr_hwc(im2)], 'h * c')[0])
505def concat_vertical_(im1: Im, im2: Im, spacing=0) -> Im:
506 if im1.width != im2.width:
507 raise ValueError(f'Images must have same width. Got {im1.width} and {im2.width}')
508 return Im(pack([get_arr_hwc(im1), get_arr_hwc(im2)], '* w c')[0])
512def get_layered_image_from_binary_mask(masks, flip=False):
513 if torch.is_tensor(masks):
514 masks = torch_to_numpy(masks)
515 if flip:
516 masks = np.flipud(masks)
518 masks = masks.astype(np.bool_)
520 colors = np.asarray(list(get_n_distinct_colors(masks.shape[2])))
521 img = np.zeros((*masks.shape[:2], 3))
522 for i in range(masks.shape[2]):
523 img[masks[..., i]] = colors[i]
525 return Image.fromarray(img.astype(np.uint8))
528def get_img_from_binary_masks(masks, flip=False):
529 """H W C"""
530 arr = encode_binary_labels(masks)
531 if flip:
532 arr = np.flipud(arr)
534 colors = np.asarray(list(get_n_distinct_colors(2 ** masks.shape[2])))
535 return Image.fromarray(colors[arr].astype(np.uint8))
538def encode_binary_labels(masks):
539 if torch.is_tensor(masks):
540 masks = torch_to_numpy(masks)
542 masks = masks.transpose(2, 0, 1)
543 bits = np.power(2, np.arange(len(masks), dtype=np.int32))
544 return (masks.astype(np.int32) * bits.reshape(-1, 1, 1)).sum(0)
547def get_n_distinct_colors(n):
548 def HSVToRGB(h, s, v):
549 (r, g, b) = colorsys.hsv_to_rgb(h, s, v)
550 return (int(255 * r), int(255 * g), int(255 * b))
552 huePartition = 1.0 / (n + 1)
553 return (HSVToRGB(huePartition * value, 1.0, 1.0) for value in range(0, n))
556def square_pad(image, h, w):
557 h_1, w_1 = image.shape[-2:]
558 ratio_f = w / h
559 ratio_1 = w_1 / h_1
561 # check if the original and final aspect ratios are the same within a margin
562 if round(ratio_1, 2) != round(ratio_f, 2):
564 # padding to preserve aspect ratio
565 hp = int(w_1/ratio_f - h_1)
566 wp = int(ratio_f * h_1 - w_1)
567 if hp > 0 and wp < 0:
568 hp = hp // 2
569 image = T.pad(image, (0, hp, 0, hp), 0, "constant")
570 return T.resize(image, [h, w])
572 elif hp < 0 and wp > 0:
573 wp = wp // 2
574 image = T.pad(image, (wp, 0, wp, 0), 0, "constant")
575 return T.resize(image, [h, w])
577 else:
578 return T.resize(image, [h, w])