Coverage for src/shephex/experiment/options.py: 99%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-03-29 18:45 +0100

1import json 

2from pathlib import Path 

3from typing import Dict, List, Self, Tuple, TypeAlias, Union 

4 

5 

6class Options: 

7 base_types: TypeAlias = Union[int, float, str, bool] 

8 collection_types: TypeAlias = Union[List, Dict, Tuple] 

9 

10 def __init__(self, *args, **kwargs) -> None: 

11 self._args = [] 

12 self._kwargs = {} 

13 

14 for arg in args: 

15 self.add_arg(arg) 

16 

17 for key, value in kwargs.items(): 

18 self.add_kwarg(key, value) 

19 

20 

21 def _check_type(self, value: Union[base_types, collection_types]): 

22 valid_type = False 

23 if isinstance(value, self.base_types): 

24 valid_type = True 

25 elif isinstance(value, self.collection_types): 

26 # Check that keys and values are of base types 

27 valid_type = True 

28 if isinstance(value, Dict): 

29 for k, v in value.items(): 

30 if not isinstance(k, self.base_types) or not isinstance( 

31 v, self.base_types 

32 ): 

33 valid_type = False 

34 break 

35 else: 

36 for v in value: 

37 if not isinstance(v, self.base_types): 

38 valid_type = False 

39 break 

40 

41 if not valid_type: 

42 raise TypeError(f'Invalid type {type(value)}') 

43 

44 def __repr__(self) -> str: 

45 return f'ExperimentOptions(args={self._args}, kwargs={self._kwargs})' 

46 

47 def dump(self, path: Union[str, Path]) -> None: 

48 if isinstance(path, str): 

49 path = Path(path) 

50 

51 all_options = {'kwargs': self._kwargs, 'args': self._args} 

52 

53 with open(path / self.name, 'w') as f: 

54 json.dump(all_options, f, indent=4) 

55 

56 def to_dict(self) -> Dict: 

57 options_dict = {**self._kwargs} 

58 

59 if len(self._args) > 0: 

60 options_dict['args'] = self._args 

61 return options_dict 

62 

63 @property 

64 def name(self) -> str: 

65 return 'options.json' 

66 

67 @classmethod 

68 def load(cls, path: Union[str, Path], name: str = 'options.json') -> Self: 

69 if isinstance(path, str): 

70 path = Path(path) 

71 

72 with open(path / name, 'r') as f: 

73 all_options = json.load(f) 

74 

75 args = all_options.pop('args') 

76 kwargs = all_options.pop('kwargs') 

77 

78 return cls(*args, **kwargs) 

79 

80 def __eq__(self, other: Self) -> bool: 

81 if not isinstance(other, Options): 

82 return False 

83 

84 if self._kwargs != other.kwargs: 

85 return False 

86 

87 if self._args != other._args: 

88 return False 

89 

90 return True 

91 

92 def items(self) -> List[Tuple[str, Union[base_types, collection_types]]]: 

93 all_dict = self._kwargs.copy() 

94 if len(self._args) > 0: 

95 all_dict['args'] = self._args 

96 return all_dict.items() 

97 

98 def keys(self) -> List[str]: 

99 all_keys = list(self._kwargs.keys()) 

100 if len(self._args) > 0: 

101 all_keys.append('args') 

102 return all_keys 

103 

104 def values(self) -> List[Union[base_types, collection_types]]: 

105 for key in self.keys(): 

106 if key == 'args': 

107 yield self._args 

108 else: 

109 yield self._kwargs[key] 

110 

111 def __getitem__(self, key: str) -> Union[base_types, collection_types]: 

112 if key == 'args': 

113 return self._args 

114 return self._kwargs[key] 

115 

116 @property 

117 def kwargs(self) -> Dict: 

118 return self._kwargs 

119 

120 @property 

121 def args(self) -> List: 

122 return self._args 

123 

124 def add_kwarg(self, key: str, value: Union[base_types, collection_types]) -> None: 

125 self._check_type(value) 

126 if key not in self._kwargs.keys(): 

127 self._kwargs[key] = value 

128 else: 

129 raise ValueError(f"Keyword argument {key} already exist") 

130 

131 def add_arg(self, value: Union[base_types, collection_types]) -> None: 

132 self._check_type(value) 

133 self._args.append(value) 

134 

135 def copy(self) -> Self: 

136 return Options(*self.args, **self.kwargs) 

137