Coverage for zanj\zanj.py: 94%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1""" 

2an HDF5/exdir file alternative, which uses json for attributes, allows serialization of arbitrary data 

3 

4for large arrays, the output is a .tar.gz file with most data in a json file, but with sufficiently large arrays stored in binary .npy files 

5 

6 

7"ZANJ" is an acronym that the AI tool [Elicit](https://elicit.org) came up with for me. not to be confused with: 

8 

9- https://en.wikipedia.org/wiki/Zanj 

10- https://www.plutojournals.com/zanj/ 

11 

12""" 

13 

14from __future__ import annotations 

15 

16import json 

17import os 

18import time 

19import zipfile 

20from dataclasses import dataclass 

21from pathlib import Path 

22from typing import Any, Union 

23 

24import numpy as np 

25from muutils.errormode import ErrorMode 

26from muutils.json_serialize.array import ArrayMode, arr_metadata 

27from muutils.json_serialize.json_serialize import ( 

28 JsonSerializer, 

29 SerializerHandler, 

30 json_serialize, 

31) 

32from muutils.json_serialize.util import JSONitem, MonoTuple 

33from muutils.sysinfo import SysInfo 

34 

35from zanj.externals import ZANJ_MAIN, ZANJ_META, ExternalItem 

36import zanj.externals 

37from zanj.loading import LOADER_MAP, LoadedZANJ, load_item_recursive 

38from zanj.serializing import ( 

39 DEFAULT_SERIALIZER_HANDLERS_ZANJ, 

40 EXTERNAL_STORE_FUNCS, 

41 KW_ONLY_KWARGS, 

42) 

43 

44# pylint: disable=protected-access, unused-import, dangerous-default-value, line-too-long 

45 

46ZANJitem = Union[ 

47 JSONitem, 

48 np.ndarray, 

49 "pd.DataFrame", # type: ignore # noqa: F821 

50] 

51 

52 

53@dataclass(**KW_ONLY_KWARGS) 

54class _ZANJ_GLOBAL_DEFAULTS_CLASS: 

55 error_mode: ErrorMode = ErrorMode.EXCEPT 

56 internal_array_mode: ArrayMode = "array_list_meta" 

57 external_array_threshold: int = 256 

58 external_list_threshold: int = 256 

59 compress: bool | int = True 

60 custom_settings: dict[str, Any] | None = None 

61 

62 

63ZANJ_GLOBAL_DEFAULTS: _ZANJ_GLOBAL_DEFAULTS_CLASS = _ZANJ_GLOBAL_DEFAULTS_CLASS() 

64 

65 

66class ZANJ(JsonSerializer): 

67 """Zip up: Arrays in Numpy, JSON for everything else 

68 

69 given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file 

70 

71 (basically npz file with json) 

72 

73 - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object 

74 - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer` 

75 - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive 

76 

77 create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized 

78 

79 """ 

80 

81 def __init__( 

82 self, 

83 error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode, 

84 internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode, 

85 external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold, 

86 external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold, 

87 compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress, 

88 custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings, 

89 handlers_pre: MonoTuple[SerializerHandler] = tuple(), 

90 handlers_default: MonoTuple[ 

91 SerializerHandler 

92 ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ, 

93 ) -> None: 

94 super().__init__( 

95 array_mode=internal_array_mode, 

96 error_mode=error_mode, 

97 handlers_pre=handlers_pre, 

98 handlers_default=handlers_default, 

99 ) 

100 

101 self.external_array_threshold: int = external_array_threshold 

102 self.external_list_threshold: int = external_list_threshold 

103 self.custom_settings: dict = ( 

104 custom_settings if custom_settings is not None else dict() 

105 ) 

106 

107 # process compression to int if bool given 

108 self.compress = compress 

109 if isinstance(compress, bool): 

110 if compress: 

111 self.compress = zipfile.ZIP_DEFLATED 

112 else: 

113 self.compress = zipfile.ZIP_STORED 

114 

115 # create the externals, leave it empty 

116 self._externals: dict[str, ExternalItem] = dict() 

117 

118 def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]: 

119 """return information about the current externals""" 

120 output: dict[str, dict] = dict() 

121 

122 key: str 

123 item: ExternalItem 

124 for key, item in self._externals.items(): 

125 data = item.data 

126 output[key] = { 

127 "item_type": item.item_type, 

128 "path": item.path, 

129 "type(data)": str(type(data)), 

130 "len(data)": len(data), 

131 } 

132 

133 if item.item_type == "ndarray": 

134 output[key].update(arr_metadata(data)) 

135 elif item.item_type.startswith("jsonl"): 

136 output[key]["data[0]"] = data[0] 

137 

138 return { 

139 key: val 

140 for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"])) 

141 } 

142 

143 def meta(self) -> JSONitem: 

144 """return the metadata of the ZANJ archive""" 

145 

146 serialization_handlers = {h.uid: h.serialize() for h in self.handlers} 

147 load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()} 

148 

149 return dict( 

150 # configuration of this ZANJ instance 

151 zanj_cfg=dict( 

152 error_mode=str(self.error_mode), 

153 array_mode=str(self.array_mode), 

154 external_array_threshold=self.external_array_threshold, 

155 external_list_threshold=self.external_list_threshold, 

156 compress=self.compress, 

157 serialization_handlers=serialization_handlers, 

158 load_handlers=load_handlers, 

159 ), 

160 # system info (python, pip packages, torch & cuda, platform info, git info) 

161 sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))), 

162 externals_info=self.externals_info(), 

163 timestamp=time.time(), 

164 ) 

165 

166 def save(self, obj: Any, file_path: str | Path) -> str: 

167 """save the object to a ZANJ archive. returns the path to the archive""" 

168 

169 # adjust extension 

170 file_path = str(file_path) 

171 if not file_path.endswith(".zanj"): 

172 file_path += ".zanj" 

173 

174 # make directory 

175 dir_path: str = os.path.dirname(file_path) 

176 if dir_path != "": 

177 if not os.path.exists(dir_path): 

178 os.makedirs(dir_path, exist_ok=False) 

179 

180 # clear the externals! 

181 self._externals = dict() 

182 

183 # serialize the object -- this will populate self._externals 

184 # TODO: calling self.json_serialize again here might be slow 

185 json_data: JSONitem = self.json_serialize(self.json_serialize(obj)) 

186 

187 # open the zip file 

188 zipf: zipfile.ZipFile = zipfile.ZipFile( 

189 file=file_path, mode="w", compression=self.compress 

190 ) 

191 

192 # store base json data and metadata 

193 zipf.writestr( 

194 ZANJ_META, 

195 json.dumps( 

196 self.json_serialize(self.meta()), 

197 indent="\t", 

198 ), 

199 ) 

200 zipf.writestr( 

201 ZANJ_MAIN, 

202 json.dumps( 

203 json_data, 

204 indent="\t", 

205 ), 

206 ) 

207 

208 # store externals 

209 for key, (ext_type, ext_data, ext_path) in self._externals.items(): 

210 # why force zip64? numpy.savez does it 

211 with zipf.open(key, "w", force_zip64=True) as fp: 

212 EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data) 

213 

214 zipf.close() 

215 

216 # clear the externals, again 

217 self._externals = dict() 

218 

219 return file_path 

220 

221 def read( 

222 self, 

223 file_path: Union[str, Path], 

224 ) -> Any: 

225 """load the object from a ZANJ archive 

226 # TODO: load only some part of the zanj file by passing an ObjectPath 

227 """ 

228 file_path = Path(file_path) 

229 if not file_path.exists(): 

230 raise FileNotFoundError(f"file not found: {file_path}") 

231 if not file_path.is_file(): 

232 raise FileNotFoundError(f"not a file: {file_path}") 

233 

234 loaded_zanj: LoadedZANJ = LoadedZANJ( 

235 path=file_path, 

236 zanj=self, 

237 ) 

238 

239 loaded_zanj.populate_externals() 

240 

241 return load_item_recursive( 

242 loaded_zanj._json_data, 

243 path=tuple(), 

244 zanj=self, 

245 error_mode=self.error_mode, 

246 # lh_map=loader_handlers, 

247 ) 

248 

249 

250zanj.externals._ZANJ_pre = ZANJ # type: ignore