Coverage for muutils\json_serialize\util.py: 44%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-08 01:02 -0700

1"""utilities for json_serialize""" 

2 

3from __future__ import annotations 

4 

5import dataclasses 

6import functools 

7import inspect 

8import sys 

9import typing 

10import warnings 

11from typing import Any, Callable, Iterable, Union 

12 

13_NUMPY_WORKING: bool 

14try: 

15 _NUMPY_WORKING = True 

16except ImportError: 

17 warnings.warn("numpy not found, cannot serialize numpy arrays!") 

18 _NUMPY_WORKING = False 

19 

20 

21JSONitem = Union[bool, int, float, str, list, typing.Dict[str, Any], None] 

22JSONdict = typing.Dict[str, JSONitem] 

23Hashableitem = Union[bool, int, float, str, tuple] 

24 

25# or if python version <3.9 

26if typing.TYPE_CHECKING or sys.version_info < (3, 9): 

27 MonoTuple = typing.Sequence 

28else: 

29 

30 class MonoTuple: 

31 """tuple type hint, but for a tuple of any length with all the same type""" 

32 

33 __slots__ = () 

34 

35 def __new__(cls, *args, **kwargs): 

36 raise TypeError("Type MonoTuple cannot be instantiated.") 

37 

38 def __init_subclass__(cls, *args, **kwargs): 

39 raise TypeError(f"Cannot subclass {cls.__module__}") 

40 

41 # idk why mypy thinks there is no such function in typing 

42 @typing._tp_cache # type: ignore 

43 def __class_getitem__(cls, params): 

44 if getattr(params, "__origin__", None) == typing.Union: 

45 return typing.GenericAlias(tuple, (params, Ellipsis)) 

46 elif isinstance(params, type): 

47 typing.GenericAlias(tuple, (params, Ellipsis)) 

48 # test if has len and is iterable 

49 elif isinstance(params, Iterable): 

50 if len(params) == 0: 

51 return tuple 

52 elif len(params) == 1: 

53 return typing.GenericAlias(tuple, (params[0], Ellipsis)) 

54 else: 

55 raise TypeError(f"MonoTuple expects 1 type argument, got {params = }") 

56 

57 

58class UniversalContainer: 

59 """contains everything -- `x in UniversalContainer()` is always True""" 

60 

61 def __contains__(self, x: Any) -> bool: 

62 return True 

63 

64 

65def isinstance_namedtuple(x: Any) -> bool: 

66 """checks if `x` is a `namedtuple` 

67 

68 credit to https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 

69 """ 

70 t: type = type(x) 

71 b: tuple = t.__bases__ 

72 if len(b) != 1 or (b[0] is not tuple): 

73 return False 

74 f: Any = getattr(t, "_fields", None) 

75 if not isinstance(f, tuple): 

76 return False 

77 return all(isinstance(n, str) for n in f) 

78 

79 

80def try_catch(func: Callable): 

81 """wraps the function to catch exceptions, returns serialized error message on exception 

82 

83 returned func will return normal result on success, or error message on exception 

84 """ 

85 

86 @functools.wraps(func) 

87 def newfunc(*args, **kwargs): 

88 try: 

89 return func(*args, **kwargs) 

90 except Exception as e: 

91 return f"{e.__class__.__name__}: {e}" 

92 

93 return newfunc 

94 

95 

96def _recursive_hashify(obj: Any, force: bool = True) -> Hashableitem: 

97 if isinstance(obj, typing.Mapping): 

98 return tuple((k, _recursive_hashify(v)) for k, v in obj.items()) 

99 elif isinstance(obj, (tuple, list, Iterable)): 

100 return tuple(_recursive_hashify(v) for v in obj) 

101 elif isinstance(obj, (bool, int, float, str)): 

102 return obj 

103 else: 

104 if force: 

105 return str(obj) 

106 else: 

107 raise ValueError(f"cannot hashify:\n{obj}") 

108 

109 

110class SerializationException(Exception): 

111 pass 

112 

113 

114def string_as_lines(s: str | None) -> list[str]: 

115 """for easier reading of long strings in json, split up by newlines 

116 

117 sort of like how jupyter notebooks do it 

118 """ 

119 if s is None: 

120 return list() 

121 else: 

122 return s.splitlines(keepends=False) 

123 

124 

125def safe_getsource(func) -> list[str]: 

126 try: 

127 return string_as_lines(inspect.getsource(func)) 

128 except Exception as e: 

129 return string_as_lines(f"Error: Unable to retrieve source code:\n{e}") 

130 

131 

132# credit to https://stackoverflow.com/questions/51743827/how-to-compare-equality-of-dataclasses-holding-numpy-ndarray-boola-b-raises 

133def array_safe_eq(a: Any, b: Any) -> bool: 

134 """check if two objects are equal, account for if numpy arrays or torch tensors""" 

135 if a is b: 

136 return True 

137 

138 if type(a) is not type(b): 

139 return False 

140 

141 if ( 

142 str(type(a)) == "<class 'numpy.ndarray'>" 

143 and str(type(b)) == "<class 'numpy.ndarray'>" 

144 ) or ( 

145 str(type(a)) == "<class 'torch.Tensor'>" 

146 and str(type(b)) == "<class 'torch.Tensor'>" 

147 ): 

148 return (a == b).all() 

149 

150 if ( 

151 str(type(a)) == "<class 'pandas.core.frame.DataFrame'>" 

152 and str(type(b)) == "<class 'pandas.core.frame.DataFrame'>" 

153 ): 

154 return a.equals(b) 

155 

156 if isinstance(a, typing.Sequence) and isinstance(b, typing.Sequence): 

157 if len(a) == 0 and len(b) == 0: 

158 return True 

159 return len(a) == len(b) and all(array_safe_eq(a1, b1) for a1, b1 in zip(a, b)) 

160 

161 if isinstance(a, (dict, typing.Mapping)) and isinstance(b, (dict, typing.Mapping)): 

162 return len(a) == len(b) and all( 

163 array_safe_eq(k1, k2) and array_safe_eq(a[k1], b[k2]) 

164 for k1, k2 in zip(a.keys(), b.keys()) 

165 ) 

166 

167 try: 

168 return bool(a == b) 

169 except (TypeError, ValueError) as e: 

170 warnings.warn(f"Cannot compare {a} and {b} for equality\n{e}") 

171 return NotImplemented # type: ignore[return-value] 

172 

173 

174def dc_eq( 

175 dc1, 

176 dc2, 

177 except_when_class_mismatch: bool = False, 

178 false_when_class_mismatch: bool = True, 

179 except_when_field_mismatch: bool = False, 

180) -> bool: 

181 """ 

182 checks if two dataclasses which (might) hold numpy arrays are equal 

183 

184 # Parameters: 

185 

186 - `dc1`: the first dataclass 

187 - `dc2`: the second dataclass 

188 - `except_when_class_mismatch: bool` 

189 if `True`, will throw `TypeError` if the classes are different. 

190 if not, will return false by default or attempt to compare the fields if `false_when_class_mismatch` is `False` 

191 (default: `False`) 

192 - `false_when_class_mismatch: bool` 

193 only relevant if `except_when_class_mismatch` is `False`. 

194 if `True`, will return `False` if the classes are different. 

195 if `False`, will attempt to compare the fields. 

196 - `except_when_field_mismatch: bool` 

197 only relevant if `except_when_class_mismatch` is `False` and `false_when_class_mismatch` is `False`. 

198 if `True`, will throw `TypeError` if the fields are different. 

199 (default: `True`) 

200 

201 # Returns: 

202 - `bool`: True if the dataclasses are equal, False otherwise 

203 

204 # Raises: 

205 - `TypeError`: if the dataclasses are of different classes 

206 - `AttributeError`: if the dataclasses have different fields 

207 

208 # TODO: after "except when class mismatch" is False, shouldn't we then go to "field keys match"? 

209 ``` 

210 [START] 

211 

212 ┌───────────┐ ┌─────────┐ 

213 │dc1 is dc2?├─►│ classes │ 

214 └──┬────────┘No│ match? │ 

215 ──── │ ├─────────┤ 

216 (True)◄──┘Yes │No │Yes 

217 ──── ▼ ▼ 

218 ┌────────────────┐ ┌────────────┐ 

219 │ except when │ │ fields keys│ 

220 │ class mismatch?│ │ match? │ 

221 ├───────────┬────┘ ├───────┬────┘ 

222 │Yes │No │No │Yes 

223 ▼ ▼ ▼ ▼ 

224 ─────────── ┌──────────┐ ┌────────┐ 

225 { raise } │ except │ │ field │ 

226 { TypeError } │ when │ │ values │ 

227 ─────────── │ field │ │ match? │ 

228 │ mismatch?│ ├────┬───┘ 

229 ├───────┬──┘ │ │Yes 

230 │Yes │No │No ▼ 

231 ▼ ▼ │ ──── 

232 ─────────────── ───── │ (True) 

233 { raise } (False)◄┘ ──── 

234 { AttributeError} ───── 

235 ─────────────── 

236 ``` 

237 

238 """ 

239 if dc1 is dc2: 

240 return True 

241 

242 if dc1.__class__ is not dc2.__class__: 

243 if except_when_class_mismatch: 

244 # if the classes don't match, raise an error 

245 raise TypeError( 

246 f"Cannot compare dataclasses of different classes: `{dc1.__class__}` and `{dc2.__class__}`" 

247 ) 

248 if except_when_field_mismatch: 

249 dc1_fields: set = set([fld.name for fld in dataclasses.fields(dc1)]) 

250 dc2_fields: set = set([fld.name for fld in dataclasses.fields(dc2)]) 

251 fields_match: bool = set(dc1_fields) == set(dc2_fields) 

252 if not fields_match: 

253 # if the fields match, keep going 

254 raise AttributeError( 

255 f"dataclasses {dc1} and {dc2} have different fields: `{dc1_fields}` and `{dc2_fields}`" 

256 ) 

257 return False 

258 

259 return all( 

260 array_safe_eq(getattr(dc1, fld.name), getattr(dc2, fld.name)) 

261 for fld in dataclasses.fields(dc1) 

262 if fld.compare 

263 )