Coverage for zanj\serializing.py: 93%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import json 

4import sys 

5from dataclasses import dataclass 

6from typing import IO, Any, Callable, Iterable, Sequence 

7 

8import numpy as np 

9from muutils.json_serialize.array import arr_metadata 

10from muutils.json_serialize.json_serialize import ( # JsonSerializer, 

11 DEFAULT_HANDLERS, 

12 ObjectPath, 

13 SerializerHandler, 

14) 

15from muutils.json_serialize.util import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY 

16from muutils.tensor_utils import NDArray 

17 

18from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre 

19 

20KW_ONLY_KWARGS: dict = dict() 

21if sys.version_info >= (3, 10): 

22 KW_ONLY_KWARGS["kw_only"] = True 

23 

24# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg 

25# for some reason pylint complains about kwargs to ZANJSerializerHandler 

26 

27 

28def jsonl_metadata(data: list[JSONdict]) -> dict: 

29 """metadata about a jsonl object""" 

30 all_cols: set[str] = set([col for item in data for col in item.keys()]) 

31 return { 

32 "data[0]": data[0], 

33 "len(data)": len(data), 

34 "columns": { 

35 col: { 

36 "types": list( 

37 set([type(item[col]).__name__ for item in data if col in item]) 

38 ), 

39 "len": len([item[col] for item in data if col in item]), 

40 } 

41 for col in all_cols 

42 if col != _FORMAT_KEY 

43 }, 

44 } 

45 

46 

47def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: NDArray) -> None: 

48 """store numpy array to given file as .npy""" 

49 np.lib.format.write_array( 

50 fp=fp, 

51 array=np.asanyarray(data), 

52 allow_pickle=False, 

53 ) 

54 

55 

56def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None: 

57 """store sequence to given file as .jsonl""" 

58 

59 for item in data: 

60 fp.write(json.dumps(item).encode("utf-8")) 

61 fp.write("\n".encode("utf-8")) 

62 

63 

64EXTERNAL_STORE_FUNCS: dict[ 

65 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None] 

66] = { 

67 "npy": store_npy, 

68 "jsonl": store_jsonl, 

69} 

70 

71 

72@dataclass(**KW_ONLY_KWARGS) 

73class ZANJSerializerHandler(SerializerHandler): 

74 """a handler for ZANJ serialization""" 

75 

76 # unique identifier for the handler, saved in _FORMAT_KEY field 

77 # uid: str 

78 # source package of the handler -- note that this might be overridden by ZANJ 

79 source_pckg: str 

80 # (self_config, object) -> whether to use this handler 

81 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool] 

82 # (self_config, object, path) -> serialized object 

83 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem] 

84 # optional description of how this serializer works 

85 # desc: str = "(no description)" 

86 

87 

88def zanj_external_serialize( 

89 jser: _ZANJ_pre, 

90 data: Any, 

91 path: ObjectPath, 

92 item_type: ExternalItemType, 

93 _format: str, 

94) -> JSONitem: 

95 """stores a numpy array or jsonl externally in a ZANJ object 

96 

97 # Parameters: 

98 - `jser: ZANJ` 

99 - `data: Any` 

100 - `path: ObjectPath` 

101 - `item_type: ExternalItemType` 

102 

103 # Returns: 

104 - `JSONitem` 

105 json data with reference 

106 

107 # Modifies: 

108 - modifies `jser._externals` 

109 """ 

110 # get the path, make sure its unique 

111 assert isinstance(path, tuple), ( 

112 f"path must be a tuple, got {type(path) = } {path = }" 

113 ) 

114 joined_path: str = "/".join([str(p) for p in path]) 

115 archive_path: str = f"{joined_path}.{item_type}" 

116 

117 if archive_path in jser._externals: 

118 raise ValueError(f"external path {archive_path} already exists!") 

119 if any([p.startswith(joined_path) for p in jser._externals.keys()]): 

120 raise ValueError(f"external path {joined_path} is a prefix of another path!") 

121 

122 # process the data if needed, assemble metadata 

123 data_new: Any = data 

124 output: dict = { 

125 _FORMAT_KEY: _format, 

126 _REF_KEY: archive_path, 

127 } 

128 if item_type == "npy": 

129 # check type 

130 data_type_str: str = str(type(data)) 

131 if data_type_str == "<class 'torch.Tensor'>": 

132 # detach and convert 

133 data_new = data.detach().cpu().numpy() 

134 elif data_type_str == "<class 'numpy.ndarray'>": 

135 pass 

136 else: 

137 # if not a numpy array, except 

138 raise TypeError(f"expected numpy.ndarray, got {data_type_str}") 

139 # get metadata 

140 output.update(arr_metadata(data)) 

141 elif item_type.startswith("jsonl"): 

142 # check via mro to avoid importing pandas 

143 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__): 

144 output["columns"] = data.columns.tolist() 

145 data_new = data.to_dict(orient="records") 

146 elif isinstance(data, (list, tuple, Iterable, Sequence)): 

147 data_new = [ 

148 jser.json_serialize(item, tuple(path) + (i,)) 

149 for i, item in enumerate(data) 

150 ] 

151 else: 

152 raise TypeError( 

153 f"expected list or pandas.DataFrame for jsonl, got {type(data)}" 

154 ) 

155 

156 if all([isinstance(item, dict) for item in data_new]): 

157 output.update(jsonl_metadata(data_new)) 

158 

159 # store the item for external serialization 

160 jser._externals[archive_path] = ExternalItem( 

161 item_type=item_type, 

162 data=data_new, 

163 path=path, 

164 ) 

165 

166 return output 

167 

168 

169DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple( 

170 [ 

171 ZANJSerializerHandler( 

172 check=lambda self, obj, path: ( 

173 isinstance(obj, np.ndarray) 

174 and obj.size >= self.external_array_threshold 

175 ), 

176 serialize_func=lambda self, obj, path: zanj_external_serialize( 

177 self, obj, path, item_type="npy", _format="numpy.ndarray:external" 

178 ), 

179 uid="numpy.ndarray:external", 

180 source_pckg="zanj", 

181 desc="external numpy array", 

182 ), 

183 ZANJSerializerHandler( 

184 check=lambda self, obj, path: ( 

185 str(type(obj)) == "<class 'torch.Tensor'>" 

186 and int(obj.nelement()) >= self.external_array_threshold 

187 ), 

188 serialize_func=lambda self, obj, path: zanj_external_serialize( 

189 self, obj, path, item_type="npy", _format="torch.Tensor:external" 

190 ), 

191 uid="torch.Tensor:external", 

192 source_pckg="zanj", 

193 desc="external torch tensor", 

194 ), 

195 ZANJSerializerHandler( 

196 check=lambda self, obj, path: isinstance(obj, list) 

197 and len(obj) >= self.external_list_threshold, 

198 serialize_func=lambda self, obj, path: zanj_external_serialize( 

199 self, obj, path, item_type="jsonl", _format="list:external" 

200 ), 

201 uid="list:external", 

202 source_pckg="zanj", 

203 desc="external list", 

204 ), 

205 ZANJSerializerHandler( 

206 check=lambda self, obj, path: isinstance(obj, tuple) 

207 and len(obj) >= self.external_list_threshold, 

208 serialize_func=lambda self, obj, path: zanj_external_serialize( 

209 self, obj, path, item_type="jsonl", _format="tuple:external" 

210 ), 

211 uid="tuple:external", 

212 source_pckg="zanj", 

213 desc="external tuple", 

214 ), 

215 ZANJSerializerHandler( 

216 check=lambda self, obj, path: ( 

217 any( 

218 "pandas.core.frame.DataFrame" in str(t) 

219 for t in obj.__class__.__mro__ 

220 ) 

221 and len(obj) >= self.external_list_threshold 

222 ), 

223 serialize_func=lambda self, obj, path: zanj_external_serialize( 

224 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external" 

225 ), 

226 uid="pandas.DataFrame:external", 

227 source_pckg="zanj", 

228 desc="external pandas DataFrame", 

229 ), 

230 # ZANJSerializerHandler( 

231 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>" 

232 # in [str(t) for t in obj.__class__.__mro__], 

233 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule( 

234 # self, obj, path, 

235 # ), 

236 # uid="torch.nn.Module", 

237 # source_pckg="zanj", 

238 # desc="fallback torch serialization", 

239 # ), 

240 ] 

241) + tuple( 

242 DEFAULT_HANDLERS # type: ignore[arg-type] 

243) 

244 

245# the complaint above is: 

246# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]