Coverage for src/image_utils/image_utils.py: 80%

412 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-19 16:15 -0700

1from __future__ import annotations 

2 

3import colorsys 

4import copy 

5import string 

6import tempfile 

7from enum import auto 

8from io import BytesIO 

9from pathlib import Path 

10from typing import (Callable, Iterable, Optional, Tuple, Type, TypeAlias, 

11 Union, cast) 

12 

13import cv2 

14import numpy as np 

15import torch 

16import torch.nn.functional as F 

17import torchvision.transforms.functional as T 

18from einops import pack, rearrange, repeat 

19from PIL import Image, ImageOps 

20from strenum import StrEnum 

21from torchvision.transforms.functional import InterpolationMode 

22 

23from image_utils.file_utils import get_date_time_str 

24 

25if int(Image.__version__.split(".")[0]) >= 9 and int(Image.__version__.split(".")[1]) > 0: 

26 resampling_module = Image.Resampling 

27else: 

28 resampling_module = Image 

29 

30colorize_weights = {} 

31 

32ImArr: TypeAlias = Union[np.ndarray, torch.Tensor] 

33ImArrType: TypeAlias = Type[Union[np.ndarray, torch.Tensor]] 

34ImDtype: TypeAlias = Union[torch.dtype, np.dtype] 

35 

36def is_tensor(obj: ImArr): 

37 return torch.is_tensor(obj) 

38 

39 

40def is_ndarray(obj: ImArr): 

41 return isinstance(obj, np.ndarray) 

42 

43 

44def is_pil(obj: ImArr): 

45 return isinstance(obj, Image.Image) 

46 

47 

48def is_arr(obj: ImArr): 

49 return torch.is_tensor(obj) | isinstance(obj, np.ndarray) 

50 

51def dispatch_op(obj: ImArr, np_op, torch_op, *args): 

52 if is_ndarray(obj): 

53 return np_op(obj, *args) 

54 elif is_tensor(obj): 

55 return torch_op(obj, *args) 

56 else: 

57 raise ValueError(f'obj must be numpy array or torch tensor, not {type(obj)}') 

58 

59 

60class ChannelOrder(StrEnum): 

61 HWC = auto() 

62 CHW = auto() 

63 

64 

65class ChannelRange(StrEnum): 

66 UINT8 = auto() 

67 FLOAT = auto() 

68 BOOL = auto() 

69 

70from jaxtyping import Float 

71 

72 

73class Im: 

74 """ 

75 This class is a helper class to easily convert between formats (PIL/NumPy ndarray/PyTorch Tensor) 

76 and perform common operations, regardless of input dtype, batching, normalization, etc. 

77 

78 Note: Be careful when using this class directly as part of a training pipeline. Many operations will cause the underlying data to convert between formats (e.g., Tensor -> Pillow) and move the data back to system memory and/or incur loss of precision (e.g., float -> uint8) 

79 """ 

80 

81 default_normalize_mean = [0.4265, 0.4489, 0.4769] 

82 default_normalize_std = [0.2053, 0.2206, 0.2578] 

83 

84 def __init__(self, arr: Union['Im', torch.Tensor, Image.Image, np.ndarray], channel_range: Optional[ChannelRange] = None, **kwargs): 

85 if isinstance(arr, Im): 

86 for attr in dir(arr): 

87 if not attr.startswith("__"): 

88 setattr(self, attr, getattr(arr, attr)) 

89 return 

90 

91 self.device: torch.device 

92 self.arr_type: ImArrType 

93 

94 if isinstance(arr, Image.Image): 

95 arr = np.array(arr) 

96 

97 assert isinstance(arr, (np.ndarray, torch.Tensor)), f'arr must be numpy array, pillow image, or torch tensor, not {type(arr)}' 

98 self.arr: ImArr = arr 

99 if isinstance(self.arr, np.ndarray): 

100 self.arr_type = np.ndarray 

101 self.device = torch.device('cpu') 

102 elif isinstance(self.arr, torch.Tensor): 

103 self.device = self.arr.device 

104 self.arr_type = torch.Tensor 

105 else: 

106 raise ValueError('Must be numpy array, pillow image, or torch tensor') 

107 

108 if len(self.arr.shape) == 2: 

109 self.arr = self.arr[..., None] 

110 

111 self.channel_order: ChannelOrder = ChannelOrder.HWC if self.arr.shape[-1] < min(self.arr.shape[-3:-1]) else ChannelOrder.CHW 

112 self.dtype: ImDtype = self.arr.dtype 

113 self.shape = self.arr.shape 

114 

115 if len(self.shape) == 3: 

116 self.arr_transform = lambda x: rearrange(x, '() a b c -> a b c') 

117 elif len(self.shape) == 4: 

118 self.arr_transform = lambda x: x 

119 elif len(self.shape) >= 5: 

120 extra_dims = self.shape[:-3] 

121 mapping = {k: v for k, v in zip(string.ascii_uppercase, extra_dims)} 

122 self.arr_transform = lambda x: rearrange( 

123 x, f'({" ".join(sorted(list(mapping.keys())))}) a b c -> {" ".join(sorted(list(mapping.keys())))} a b c', **mapping) 

124 else: 

125 raise ValueError('Must be between 3-5 dims') 

126 

127 self.arr = rearrange(self.arr, '... a b c -> (...) a b c') 

128 

129 if channel_range is not None: 

130 self.channel_range = channel_range 

131 elif self.dtype == np.uint8 or self.dtype == torch.uint8: 

132 if self.arr_type == Image.Image or self.arr.max() > 1: 

133 self.channel_range = ChannelRange.UINT8 

134 else: 

135 self.channel_range = ChannelRange.BOOL 

136 elif self.dtype in (np.float16, np.float32, torch.float16, torch.bfloat16, torch.float32): 

137 if -128 <= self.arr.min() <= self.arr.max() <= 128: 

138 self.channel_range = ChannelRange.FLOAT 

139 else: 

140 raise ValueError('Not supported') 

141 elif self.dtype == np.bool_ or torch.bool: 

142 self.channel_range = ChannelRange.BOOL 

143 else: 

144 raise ValueError('Invalid Type') 

145 

146 def __repr__(self): 

147 if self.arr_type == np.ndarray: 

148 arr_name = 'ndarray' 

149 elif self.arr_type == torch.Tensor: 

150 arr_name = 'tensor' 

151 else: 

152 raise ValueError('Must be numpy array, pillow image, or torch tensor') 

153 

154 if is_pil(self.arr): 

155 shape_str = repr(self.arr) 

156 else: 

157 shape_str = f'type: {arr_name}, shape: {self.shape}' 

158 

159 return f'Im of {shape_str}, device: {self.device}' 

160 

161 def convert(self, desired_datatype: ImArrType, desired_order: ChannelOrder = ChannelOrder.HWC, desired_range: ChannelRange = ChannelRange.UINT8) -> Im: 

162 if self.arr_type != desired_datatype or self.channel_order != desired_order or self.channel_range != desired_range: 

163 

164 # We preserve the original dtype, shape, and device 

165 orig_shape, orig_transform, orig_device, orig_dtype = self.shape, self.arr_transform, self.device, self.arr.dtype 

166 

167 if desired_datatype == np.ndarray: 

168 self = Im(self.get_np(order=desired_order, range=desired_range)) 

169 elif desired_datatype == torch.Tensor: 

170 self = Im(self.get_torch(order=desired_order, range=desired_range)) 

171 

172 self.device = orig_device 

173 self.arr_transform = orig_transform 

174 self.dtype = orig_dtype 

175 self.shape = orig_shape 

176 

177 return self 

178 

179 @staticmethod 

180 def convert_to_datatype(desired_datatype: ImArrType, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8): 

181 def custom_decorator(func): 

182 def wrapper(self, *args, **kwargs): 

183 self = self.convert(desired_datatype, desired_order, desired_range) 

184 return func(self, *args, **kwargs) 

185 return wrapper 

186 return custom_decorator 

187 

188 def handle_order_transform(self, im, desired_order: ChannelOrder, desired_range: ChannelRange, select_batch=None): 

189 if select_batch: 

190 im = im[select_batch] 

191 else: 

192 im = self.arr_transform(im) 

193 

194 if desired_order == ChannelOrder.CHW and self.channel_order == ChannelOrder.HWC: 

195 im = rearrange(im, '... h w c -> ... c h w') 

196 elif desired_order == ChannelOrder.HWC and self.channel_order == ChannelOrder.CHW: 

197 im = rearrange(im, '... c h w -> ... h w c') 

198 

199 start_cur_order = 'h w ()' if desired_order == ChannelOrder.HWC else '() h w' 

200 end_cur_order = start_cur_order.replace('()', 'c') 

201 

202 if self.channel_range != desired_range: 

203 if is_ndarray(im): 

204 if self.channel_range == ChannelRange.FLOAT and desired_range == ChannelRange.UINT8: 

205 im = (im * 255).astype(np.uint8) 

206 elif self.channel_range == ChannelRange.UINT8 and desired_range == ChannelRange.FLOAT: 

207 im = (im / 255.0).astype(np.float32) 

208 elif self.channel_range == ChannelRange.BOOL and desired_range == ChannelRange.UINT8: 

209 assert self.channels == 1 

210 im = (repeat(im, f"... {start_cur_order} -> ... {end_cur_order}", c=3) * 255).astype(np.uint8) 

211 else: 

212 raise ValueError("Not supported") 

213 elif is_tensor(im): 

214 if self.channel_range == ChannelRange.FLOAT and desired_range == ChannelRange.UINT8: 

215 im = (im * 255).to(torch.uint8) 

216 elif self.channel_range == ChannelRange.UINT8 and desired_range == ChannelRange.FLOAT: 

217 im = (im / 255.0).to(torch.float32) 

218 elif self.channel_range == ChannelRange.BOOL and desired_range == ChannelRange.UINT8: 

219 assert self.channels == 1 

220 im = (repeat(im, f"... {start_cur_order} -> ... {end_cur_order}", c=3) * 255).to(torch.uint8) 

221 elif self.channel_range == ChannelRange.BOOL and desired_range == ChannelRange.FLOAT: 

222 assert self.channels == 1 

223 im = repeat(im, f"... {start_cur_order} -> ... {end_cur_order}", c=3).to(torch.float32) 

224 else: 

225 print(self.channel_range, desired_range) 

226 raise ValueError("Not supported") 

227 

228 return im 

229 

230 def get_np(self, order=ChannelOrder.HWC, range=ChannelRange.UINT8) -> np.ndarray: 

231 arr = self.arr 

232 if is_tensor(arr): 

233 arr = torch_to_numpy(arr) 

234 

235 arr = self.handle_order_transform(arr, order, range) 

236 

237 return arr 

238 

239 def get_torch(self, order=ChannelOrder.CHW, range=ChannelRange.FLOAT) -> torch.Tensor: 

240 arr = self.arr 

241 if is_ndarray(arr): 

242 arr = torch.from_numpy(arr) 

243 

244 arr = self.handle_order_transform(arr, order, range) 

245 if self.device is not None: 

246 arr = arr.to(self.device) 

247 return arr 

248 

249 def get_pil(self) -> Union[Image.Image, list[Image.Image]]: 

250 if len(self.shape) == 3: 

251 return Image.fromarray(self.get_np()) 

252 else: 

253 img = rearrange(self.get_np(), '... h w c -> (...) h w c') 

254 if img.shape[0] == 1: 

255 return Image.fromarray(img[0]) 

256 else: 

257 return [Image.fromarray(img[i]) for i in range(img.shape[0])] 

258 

259 @property 

260 def copy(self): 

261 return copy.deepcopy(self) 

262 

263 @property 

264 def height(self): 

265 return self.image_shape[0] 

266 

267 @property 

268 def width(self): 

269 return self.image_shape[1] 

270 

271 @property 

272 def channels(self): 

273 return self.arr.shape[-1] if self.channel_order == ChannelOrder.HWC else self.arr.shape[-3] 

274 

275 @property 

276 def image_shape(self): # returns h,w 

277 return (self.arr.shape[-3], self.arr.shape[-2]) if self.channel_order == ChannelOrder.HWC else (self.arr.shape[-2], self.arr.shape[-1]) 

278 

279 @staticmethod 

280 def open(filepath: Path, use_imageio=False) -> Im: 

281 if use_imageio: 

282 from imageio import v3 as iio 

283 img = iio.imread(filepath) 

284 else: 

285 img = Image.open(filepath) 

286 return Im(img) 

287 

288 @convert_to_datatype(desired_datatype=torch.Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT) 

289 def resize(self, height, width, resampling_mode=InterpolationMode.BILINEAR): 

290 assert isinstance(self.arr, torch.Tensor) 

291 return Im(T.resize(self.arr, [height, width], resampling_mode)) 

292 

293 def scale(self, scale) -> Im: 

294 width, height = self.width, self.height 

295 return self.resize(int(height * scale), int(width * scale)) 

296 

297 def scale_to_width(self, new_width) -> Im: 

298 width, height = self.width, self.height 

299 wpercent = (new_width/float(width)) 

300 hsize = int((float(height)*float(wpercent))) 

301 return self.resize(hsize, new_width) 

302 

303 def scale_to_height(self, new_height) -> Im: 

304 width, height = self.width, self.height 

305 hpercent = (new_height/float(height)) 

306 wsize = int((float(width)*float(hpercent))) 

307 return self.resize(new_height, wsize) 

308 

309 @staticmethod 

310 def _save_data(filepath: Path = Path(get_date_time_str()), filetype='png'): 

311 filepath = Path(filepath) 

312 if filepath.suffix == '': 

313 filepath = filepath.with_suffix(f'.{filetype}') 

314 

315 if len(filepath.parents) == 1: 

316 filepath = Path('output') / filepath 

317 filepath.parent.mkdir(parents=True, exist_ok=True) 

318 

319 return filepath 

320 

321 def save(self, filepath: Path = Path(get_date_time_str()), filetype='png', optimize=False, quality=None): 

322 img = self.get_torch() 

323 

324 filepath = Im._save_data(filepath, filetype) 

325 

326 if len(img.shape) > 3: 

327 from torchvision import utils 

328 img = rearrange(img, '... h w c -> (...) h w c') 

329 img = utils.make_grid(img) 

330 img = Im(img).get_pil() 

331 else: 

332 img = self.get_pil() 

333 

334 assert isinstance(img, Image.Image) 

335 

336 flags = {'optimize': True, 'quality': quality if quality else 0.95} if optimize or quality else {} 

337 

338 img.save(filepath, **flags) 

339 

340 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8) 

341 def write_text(self, text: str) -> Im: 

342 for i in range(self.arr.shape[0]): 

343 text_to_write = text[i] if isinstance(text, list) else text 

344 assert isinstance(self.arr[i], np.ndarray) 

345 im = cv2.cvtColor(cast(np.ndarray, self.arr[i]), cv2.COLOR_RGB2BGR) 

346 im = cv2.putText(im, text_to_write, (0, im.shape[0] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.002 * 

347 min(self.arr.shape[-3:-1]), (255, 0, 0), max(1, round(min(self.arr.shape[-3:-1]) / 150)), cv2.LINE_AA) 

348 self.arr[i] = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 

349 

350 return self 

351 

352 def add_border(self, border: int, color: Tuple[int, int, int]): 

353 imgs = self.pil 

354 if isinstance(imgs, Iterable): 

355 imgs = Im(np.stack([Im(ImageOps.expand(ImageOps.crop(img, border=border), border=border, fill=color)).np for img in imgs], axis=0)) 

356 else: 

357 imgs = Im(ImageOps.expand(ImageOps.crop(imgs, border=border), border=border, fill=color)) 

358 return imgs 

359 

360 def normalize_setup(self, mean=default_normalize_mean, std=default_normalize_std): 

361 def convert_instance_np(arr_1, arr_2): 

362 assert isinstance(self.dtype, np.dtype) 

363 return np.array(arr_2).astype(self.dtype) 

364 def convert_instance_torch(arr_1, arr_2): 

365 assert isinstance(self.dtype, torch.dtype) 

366 if self.dtype in (torch.float16, torch.bfloat16, torch.half): 

367 return torch.tensor(arr_2).to(dtype=torch.float, device=self.device) 

368 else: 

369 return torch.tensor(arr_2).to(dtype=self.dtype, device=self.device) 

370 

371 if is_ndarray(self.arr): 

372 self = Im(self.get_np(ChannelOrder.HWC, ChannelRange.FLOAT)) 

373 elif is_tensor(self.arr): 

374 self = Im(self.get_torch(ChannelOrder.HWC, ChannelRange.FLOAT)) 

375 

376 mean = dispatch_op(self.arr, convert_instance_np, convert_instance_torch, mean) 

377 std = dispatch_op(self.arr, convert_instance_np, convert_instance_torch, std) 

378 

379 return self, mean, std 

380 

381 def normalize(self, **kwargs) -> Im: 

382 self, mean, std = self.normalize_setup(**kwargs) 

383 self.arr = (self.arr - mean) / std 

384 return self 

385 

386 def denormalize(self, clamp: Union[bool, tuple[float, float]] = (0, 1.0), **kwargs) -> Im: 

387 self, mean, std = self.normalize_setup(**kwargs) 

388 self.arr = (self.arr * std) + mean 

389 if isinstance(self.arr, np.ndarray): 

390 self.arr = self.arr.clip(*clamp) if clamp else self.arr 

391 elif isinstance(self.arr, torch.Tensor): 

392 self.arr = self.arr.clamp(*clamp) if clamp else self.arr 

393 return self 

394 

395 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8) 

396 def get_opencv(self): 

397 return self.arr 

398 

399 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8) 

400 def convert_opencv_color(self, color: int): 

401 """E.g.,cv2.COLOR_RGB2BGR """ 

402 assert isinstance(self.arr, np.ndarray) 

403 self.arr = cv2.cvtColor(self.arr, color) 

404 

405 @staticmethod 

406 def concat_vertical(*args, **kwargs) -> Im: 

407 """Concatenates images vertically (i.e. stacked on top of each other)""" 

408 return concat_variable(concat_vertical_, *args, **kwargs) 

409 

410 @staticmethod 

411 def concat_horizontal(*args, **kwargs) -> Im: 

412 """Concatenates images horizontally (i.e. left to right)""" 

413 return concat_variable(concat_horizontal_, *args, **kwargs) 

414 

415 def save_video(self, filepath: Path, fps: int, format='mp4'): 

416 filepath = Im._save_data(filepath, format) 

417 byte_stream = self.encode_video(fps, format) 

418 with open(filepath, "wb") as f: 

419 f.write(byte_stream.getvalue()) 

420 

421 @convert_to_datatype(desired_datatype=np.ndarray, desired_order=ChannelOrder.HWC, desired_range=ChannelRange.UINT8) 

422 def encode_video(self, fps: int, format='mp4') -> BytesIO: 

423 assert len(self.arr.shape) == 4, "Video data must be 4D (time, height, width, channels)" 

424 import imageio 

425 byte_stream = BytesIO() 

426 

427 # TODO: We shouldn't need to write -> read. An imageio/ffmpeg issue is causing this. 

428 with tempfile.NamedTemporaryFile(suffix=f'.{format}') as ntp: 

429 if format == 'webm': 

430 writer = imageio.get_writer(ntp.name, format='webm', codec='libvpx-vp9', pixelformat='yuv420p', output_params=['-lossless', '1'], fps=fps) 

431 elif format == 'gif': 

432 writer = imageio.get_writer(ntp.name, format='GIF', mode="I", duration=(1000 * 1/fps)) 

433 elif format == 'mp4': 

434 writer = imageio.get_writer(ntp.name, quality=10, pixelformat='yuv420p', codec='libx264', fps=fps) 

435 else: 

436 raise NotImplementedError(f'Format {format} not implemented.') 

437 

438 for frame in self.arr: 

439 writer.append_data(frame) 

440 

441 writer.close() 

442 with open(ntp.name, 'rb') as f: 

443 byte_stream.write(f.read()) 

444 

445 byte_stream.seek(0) 

446 return byte_stream 

447 

448 def to(self, device: torch.device): 

449 assert isinstance(self.arr, torch.Tensor), "Can only convert to device if array is a torch.Tensor" 

450 self.arr = self.arr.to(device) 

451 return self 

452 

453 @staticmethod 

454 def stack_imgs(*args: Im): 

455 imgs = [img.convert(desired_datatype=np.ndarray) if img.arr_type == Image.Image else img for img in args] 

456 return Im(rearrange([img.handle_order_transform(img.arr, desired_order=ChannelOrder.HWC, desired_range=img.channel_range) for img in imgs], 'b ... -> b ...')) 

457 

458 @convert_to_datatype(desired_datatype=torch.Tensor, desired_order=ChannelOrder.CHW, desired_range=ChannelRange.FLOAT) 

459 def colorize(self) -> Im: 

460 if self.channels not in colorize_weights: 

461 colorize_weights[self.channels] = torch.randn(3, self.channels, 1, 1) 

462 

463 assert isinstance(self.arr, torch.Tensor) 

464 self.arr = F.conv2d(self.arr, weight=colorize_weights[self.channels]) 

465 self.arr = (self.arr-self.arr.min())/(self.arr.max()-self.arr.min()) 

466 return self 

467 

468 pil = property(get_pil) 

469 np = property(get_np) 

470 torch = property(get_torch) 

471 opencv = property(get_opencv) 

472 

473 

474def torch_to_numpy(arr): 

475 if arr.dtype == torch.bfloat16 or arr.dtype == torch.float16: 

476 return arr.float().cpu().detach().numpy() 

477 else: 

478 return arr.cpu().detach().numpy() 

479 

480 

481def pil_to_numpy(arr): 

482 return np.array(arr.convert('RGB')) 

483 

484 

485def concat_variable(concat_func: Callable[..., Im], *args: Im, **kwargs) -> Im: 

486 output_img = None 

487 for img in args: 

488 if output_img is None: 

489 output_img = img 

490 else: 

491 output_img = concat_func(output_img, img, **kwargs) 

492 

493 assert isinstance(output_img, Im) 

494 return output_img 

495 

496 

497def get_arr_hwc(im: Im): return im.handle_order_transform(im.arr, desired_order=ChannelOrder.HWC, desired_range=im.channel_range) 

498 

499 

500def concat_horizontal_(im1: Im, im2: Im, spacing=0) -> Im: 

501 if im1.height != im2.height: 

502 raise ValueError(f'Images must have same height. Got {im1.height} and {im2.height}') 

503 return Im(pack([get_arr_hwc(im1), get_arr_hwc(im2)], 'h * c')[0]) 

504 

505def concat_vertical_(im1: Im, im2: Im, spacing=0) -> Im: 

506 if im1.width != im2.width: 

507 raise ValueError(f'Images must have same width. Got {im1.width} and {im2.width}') 

508 return Im(pack([get_arr_hwc(im1), get_arr_hwc(im2)], '* w c')[0]) 

509 

510 

511 

512def get_layered_image_from_binary_mask(masks, flip=False): 

513 if torch.is_tensor(masks): 

514 masks = torch_to_numpy(masks) 

515 if flip: 

516 masks = np.flipud(masks) 

517 

518 masks = masks.astype(np.bool_) 

519 

520 colors = np.asarray(list(get_n_distinct_colors(masks.shape[2]))) 

521 img = np.zeros((*masks.shape[:2], 3)) 

522 for i in range(masks.shape[2]): 

523 img[masks[..., i]] = colors[i] 

524 

525 return Image.fromarray(img.astype(np.uint8)) 

526 

527 

528def get_img_from_binary_masks(masks, flip=False): 

529 """H W C""" 

530 arr = encode_binary_labels(masks) 

531 if flip: 

532 arr = np.flipud(arr) 

533 

534 colors = np.asarray(list(get_n_distinct_colors(2 ** masks.shape[2]))) 

535 return Image.fromarray(colors[arr].astype(np.uint8)) 

536 

537 

538def encode_binary_labels(masks): 

539 if torch.is_tensor(masks): 

540 masks = torch_to_numpy(masks) 

541 

542 masks = masks.transpose(2, 0, 1) 

543 bits = np.power(2, np.arange(len(masks), dtype=np.int32)) 

544 return (masks.astype(np.int32) * bits.reshape(-1, 1, 1)).sum(0) 

545 

546 

547def get_n_distinct_colors(n): 

548 def HSVToRGB(h, s, v): 

549 (r, g, b) = colorsys.hsv_to_rgb(h, s, v) 

550 return (int(255 * r), int(255 * g), int(255 * b)) 

551 

552 huePartition = 1.0 / (n + 1) 

553 return (HSVToRGB(huePartition * value, 1.0, 1.0) for value in range(0, n)) 

554 

555 

556def square_pad(image, h, w): 

557 h_1, w_1 = image.shape[-2:] 

558 ratio_f = w / h 

559 ratio_1 = w_1 / h_1 

560 

561 # check if the original and final aspect ratios are the same within a margin 

562 if round(ratio_1, 2) != round(ratio_f, 2): 

563 

564 # padding to preserve aspect ratio 

565 hp = int(w_1/ratio_f - h_1) 

566 wp = int(ratio_f * h_1 - w_1) 

567 if hp > 0 and wp < 0: 

568 hp = hp // 2 

569 image = T.pad(image, (0, hp, 0, hp), 0, "constant") 

570 return T.resize(image, [h, w]) 

571 

572 elif hp < 0 and wp > 0: 

573 wp = wp // 2 

574 image = T.pad(image, (wp, 0, wp, 0), 0, "constant") 

575 return T.resize(image, [h, w]) 

576 

577 else: 

578 return T.resize(image, [h, w])