Coverage for zanj\zanj.py: 94%
83 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1"""
2an HDF5/exdir file alternative, which uses json for attributes, allows serialization of arbitrary data
4for large arrays, the output is a .tar.gz file with most data in a json file, but with sufficiently large arrays stored in binary .npy files
7"ZANJ" is an acronym that the AI tool [Elicit](https://elicit.org) came up with for me. not to be confused with:
9- https://en.wikipedia.org/wiki/Zanj
10- https://www.plutojournals.com/zanj/
12"""
14from __future__ import annotations
16import json
17import os
18import time
19import zipfile
20from dataclasses import dataclass
21from pathlib import Path
22from typing import Any, Union
24import numpy as np
25from muutils.errormode import ErrorMode
26from muutils.json_serialize.array import ArrayMode, arr_metadata
27from muutils.json_serialize.json_serialize import (
28 JsonSerializer,
29 SerializerHandler,
30 json_serialize,
31)
32from muutils.json_serialize.util import JSONitem, MonoTuple
33from muutils.sysinfo import SysInfo
35from zanj.externals import ZANJ_MAIN, ZANJ_META, ExternalItem
36import zanj.externals
37from zanj.loading import LOADER_MAP, LoadedZANJ, load_item_recursive
38from zanj.serializing import (
39 DEFAULT_SERIALIZER_HANDLERS_ZANJ,
40 EXTERNAL_STORE_FUNCS,
41 KW_ONLY_KWARGS,
42)
44# pylint: disable=protected-access, unused-import, dangerous-default-value, line-too-long
46ZANJitem = Union[
47 JSONitem,
48 np.ndarray,
49 "pd.DataFrame", # type: ignore # noqa: F821
50]
53@dataclass(**KW_ONLY_KWARGS)
54class _ZANJ_GLOBAL_DEFAULTS_CLASS:
55 error_mode: ErrorMode = ErrorMode.EXCEPT
56 internal_array_mode: ArrayMode = "array_list_meta"
57 external_array_threshold: int = 256
58 external_list_threshold: int = 256
59 compress: bool | int = True
60 custom_settings: dict[str, Any] | None = None
63ZANJ_GLOBAL_DEFAULTS: _ZANJ_GLOBAL_DEFAULTS_CLASS = _ZANJ_GLOBAL_DEFAULTS_CLASS()
66class ZANJ(JsonSerializer):
67 """Zip up: Arrays in Numpy, JSON for everything else
69 given an arbitrary object, throw into a zip file, with arrays stored in .npy files, and everything else stored in a json file
71 (basically npz file with json)
73 - numpy (or pytorch) arrays are stored in paths according to their name and structure in the object
74 - everything else about the object is stored in a json file `zanj.json` in the root of the archive, via `muutils.json_serialize.JsonSerializer`
75 - metadata about ZANJ configuration, and optionally packages and versions, is stored in a `__zanj_meta__.json` file in the root of the archive
77 create a ZANJ-class via `z_cls = ZANJ().create(obj)`, and save/read instances of the object via `z_cls.save(obj, path)`, `z_cls.load(path)`. be sure to pass an **instance** of the object, to make sure that the attributes of the class can be correctly recognized
79 """
81 def __init__(
82 self,
83 error_mode: ErrorMode = ZANJ_GLOBAL_DEFAULTS.error_mode,
84 internal_array_mode: ArrayMode = ZANJ_GLOBAL_DEFAULTS.internal_array_mode,
85 external_array_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_array_threshold,
86 external_list_threshold: int = ZANJ_GLOBAL_DEFAULTS.external_list_threshold,
87 compress: bool | int = ZANJ_GLOBAL_DEFAULTS.compress,
88 custom_settings: dict[str, Any] | None = ZANJ_GLOBAL_DEFAULTS.custom_settings,
89 handlers_pre: MonoTuple[SerializerHandler] = tuple(),
90 handlers_default: MonoTuple[
91 SerializerHandler
92 ] = DEFAULT_SERIALIZER_HANDLERS_ZANJ,
93 ) -> None:
94 super().__init__(
95 array_mode=internal_array_mode,
96 error_mode=error_mode,
97 handlers_pre=handlers_pre,
98 handlers_default=handlers_default,
99 )
101 self.external_array_threshold: int = external_array_threshold
102 self.external_list_threshold: int = external_list_threshold
103 self.custom_settings: dict = (
104 custom_settings if custom_settings is not None else dict()
105 )
107 # process compression to int if bool given
108 self.compress = compress
109 if isinstance(compress, bool):
110 if compress:
111 self.compress = zipfile.ZIP_DEFLATED
112 else:
113 self.compress = zipfile.ZIP_STORED
115 # create the externals, leave it empty
116 self._externals: dict[str, ExternalItem] = dict()
118 def externals_info(self) -> dict[str, dict[str, str | int | list[int]]]:
119 """return information about the current externals"""
120 output: dict[str, dict] = dict()
122 key: str
123 item: ExternalItem
124 for key, item in self._externals.items():
125 data = item.data
126 output[key] = {
127 "item_type": item.item_type,
128 "path": item.path,
129 "type(data)": str(type(data)),
130 "len(data)": len(data),
131 }
133 if item.item_type == "ndarray":
134 output[key].update(arr_metadata(data))
135 elif item.item_type.startswith("jsonl"):
136 output[key]["data[0]"] = data[0]
138 return {
139 key: val
140 for key, val in sorted(output.items(), key=lambda x: len(x[1]["path"]))
141 }
143 def meta(self) -> JSONitem:
144 """return the metadata of the ZANJ archive"""
146 serialization_handlers = {h.uid: h.serialize() for h in self.handlers}
147 load_handlers = {h.uid: h.serialize() for h in LOADER_MAP.values()}
149 return dict(
150 # configuration of this ZANJ instance
151 zanj_cfg=dict(
152 error_mode=str(self.error_mode),
153 array_mode=str(self.array_mode),
154 external_array_threshold=self.external_array_threshold,
155 external_list_threshold=self.external_list_threshold,
156 compress=self.compress,
157 serialization_handlers=serialization_handlers,
158 load_handlers=load_handlers,
159 ),
160 # system info (python, pip packages, torch & cuda, platform info, git info)
161 sysinfo=json_serialize(SysInfo.get_all(include=("python", "pytorch"))),
162 externals_info=self.externals_info(),
163 timestamp=time.time(),
164 )
166 def save(self, obj: Any, file_path: str | Path) -> str:
167 """save the object to a ZANJ archive. returns the path to the archive"""
169 # adjust extension
170 file_path = str(file_path)
171 if not file_path.endswith(".zanj"):
172 file_path += ".zanj"
174 # make directory
175 dir_path: str = os.path.dirname(file_path)
176 if dir_path != "":
177 if not os.path.exists(dir_path):
178 os.makedirs(dir_path, exist_ok=False)
180 # clear the externals!
181 self._externals = dict()
183 # serialize the object -- this will populate self._externals
184 # TODO: calling self.json_serialize again here might be slow
185 json_data: JSONitem = self.json_serialize(self.json_serialize(obj))
187 # open the zip file
188 zipf: zipfile.ZipFile = zipfile.ZipFile(
189 file=file_path, mode="w", compression=self.compress
190 )
192 # store base json data and metadata
193 zipf.writestr(
194 ZANJ_META,
195 json.dumps(
196 self.json_serialize(self.meta()),
197 indent="\t",
198 ),
199 )
200 zipf.writestr(
201 ZANJ_MAIN,
202 json.dumps(
203 json_data,
204 indent="\t",
205 ),
206 )
208 # store externals
209 for key, (ext_type, ext_data, ext_path) in self._externals.items():
210 # why force zip64? numpy.savez does it
211 with zipf.open(key, "w", force_zip64=True) as fp:
212 EXTERNAL_STORE_FUNCS[ext_type](self, fp, ext_data)
214 zipf.close()
216 # clear the externals, again
217 self._externals = dict()
219 return file_path
221 def read(
222 self,
223 file_path: Union[str, Path],
224 ) -> Any:
225 """load the object from a ZANJ archive
226 # TODO: load only some part of the zanj file by passing an ObjectPath
227 """
228 file_path = Path(file_path)
229 if not file_path.exists():
230 raise FileNotFoundError(f"file not found: {file_path}")
231 if not file_path.is_file():
232 raise FileNotFoundError(f"not a file: {file_path}")
234 loaded_zanj: LoadedZANJ = LoadedZANJ(
235 path=file_path,
236 zanj=self,
237 )
239 loaded_zanj.populate_externals()
241 return load_item_recursive(
242 loaded_zanj._json_data,
243 path=tuple(),
244 zanj=self,
245 error_mode=self.error_mode,
246 # lh_map=loader_handlers,
247 )
250zanj.externals._ZANJ_pre = ZANJ # type: ignore