Coverage for zanj\serializing.py: 93%
59 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import json
4import sys
5from dataclasses import dataclass
6from typing import IO, Any, Callable, Iterable, Sequence
8import numpy as np
9from muutils.json_serialize.array import arr_metadata
10from muutils.json_serialize.json_serialize import ( # JsonSerializer,
11 DEFAULT_HANDLERS,
12 ObjectPath,
13 SerializerHandler,
14)
15from muutils.json_serialize.util import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY
16from muutils.tensor_utils import NDArray
18from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre
20KW_ONLY_KWARGS: dict = dict()
21if sys.version_info >= (3, 10):
22 KW_ONLY_KWARGS["kw_only"] = True
24# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
25# for some reason pylint complains about kwargs to ZANJSerializerHandler
28def jsonl_metadata(data: list[JSONdict]) -> dict:
29 """metadata about a jsonl object"""
30 all_cols: set[str] = set([col for item in data for col in item.keys()])
31 return {
32 "data[0]": data[0],
33 "len(data)": len(data),
34 "columns": {
35 col: {
36 "types": list(
37 set([type(item[col]).__name__ for item in data if col in item])
38 ),
39 "len": len([item[col] for item in data if col in item]),
40 }
41 for col in all_cols
42 if col != _FORMAT_KEY
43 },
44 }
47def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: NDArray) -> None:
48 """store numpy array to given file as .npy"""
49 np.lib.format.write_array(
50 fp=fp,
51 array=np.asanyarray(data),
52 allow_pickle=False,
53 )
56def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
57 """store sequence to given file as .jsonl"""
59 for item in data:
60 fp.write(json.dumps(item).encode("utf-8"))
61 fp.write("\n".encode("utf-8"))
64EXTERNAL_STORE_FUNCS: dict[
65 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
66] = {
67 "npy": store_npy,
68 "jsonl": store_jsonl,
69}
72@dataclass(**KW_ONLY_KWARGS)
73class ZANJSerializerHandler(SerializerHandler):
74 """a handler for ZANJ serialization"""
76 # unique identifier for the handler, saved in _FORMAT_KEY field
77 # uid: str
78 # source package of the handler -- note that this might be overridden by ZANJ
79 source_pckg: str
80 # (self_config, object) -> whether to use this handler
81 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
82 # (self_config, object, path) -> serialized object
83 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
84 # optional description of how this serializer works
85 # desc: str = "(no description)"
88def zanj_external_serialize(
89 jser: _ZANJ_pre,
90 data: Any,
91 path: ObjectPath,
92 item_type: ExternalItemType,
93 _format: str,
94) -> JSONitem:
95 """stores a numpy array or jsonl externally in a ZANJ object
97 # Parameters:
98 - `jser: ZANJ`
99 - `data: Any`
100 - `path: ObjectPath`
101 - `item_type: ExternalItemType`
103 # Returns:
104 - `JSONitem`
105 json data with reference
107 # Modifies:
108 - modifies `jser._externals`
109 """
110 # get the path, make sure its unique
111 assert isinstance(path, tuple), (
112 f"path must be a tuple, got {type(path) = } {path = }"
113 )
114 joined_path: str = "/".join([str(p) for p in path])
115 archive_path: str = f"{joined_path}.{item_type}"
117 if archive_path in jser._externals:
118 raise ValueError(f"external path {archive_path} already exists!")
119 if any([p.startswith(joined_path) for p in jser._externals.keys()]):
120 raise ValueError(f"external path {joined_path} is a prefix of another path!")
122 # process the data if needed, assemble metadata
123 data_new: Any = data
124 output: dict = {
125 _FORMAT_KEY: _format,
126 _REF_KEY: archive_path,
127 }
128 if item_type == "npy":
129 # check type
130 data_type_str: str = str(type(data))
131 if data_type_str == "<class 'torch.Tensor'>":
132 # detach and convert
133 data_new = data.detach().cpu().numpy()
134 elif data_type_str == "<class 'numpy.ndarray'>":
135 pass
136 else:
137 # if not a numpy array, except
138 raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
139 # get metadata
140 output.update(arr_metadata(data))
141 elif item_type.startswith("jsonl"):
142 # check via mro to avoid importing pandas
143 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__):
144 output["columns"] = data.columns.tolist()
145 data_new = data.to_dict(orient="records")
146 elif isinstance(data, (list, tuple, Iterable, Sequence)):
147 data_new = [
148 jser.json_serialize(item, tuple(path) + (i,))
149 for i, item in enumerate(data)
150 ]
151 else:
152 raise TypeError(
153 f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
154 )
156 if all([isinstance(item, dict) for item in data_new]):
157 output.update(jsonl_metadata(data_new))
159 # store the item for external serialization
160 jser._externals[archive_path] = ExternalItem(
161 item_type=item_type,
162 data=data_new,
163 path=path,
164 )
166 return output
169DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
170 [
171 ZANJSerializerHandler(
172 check=lambda self, obj, path: (
173 isinstance(obj, np.ndarray)
174 and obj.size >= self.external_array_threshold
175 ),
176 serialize_func=lambda self, obj, path: zanj_external_serialize(
177 self, obj, path, item_type="npy", _format="numpy.ndarray:external"
178 ),
179 uid="numpy.ndarray:external",
180 source_pckg="zanj",
181 desc="external numpy array",
182 ),
183 ZANJSerializerHandler(
184 check=lambda self, obj, path: (
185 str(type(obj)) == "<class 'torch.Tensor'>"
186 and int(obj.nelement()) >= self.external_array_threshold
187 ),
188 serialize_func=lambda self, obj, path: zanj_external_serialize(
189 self, obj, path, item_type="npy", _format="torch.Tensor:external"
190 ),
191 uid="torch.Tensor:external",
192 source_pckg="zanj",
193 desc="external torch tensor",
194 ),
195 ZANJSerializerHandler(
196 check=lambda self, obj, path: isinstance(obj, list)
197 and len(obj) >= self.external_list_threshold,
198 serialize_func=lambda self, obj, path: zanj_external_serialize(
199 self, obj, path, item_type="jsonl", _format="list:external"
200 ),
201 uid="list:external",
202 source_pckg="zanj",
203 desc="external list",
204 ),
205 ZANJSerializerHandler(
206 check=lambda self, obj, path: isinstance(obj, tuple)
207 and len(obj) >= self.external_list_threshold,
208 serialize_func=lambda self, obj, path: zanj_external_serialize(
209 self, obj, path, item_type="jsonl", _format="tuple:external"
210 ),
211 uid="tuple:external",
212 source_pckg="zanj",
213 desc="external tuple",
214 ),
215 ZANJSerializerHandler(
216 check=lambda self, obj, path: (
217 any(
218 "pandas.core.frame.DataFrame" in str(t)
219 for t in obj.__class__.__mro__
220 )
221 and len(obj) >= self.external_list_threshold
222 ),
223 serialize_func=lambda self, obj, path: zanj_external_serialize(
224 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
225 ),
226 uid="pandas.DataFrame:external",
227 source_pckg="zanj",
228 desc="external pandas DataFrame",
229 ),
230 # ZANJSerializerHandler(
231 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
232 # in [str(t) for t in obj.__class__.__mro__],
233 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
234 # self, obj, path,
235 # ),
236 # uid="torch.nn.Module",
237 # source_pckg="zanj",
238 # desc="fallback torch serialization",
239 # ),
240 ]
241) + tuple(
242 DEFAULT_HANDLERS # type: ignore[arg-type]
243)
245# the complaint above is:
246# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]