Coverage for zanj\loading.py: 80%
126 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import json
4import threading
5import typing
6import zipfile
7from dataclasses import dataclass
8from pathlib import Path
9from typing import Any, Callable
11import numpy as np
13try:
14 import pandas as pd # type: ignore[import]
16 pandas_DataFrame = pd.DataFrame # type: ignore[no-redef]
17except ImportError:
19 class pandas_DataFrame: # type: ignore[no-redef]
20 def __init__(self, *args, **kwargs):
21 raise ImportError("cannot load pandas DataFrame, pandas is not installed")
24import torch
25from muutils.errormode import ErrorMode
26from muutils.json_serialize.array import load_array
27from muutils.json_serialize.json_serialize import ObjectPath
28from muutils.json_serialize.util import (
29 _FORMAT_KEY,
30 _REF_KEY,
31 JSONdict,
32 JSONitem,
33 safe_getsource,
34 string_as_lines,
35)
36from muutils.tensor_utils import DTYPE_MAP, TORCH_DTYPE_MAP
38from zanj.externals import (
39 GET_EXTERNAL_LOAD_FUNC,
40 ZANJ_MAIN,
41 ZANJ_META,
42 ExternalItem,
43 _ZANJ_pre,
44)
46# pylint: disable=protected-access, dangerous-default-value
49def _populate_externals_error_checking(key, item) -> bool:
50 """checks that the key is valid for the item. returns "True" we need to augment the path by accessing the "data" element"""
52 # special case for not fully loaded external item which we still need to populate
53 if isinstance(item, typing.Mapping):
54 if (_FORMAT_KEY in item) and item[_FORMAT_KEY].endswith(":external"):
55 if "data" in item:
56 return True
57 else:
58 raise KeyError(
59 f"expected an external item, but could not find data: {list(item.keys())}",
60 f"{item[_FORMAT_KEY]}, {len(item) = }, {item.get('data', '<EMPTY>') = }",
61 )
63 # if it's a list, make sure the key is an int and that it's in range
64 if isinstance(item, typing.Sequence):
65 if not isinstance(key, int):
66 raise TypeError(f"improper type: '{type(key) = }', expected int")
67 if key >= len(item):
68 raise IndexError(f"index out of range: '{key = }', expected < {len(item)}")
70 # if it's a dict, make sure that the key is a str and that it's in the dict
71 elif isinstance(item, typing.Mapping):
72 if not isinstance(key, str):
73 raise TypeError(f"improper type: '{type(key) = }', expected str")
74 if key not in item:
75 raise KeyError(f"key not in dict: '{key = }', expected in {item.keys()}")
77 # otherwise, raise an error
78 else:
79 raise TypeError(f"improper type: '{type(item) = }', expected dict or list")
81 return False
84@dataclass
85class LoaderHandler:
86 """handler for loading an object from a json file or a ZANJ archive"""
88 # TODO: add a separate "asserts" function?
89 # right now, any asserts must happen in `check` or `load` which is annoying with lambdas
91 # (json_data, path) -> whether to use this handler
92 check: Callable[[JSONitem, ObjectPath, _ZANJ_pre], bool]
93 # function to load the object (json_data, path) -> loaded_obj
94 load: Callable[[JSONitem, ObjectPath, _ZANJ_pre], Any]
95 # unique identifier for the handler, saved in __muutils_format__ field
96 uid: str
97 # source package of the handler -- note that this might be overridden by ZANJ
98 source_pckg: str
99 # priority of the handler, defaults are all 0
100 priority: int = 0
101 # description of the handler
102 desc: str = "(no description)"
104 def serialize(self) -> JSONdict:
105 """serialize the handler info"""
106 return {
107 # get the code and doc of the check function
108 "check": {
109 "code": safe_getsource(self.check),
110 "doc": string_as_lines(self.check.__doc__),
111 },
112 # get the code and doc of the load function
113 "load": {
114 "code": safe_getsource(self.load),
115 "doc": string_as_lines(self.load.__doc__),
116 },
117 # get the uid, source_pckg, priority, and desc
118 "uid": str(self.uid),
119 "source_pckg": str(self.source_pckg),
120 "priority": int(self.priority),
121 "desc": str(self.desc),
122 }
124 @classmethod
125 def from_formattedclass(cls, fc: type, priority: int = 0):
126 """create a loader from a class with `serialize`, `load` methods and `__muutils_format__` attribute"""
127 assert hasattr(fc, "serialize")
128 assert callable(fc.serialize) # type: ignore
129 assert hasattr(fc, "load")
130 assert callable(fc.load) # type: ignore
131 assert hasattr(fc, _FORMAT_KEY)
132 assert isinstance(fc.__muutils_format__, str) # type: ignore
134 return cls(
135 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
136 json_item[_FORMAT_KEY] == fc.__muutils_format__ # type: ignore[attr-defined]
137 ),
138 load=lambda json_item, path=None, z=None: fc.load(json_item, path, z), # type: ignore[misc]
139 uid=fc.__muutils_format__, # type: ignore[attr-defined]
140 source_pckg=str(fc.__module__),
141 priority=priority,
142 desc=f"formatted class loader for {fc.__name__}",
143 )
146# NOTE: there are type ignores on the loaders, since the type checking should be the responsibility of the check function
148LOADER_MAP_LOCK = threading.Lock()
150LOADER_MAP: dict[str, LoaderHandler] = {
151 lh.uid: lh
152 for lh in [
153 # array external
154 LoaderHandler(
155 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
156 isinstance(json_item, typing.Mapping)
157 and _FORMAT_KEY in json_item
158 and json_item[_FORMAT_KEY].startswith("numpy.ndarray")
159 # and json_item["data"].dtype.name == json_item["dtype"]
160 # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
161 ),
162 load=lambda json_item, path=None, z=None: np.array( # type: ignore[misc]
163 load_array(json_item), dtype=DTYPE_MAP[json_item["dtype"]]
164 ),
165 uid="numpy.ndarray",
166 source_pckg="zanj",
167 desc="numpy.ndarray loader",
168 ),
169 LoaderHandler(
170 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
171 isinstance(json_item, typing.Mapping)
172 and _FORMAT_KEY in json_item
173 and json_item[_FORMAT_KEY].startswith("torch.Tensor")
174 # and json_item["data"].dtype.name == json_item["dtype"]
175 # and tuple(json_item["data"].shape) == tuple(json_item["shape"])
176 ),
177 load=lambda json_item, path=None, z=None: torch.tensor( # type: ignore[misc]
178 load_array(json_item), dtype=TORCH_DTYPE_MAP[json_item["dtype"]]
179 ),
180 uid="torch.Tensor",
181 source_pckg="zanj",
182 desc="torch.Tensor loader",
183 ),
184 # pandas
185 LoaderHandler(
186 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
187 isinstance(json_item, typing.Mapping)
188 and _FORMAT_KEY in json_item
189 and json_item[_FORMAT_KEY].startswith("pandas.DataFrame")
190 and "data" in json_item
191 and isinstance(json_item["data"], typing.Sequence)
192 ),
193 load=lambda json_item, path=None, z=None: pandas_DataFrame( # type: ignore[misc]
194 json_item["data"]
195 ),
196 uid="pandas.DataFrame",
197 source_pckg="zanj",
198 desc="pandas.DataFrame loader",
199 ),
200 # list/tuple external
201 LoaderHandler(
202 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
203 isinstance(json_item, typing.Mapping)
204 and _FORMAT_KEY in json_item
205 and json_item[_FORMAT_KEY].startswith("list")
206 and "data" in json_item
207 and isinstance(json_item["data"], typing.Sequence)
208 ),
209 load=lambda json_item, path=None, z=None: [ # type: ignore[misc]
210 load_item_recursive(x, path, z) for x in json_item["data"]
211 ],
212 uid="list",
213 source_pckg="zanj",
214 desc="list loader, for externals",
215 ),
216 LoaderHandler(
217 check=lambda json_item, path=None, z=None: ( # type: ignore[misc]
218 isinstance(json_item, typing.Mapping)
219 and _FORMAT_KEY in json_item
220 and json_item[_FORMAT_KEY].startswith("tuple")
221 and "data" in json_item
222 and isinstance(json_item["data"], typing.Sequence)
223 ),
224 load=lambda json_item, path=None, z=None: tuple( # type: ignore[misc]
225 [load_item_recursive(x, path, z) for x in json_item["data"]]
226 ),
227 uid="tuple",
228 source_pckg="zanj",
229 desc="tuple loader, for externals",
230 ),
231 ]
232}
235def register_loader_handler(handler: LoaderHandler):
236 """register a custom loader handler"""
237 global LOADER_MAP, LOADER_MAP_LOCK
238 with LOADER_MAP_LOCK:
239 LOADER_MAP[handler.uid] = handler
242def get_item_loader(
243 json_item: JSONitem,
244 path: ObjectPath,
245 zanj: _ZANJ_pre | None = None,
246 error_mode: ErrorMode = ErrorMode.WARN,
247 # lh_map: dict[str, LoaderHandler] = LOADER_MAP,
248) -> LoaderHandler | None:
249 """get the loader for a json item"""
250 global LOADER_MAP
252 # check if we recognize the format
253 if isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item:
254 if not isinstance(json_item[_FORMAT_KEY], str):
255 raise TypeError(
256 f"invalid __muutils_format__ type '{type(json_item[_FORMAT_KEY])}' in '{path=}': '{json_item[_FORMAT_KEY] = }'"
257 )
258 if json_item[_FORMAT_KEY] in LOADER_MAP:
259 return LOADER_MAP[json_item[_FORMAT_KEY]] # type: ignore[index]
261 # if we dont recognize the format, try to find a loader that can handle it
262 for key, lh in LOADER_MAP.items():
263 if lh.check(json_item, path, zanj):
264 return lh
266 # if we still dont have a loader, return None
267 return None
270def load_item_recursive(
271 json_item: JSONitem,
272 path: ObjectPath,
273 zanj: _ZANJ_pre | None = None,
274 error_mode: ErrorMode = ErrorMode.WARN,
275 allow_not_loading: bool = True,
276) -> Any:
277 lh: LoaderHandler | None = get_item_loader(
278 json_item=json_item,
279 path=path,
280 zanj=zanj,
281 error_mode=error_mode,
282 # lh_map=lh_map,
283 )
285 if lh is not None:
286 # special case for serializable dataclasses
287 if (
288 isinstance(json_item, typing.Mapping)
289 and (_FORMAT_KEY in json_item)
290 and ("SerializableDataclass" in json_item[_FORMAT_KEY]) # type: ignore[operator]
291 ):
292 # why this horribleness?
293 # SerializableDataclass, if it has a field `x` which is also a SerializableDataclass, will automatically call `x.__class__.load()`
294 # However, we need to load things in containers, as well as arrays
295 processed_json_item: dict = {
296 key: (
297 val
298 if (
299 isinstance(val, typing.Mapping)
300 and (_FORMAT_KEY in val)
301 and ("SerializableDataclass" in val[_FORMAT_KEY])
302 )
303 else load_item_recursive(
304 json_item=val,
305 path=tuple(path) + (key,),
306 zanj=zanj,
307 error_mode=error_mode,
308 )
309 )
310 for key, val in json_item.items()
311 }
313 return lh.load(processed_json_item, path, zanj)
315 else:
316 return lh.load(json_item, path, zanj)
317 else:
318 if isinstance(json_item, dict):
319 return {
320 key: load_item_recursive(
321 json_item=json_item[key],
322 path=tuple(path) + (key,),
323 zanj=zanj,
324 error_mode=error_mode,
325 # lh_map=lh_map,
326 )
327 for key in json_item
328 }
329 elif isinstance(json_item, list):
330 return [
331 load_item_recursive(
332 json_item=x,
333 path=tuple(path) + (i,),
334 zanj=zanj,
335 error_mode=error_mode,
336 # lh_map=lh_map,
337 )
338 for i, x in enumerate(json_item)
339 ]
340 elif isinstance(json_item, (str, int, float, bool, type(None))):
341 return json_item
342 else:
343 if allow_not_loading:
344 return json_item
345 else:
346 raise ValueError(
347 f"unknown type {type(json_item)} at {path}\n{json_item}"
348 )
351def _each_item_in_externals(
352 externals: dict[str, ExternalItem],
353 json_data: JSONitem,
354) -> typing.Iterable[tuple[str, ExternalItem, Any, ObjectPath]]:
355 """note that you MUST use the raw iterator, dont try to turn into a list or something"""
357 sorted_externals: list[tuple[str, ExternalItem]] = sorted(
358 externals.items(), key=lambda x: len(x[1].path)
359 )
361 for ext_path, ext_item in sorted_externals:
362 # get the path to the item
363 path: ObjectPath = tuple(ext_item.path)
364 assert len(path) > 0
365 assert all(isinstance(key, (str, int)) for key in path), (
366 f"improper types in path {path=}"
367 )
368 # get the item
369 item = json_data
370 for i, key in enumerate(path):
371 try:
372 # ignores in this block are because we cannot know the type is indexable in static analysis
373 # but, we check the types in the line below
374 external_unloaded: bool = _populate_externals_error_checking(key, item)
375 if external_unloaded:
376 item = item["data"] # type: ignore
377 item = item[key] # type: ignore[index]
379 except (KeyError, IndexError, TypeError) as e:
380 raise KeyError(
381 f"could not find '{key = }' at path '{ext_path = }', specifically at index '{i = }'",
382 f"'{type(item) =}', '{len(item) = }', '{item.keys() if isinstance(item, dict) else None = }'", # type: ignore
383 f"From error: {e = }",
384 f"\n\n{item=}\n\n{ext_item=}",
385 ) from e
387 yield (ext_path, ext_item, item, path)
390class LoadedZANJ:
391 """for loading a zanj file"""
393 def __init__(
394 self,
395 path: str | Path,
396 zanj: _ZANJ_pre,
397 ) -> None:
398 # path and zanj object
399 self._path: str = str(path)
400 self._zanj: _ZANJ_pre = zanj
402 # load zip file
403 _zipf: zipfile.ZipFile = zipfile.ZipFile(file=self._path, mode="r")
405 # load data
406 self._meta: JSONdict = json.load(_zipf.open(ZANJ_META, "r"))
407 self._json_data: JSONitem = json.load(_zipf.open(ZANJ_MAIN, "r"))
409 # read externals
410 self._externals: dict[str, ExternalItem] = dict()
411 for fname, ext_item in self._meta["externals_info"].items(): # type: ignore
412 item_type: str = ext_item["item_type"] # type: ignore
413 with _zipf.open(fname, "r") as fp:
414 self._externals[fname] = ExternalItem(
415 item_type=item_type, # type: ignore[arg-type]
416 data=GET_EXTERNAL_LOAD_FUNC(item_type)(self, fp),
417 path=ext_item["path"], # type: ignore
418 )
420 # close zip file
421 _zipf.close()
422 del _zipf
424 def populate_externals(self) -> None:
425 """put all external items into the main json data"""
427 # loop over once, populating the externals only
428 for ext_path, ext_item, item, path in _each_item_in_externals(
429 self._externals, self._json_data
430 ):
431 # replace the item with the external item
432 assert _REF_KEY in item # type: ignore
433 assert item[_REF_KEY] == ext_path # type: ignore
434 item["data"] = ext_item.data # type: ignore