Coverage for src/image_utils/custom_library_ops.py: 0%

36 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-05-19 16:15 -0700

1import torch 

2import numpy as np 

3import random 

4from . import is_tensor, is_ndarray, is_arr 

5 

6def generic_print(self, arr_values): 

7 assert is_arr(self) 

8 

9 if len(self.shape) == 0: 

10 return arr_values 

11 

12 if is_ndarray(self): 

13 lib = np 

14 num_elements = lib.prod(self.shape) 

15 device = '' 

16 else: 

17 lib = torch 

18 num_elements = lib.prod(torch.tensor(list(self.shape))).item() 

19 device = self.device.type 

20 

21 if self.dtype in (np.bool_, torch.bool): 

22 specific_data = f' sum: {self.sum()}, unique: {len(lib.unique(self))},' 

23 elif (is_ndarray(self) and np.issubdtype(self.dtype, np.integer)) or (is_tensor(self) and not torch.is_floating_point(self)): 

24 specific_data = f' unique: {len(lib.unique(self))},' 

25 else: 

26 specific_data = f' avg: {self.mean():.3f},' 

27 

28 shape_str = ",".join([str(self.shape[i]) for i in range(len(self.shape))]) 

29 finite_str = "finite" if lib.isfinite(self).all() else "non-finite" 

30 basic_info = f'[{shape_str}] {self.dtype} {device} {finite_str}' 

31 numerical_info = f'\nelems: {num_elements},{specific_data} min: {self.min():.3f}, max: {self.max().item():.3f}' 

32 

33 def get_first_and_last_lines(text): 

34 if text.count('\n') > 4: 

35 lines = text.split('\n') 

36 first_lines = "\n".join(lines[:2]) 

37 end_lines = "\n".join(lines[-2:]) 

38 return f'{first_lines} ...\n{end_lines}' 

39 else: 

40 return text 

41 

42 return basic_info + numerical_info + f'\n{arr_values}\n' + basic_info 

43 

44# torch.set_printoptions(sci_mode=False, precision=3, threshold=10, edgeitems=2, linewidth=120) 

45# normal_repr = torch.Tensor.__repr__ 

46# torch.Tensor.__repr__ = lambda self: generic_print(self, normal_repr(self)) 

47 

48# np.set_printoptions(suppress=True, precision=3, threshold=10, edgeitems=2, linewidth=120) 

49# np.set_string_function(lambda self: generic_print(self, np.ndarray.__repr__(self)), repr=False) 

50 

51def set_random_seeds(): 

52 torch.manual_seed(0) 

53 random.seed(0) 

54 np.random.seed(0)