Coverage for src/shephex/experiment/options.py: 99%
98 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-29 18:45 +0100
« prev ^ index » next coverage.py v7.6.1, created at 2025-03-29 18:45 +0100
1import json
2from pathlib import Path
3from typing import Dict, List, Self, Tuple, TypeAlias, Union
6class Options:
7 base_types: TypeAlias = Union[int, float, str, bool]
8 collection_types: TypeAlias = Union[List, Dict, Tuple]
10 def __init__(self, *args, **kwargs) -> None:
11 self._args = []
12 self._kwargs = {}
14 for arg in args:
15 self.add_arg(arg)
17 for key, value in kwargs.items():
18 self.add_kwarg(key, value)
21 def _check_type(self, value: Union[base_types, collection_types]):
22 valid_type = False
23 if isinstance(value, self.base_types):
24 valid_type = True
25 elif isinstance(value, self.collection_types):
26 # Check that keys and values are of base types
27 valid_type = True
28 if isinstance(value, Dict):
29 for k, v in value.items():
30 if not isinstance(k, self.base_types) or not isinstance(
31 v, self.base_types
32 ):
33 valid_type = False
34 break
35 else:
36 for v in value:
37 if not isinstance(v, self.base_types):
38 valid_type = False
39 break
41 if not valid_type:
42 raise TypeError(f'Invalid type {type(value)}')
44 def __repr__(self) -> str:
45 return f'ExperimentOptions(args={self._args}, kwargs={self._kwargs})'
47 def dump(self, path: Union[str, Path]) -> None:
48 if isinstance(path, str):
49 path = Path(path)
51 all_options = {'kwargs': self._kwargs, 'args': self._args}
53 with open(path / self.name, 'w') as f:
54 json.dump(all_options, f, indent=4)
56 def to_dict(self) -> Dict:
57 options_dict = {**self._kwargs}
59 if len(self._args) > 0:
60 options_dict['args'] = self._args
61 return options_dict
63 @property
64 def name(self) -> str:
65 return 'options.json'
67 @classmethod
68 def load(cls, path: Union[str, Path], name: str = 'options.json') -> Self:
69 if isinstance(path, str):
70 path = Path(path)
72 with open(path / name, 'r') as f:
73 all_options = json.load(f)
75 args = all_options.pop('args')
76 kwargs = all_options.pop('kwargs')
78 return cls(*args, **kwargs)
80 def __eq__(self, other: Self) -> bool:
81 if not isinstance(other, Options):
82 return False
84 if self._kwargs != other.kwargs:
85 return False
87 if self._args != other._args:
88 return False
90 return True
92 def items(self) -> List[Tuple[str, Union[base_types, collection_types]]]:
93 all_dict = self._kwargs.copy()
94 if len(self._args) > 0:
95 all_dict['args'] = self._args
96 return all_dict.items()
98 def keys(self) -> List[str]:
99 all_keys = list(self._kwargs.keys())
100 if len(self._args) > 0:
101 all_keys.append('args')
102 return all_keys
104 def values(self) -> List[Union[base_types, collection_types]]:
105 for key in self.keys():
106 if key == 'args':
107 yield self._args
108 else:
109 yield self._kwargs[key]
111 def __getitem__(self, key: str) -> Union[base_types, collection_types]:
112 if key == 'args':
113 return self._args
114 return self._kwargs[key]
116 @property
117 def kwargs(self) -> Dict:
118 return self._kwargs
120 @property
121 def args(self) -> List:
122 return self._args
124 def add_kwarg(self, key: str, value: Union[base_types, collection_types]) -> None:
125 self._check_type(value)
126 if key not in self._kwargs.keys():
127 self._kwargs[key] = value
128 else:
129 raise ValueError(f"Keyword argument {key} already exist")
131 def add_arg(self, value: Union[base_types, collection_types]) -> None:
132 self._check_type(value)
133 self._args.append(value)
135 def copy(self) -> Self:
136 return Options(*self.args, **self.kwargs)