docs for zanj v0.4.0
View Source on GitHub

zanj.serializing


  1from __future__ import annotations
  2
  3import json
  4import sys
  5from dataclasses import dataclass
  6from typing import IO, Any, Callable, Iterable, Sequence
  7
  8import numpy as np
  9from muutils.json_serialize.array import arr_metadata
 10from muutils.json_serialize.json_serialize import (  # JsonSerializer,
 11    DEFAULT_HANDLERS,
 12    ObjectPath,
 13    SerializerHandler,
 14)
 15from muutils.json_serialize.util import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY
 16from muutils.tensor_utils import NDArray
 17
 18from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre
 19
 20KW_ONLY_KWARGS: dict = dict()
 21if sys.version_info >= (3, 10):
 22    KW_ONLY_KWARGS["kw_only"] = True
 23
 24# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg
 25# for some reason pylint complains about kwargs to ZANJSerializerHandler
 26
 27
 28def jsonl_metadata(data: list[JSONdict]) -> dict:
 29    """metadata about a jsonl object"""
 30    all_cols: set[str] = set([col for item in data for col in item.keys()])
 31    return {
 32        "data[0]": data[0],
 33        "len(data)": len(data),
 34        "columns": {
 35            col: {
 36                "types": list(
 37                    set([type(item[col]).__name__ for item in data if col in item])
 38                ),
 39                "len": len([item[col] for item in data if col in item]),
 40            }
 41            for col in all_cols
 42            if col != _FORMAT_KEY
 43        },
 44    }
 45
 46
 47def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: NDArray) -> None:
 48    """store numpy array to given file as .npy"""
 49    np.lib.format.write_array(
 50        fp=fp,
 51        array=np.asanyarray(data),
 52        allow_pickle=False,
 53    )
 54
 55
 56def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
 57    """store sequence to given file as .jsonl"""
 58
 59    for item in data:
 60        fp.write(json.dumps(item).encode("utf-8"))
 61        fp.write("\n".encode("utf-8"))
 62
 63
 64EXTERNAL_STORE_FUNCS: dict[
 65    ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None]
 66] = {
 67    "npy": store_npy,
 68    "jsonl": store_jsonl,
 69}
 70
 71
 72@dataclass(**KW_ONLY_KWARGS)
 73class ZANJSerializerHandler(SerializerHandler):
 74    """a handler for ZANJ serialization"""
 75
 76    # unique identifier for the handler, saved in _FORMAT_KEY field
 77    # uid: str
 78    # source package of the handler -- note that this might be overridden by ZANJ
 79    source_pckg: str
 80    # (self_config, object) -> whether to use this handler
 81    check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
 82    # (self_config, object, path) -> serialized object
 83    serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
 84    # optional description of how this serializer works
 85    # desc: str = "(no description)"
 86
 87
 88def zanj_external_serialize(
 89    jser: _ZANJ_pre,
 90    data: Any,
 91    path: ObjectPath,
 92    item_type: ExternalItemType,
 93    _format: str,
 94) -> JSONitem:
 95    """stores a numpy array or jsonl externally in a ZANJ object
 96
 97    # Parameters:
 98     - `jser: ZANJ`
 99     - `data: Any`
100     - `path: ObjectPath`
101     - `item_type: ExternalItemType`
102
103    # Returns:
104     - `JSONitem`
105       json data with reference
106
107    # Modifies:
108     - modifies `jser._externals`
109    """
110    # get the path, make sure its unique
111    assert isinstance(path, tuple), (
112        f"path must be a tuple, got {type(path) = } {path = }"
113    )
114    joined_path: str = "/".join([str(p) for p in path])
115    archive_path: str = f"{joined_path}.{item_type}"
116
117    if archive_path in jser._externals:
118        raise ValueError(f"external path {archive_path} already exists!")
119    if any([p.startswith(joined_path) for p in jser._externals.keys()]):
120        raise ValueError(f"external path {joined_path} is a prefix of another path!")
121
122    # process the data if needed, assemble metadata
123    data_new: Any = data
124    output: dict = {
125        _FORMAT_KEY: _format,
126        _REF_KEY: archive_path,
127    }
128    if item_type == "npy":
129        # check type
130        data_type_str: str = str(type(data))
131        if data_type_str == "<class 'torch.Tensor'>":
132            # detach and convert
133            data_new = data.detach().cpu().numpy()
134        elif data_type_str == "<class 'numpy.ndarray'>":
135            pass
136        else:
137            # if not a numpy array, except
138            raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
139        # get metadata
140        output.update(arr_metadata(data))
141    elif item_type.startswith("jsonl"):
142        # check via mro to avoid importing pandas
143        if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__):
144            output["columns"] = data.columns.tolist()
145            data_new = data.to_dict(orient="records")
146        elif isinstance(data, (list, tuple, Iterable, Sequence)):
147            data_new = [
148                jser.json_serialize(item, tuple(path) + (i,))
149                for i, item in enumerate(data)
150            ]
151        else:
152            raise TypeError(
153                f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
154            )
155
156        if all([isinstance(item, dict) for item in data_new]):
157            output.update(jsonl_metadata(data_new))
158
159    # store the item for external serialization
160    jser._externals[archive_path] = ExternalItem(
161        item_type=item_type,
162        data=data_new,
163        path=path,
164    )
165
166    return output
167
168
169DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple(
170    [
171        ZANJSerializerHandler(
172            check=lambda self, obj, path: (
173                isinstance(obj, np.ndarray)
174                and obj.size >= self.external_array_threshold
175            ),
176            serialize_func=lambda self, obj, path: zanj_external_serialize(
177                self, obj, path, item_type="npy", _format="numpy.ndarray:external"
178            ),
179            uid="numpy.ndarray:external",
180            source_pckg="zanj",
181            desc="external numpy array",
182        ),
183        ZANJSerializerHandler(
184            check=lambda self, obj, path: (
185                str(type(obj)) == "<class 'torch.Tensor'>"
186                and int(obj.nelement()) >= self.external_array_threshold
187            ),
188            serialize_func=lambda self, obj, path: zanj_external_serialize(
189                self, obj, path, item_type="npy", _format="torch.Tensor:external"
190            ),
191            uid="torch.Tensor:external",
192            source_pckg="zanj",
193            desc="external torch tensor",
194        ),
195        ZANJSerializerHandler(
196            check=lambda self, obj, path: isinstance(obj, list)
197            and len(obj) >= self.external_list_threshold,
198            serialize_func=lambda self, obj, path: zanj_external_serialize(
199                self, obj, path, item_type="jsonl", _format="list:external"
200            ),
201            uid="list:external",
202            source_pckg="zanj",
203            desc="external list",
204        ),
205        ZANJSerializerHandler(
206            check=lambda self, obj, path: isinstance(obj, tuple)
207            and len(obj) >= self.external_list_threshold,
208            serialize_func=lambda self, obj, path: zanj_external_serialize(
209                self, obj, path, item_type="jsonl", _format="tuple:external"
210            ),
211            uid="tuple:external",
212            source_pckg="zanj",
213            desc="external tuple",
214        ),
215        ZANJSerializerHandler(
216            check=lambda self, obj, path: (
217                any(
218                    "pandas.core.frame.DataFrame" in str(t)
219                    for t in obj.__class__.__mro__
220                )
221                and len(obj) >= self.external_list_threshold
222            ),
223            serialize_func=lambda self, obj, path: zanj_external_serialize(
224                self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external"
225            ),
226            uid="pandas.DataFrame:external",
227            source_pckg="zanj",
228            desc="external pandas DataFrame",
229        ),
230        # ZANJSerializerHandler(
231        #     check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>"
232        #     in [str(t) for t in obj.__class__.__mro__],
233        #     serialize_func=lambda self, obj, path: zanj_serialize_torchmodule(
234        #         self, obj, path,
235        #     ),
236        #     uid="torch.nn.Module",
237        #     source_pckg="zanj",
238        #     desc="fallback torch serialization",
239        # ),
240    ]
241) + tuple(
242    DEFAULT_HANDLERS  # type: ignore[arg-type]
243)
244
245# the complaint above is:
246# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]"  [arg-type]

KW_ONLY_KWARGS: dict = {'kw_only': True}
def jsonl_metadata( data: list[typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]]]) -> dict:
29def jsonl_metadata(data: list[JSONdict]) -> dict:
30    """metadata about a jsonl object"""
31    all_cols: set[str] = set([col for item in data for col in item.keys()])
32    return {
33        "data[0]": data[0],
34        "len(data)": len(data),
35        "columns": {
36            col: {
37                "types": list(
38                    set([type(item[col]).__name__ for item in data if col in item])
39                ),
40                "len": len([item[col] for item in data if col in item]),
41            }
42            for col in all_cols
43            if col != _FORMAT_KEY
44        },
45    }

metadata about a jsonl object

def store_npy( self: Any, fp: IO[bytes], data: muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray) -> None:
48def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: NDArray) -> None:
49    """store numpy array to given file as .npy"""
50    np.lib.format.write_array(
51        fp=fp,
52        array=np.asanyarray(data),
53        allow_pickle=False,
54    )

store numpy array to given file as .npy

def store_jsonl( self: Any, fp: IO[bytes], data: Sequence[Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> None:
57def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None:
58    """store sequence to given file as .jsonl"""
59
60    for item in data:
61        fp.write(json.dumps(item).encode("utf-8"))
62        fp.write("\n".encode("utf-8"))

store sequence to given file as .jsonl

EXTERNAL_STORE_FUNCS: dict[typing.Literal['jsonl', 'npy'], typing.Callable[[typing.Any, typing.IO[bytes], typing.Any], NoneType]] = {'npy': <function store_npy>, 'jsonl': <function store_jsonl>}
@dataclass(**KW_ONLY_KWARGS)
class ZANJSerializerHandler(muutils.json_serialize.json_serialize.SerializerHandler):
73@dataclass(**KW_ONLY_KWARGS)
74class ZANJSerializerHandler(SerializerHandler):
75    """a handler for ZANJ serialization"""
76
77    # unique identifier for the handler, saved in _FORMAT_KEY field
78    # uid: str
79    # source package of the handler -- note that this might be overridden by ZANJ
80    source_pckg: str
81    # (self_config, object) -> whether to use this handler
82    check: Callable[[_ZANJ_pre, Any, ObjectPath], bool]
83    # (self_config, object, path) -> serialized object
84    serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem]
85    # optional description of how this serializer works
86    # desc: str = "(no description)"

a handler for ZANJ serialization

ZANJSerializerHandler( uid: str, desc: str, *, check: Callable[[Any, Any, tuple[Union[str, int], ...]], bool], serialize_func: Callable[[Any, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]], source_pckg: str)
source_pckg: str
check: Callable[[Any, Any, tuple[Union[str, int], ...]], bool]
serialize_func: Callable[[Any, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]
Inherited Members
muutils.json_serialize.json_serialize.SerializerHandler
uid
desc
serialize
def zanj_external_serialize( jser: Any, data: Any, path: tuple[typing.Union[str, int], ...], item_type: Literal['jsonl', 'npy'], _format: str) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
 89def zanj_external_serialize(
 90    jser: _ZANJ_pre,
 91    data: Any,
 92    path: ObjectPath,
 93    item_type: ExternalItemType,
 94    _format: str,
 95) -> JSONitem:
 96    """stores a numpy array or jsonl externally in a ZANJ object
 97
 98    # Parameters:
 99     - `jser: ZANJ`
100     - `data: Any`
101     - `path: ObjectPath`
102     - `item_type: ExternalItemType`
103
104    # Returns:
105     - `JSONitem`
106       json data with reference
107
108    # Modifies:
109     - modifies `jser._externals`
110    """
111    # get the path, make sure its unique
112    assert isinstance(path, tuple), (
113        f"path must be a tuple, got {type(path) = } {path = }"
114    )
115    joined_path: str = "/".join([str(p) for p in path])
116    archive_path: str = f"{joined_path}.{item_type}"
117
118    if archive_path in jser._externals:
119        raise ValueError(f"external path {archive_path} already exists!")
120    if any([p.startswith(joined_path) for p in jser._externals.keys()]):
121        raise ValueError(f"external path {joined_path} is a prefix of another path!")
122
123    # process the data if needed, assemble metadata
124    data_new: Any = data
125    output: dict = {
126        _FORMAT_KEY: _format,
127        _REF_KEY: archive_path,
128    }
129    if item_type == "npy":
130        # check type
131        data_type_str: str = str(type(data))
132        if data_type_str == "<class 'torch.Tensor'>":
133            # detach and convert
134            data_new = data.detach().cpu().numpy()
135        elif data_type_str == "<class 'numpy.ndarray'>":
136            pass
137        else:
138            # if not a numpy array, except
139            raise TypeError(f"expected numpy.ndarray, got {data_type_str}")
140        # get metadata
141        output.update(arr_metadata(data))
142    elif item_type.startswith("jsonl"):
143        # check via mro to avoid importing pandas
144        if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__):
145            output["columns"] = data.columns.tolist()
146            data_new = data.to_dict(orient="records")
147        elif isinstance(data, (list, tuple, Iterable, Sequence)):
148            data_new = [
149                jser.json_serialize(item, tuple(path) + (i,))
150                for i, item in enumerate(data)
151            ]
152        else:
153            raise TypeError(
154                f"expected list or pandas.DataFrame for jsonl, got {type(data)}"
155            )
156
157        if all([isinstance(item, dict) for item in data_new]):
158            output.update(jsonl_metadata(data_new))
159
160    # store the item for external serialization
161    jser._externals[archive_path] = ExternalItem(
162        item_type=item_type,
163        data=data_new,
164        path=path,
165    )
166
167    return output

stores a numpy array or jsonl externally in a ZANJ object

Parameters:

  • jser: ZANJ
  • data: Any
  • path: ObjectPath
  • item_type: ExternalItemType

Returns:

  • JSONitem json data with reference

Modifies:

  • modifies jser._externals
DEFAULT_SERIALIZER_HANDLERS_ZANJ: None = (ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray:external', desc='external numpy array', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor:external', desc='external torch tensor', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='list:external', desc='external list', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='tuple:external', desc='external tuple', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame:external', desc='external pandas DataFrame', source_pckg='zanj'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings'))