Coverage for src/image_utils/custom_library_ops.py: 0%
36 statements
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 16:15 -0700
« prev ^ index » next coverage.py v7.2.5, created at 2023-05-19 16:15 -0700
1import torch
2import numpy as np
3import random
4from . import is_tensor, is_ndarray, is_arr
6def generic_print(self, arr_values):
7 assert is_arr(self)
9 if len(self.shape) == 0:
10 return arr_values
12 if is_ndarray(self):
13 lib = np
14 num_elements = lib.prod(self.shape)
15 device = ''
16 else:
17 lib = torch
18 num_elements = lib.prod(torch.tensor(list(self.shape))).item()
19 device = self.device.type
21 if self.dtype in (np.bool_, torch.bool):
22 specific_data = f' sum: {self.sum()}, unique: {len(lib.unique(self))},'
23 elif (is_ndarray(self) and np.issubdtype(self.dtype, np.integer)) or (is_tensor(self) and not torch.is_floating_point(self)):
24 specific_data = f' unique: {len(lib.unique(self))},'
25 else:
26 specific_data = f' avg: {self.mean():.3f},'
28 shape_str = ",".join([str(self.shape[i]) for i in range(len(self.shape))])
29 finite_str = "finite" if lib.isfinite(self).all() else "non-finite"
30 basic_info = f'[{shape_str}] {self.dtype} {device} {finite_str}'
31 numerical_info = f'\nelems: {num_elements},{specific_data} min: {self.min():.3f}, max: {self.max().item():.3f}'
33 def get_first_and_last_lines(text):
34 if text.count('\n') > 4:
35 lines = text.split('\n')
36 first_lines = "\n".join(lines[:2])
37 end_lines = "\n".join(lines[-2:])
38 return f'{first_lines} ...\n{end_lines}'
39 else:
40 return text
42 return basic_info + numerical_info + f'\n{arr_values}\n' + basic_info
44# torch.set_printoptions(sci_mode=False, precision=3, threshold=10, edgeitems=2, linewidth=120)
45# normal_repr = torch.Tensor.__repr__
46# torch.Tensor.__repr__ = lambda self: generic_print(self, normal_repr(self))
48# np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120)
49# np.set_string_function(lambda self: generic_print(self, np.ndarray.__repr__(self)), repr=False)
51def set_random_seeds():
52 torch.manual_seed(0)
53 random.seed(0)
54 np.random.seed(0)