zanj.serializing
1from __future__ import annotations 2 3import json 4import sys 5from dataclasses import dataclass 6from typing import IO, Any, Callable, Iterable, Sequence 7 8import numpy as np 9from muutils.json_serialize.array import arr_metadata 10from muutils.json_serialize.json_serialize import ( # JsonSerializer, 11 DEFAULT_HANDLERS, 12 ObjectPath, 13 SerializerHandler, 14) 15from muutils.json_serialize.util import JSONdict, JSONitem, MonoTuple, _FORMAT_KEY, _REF_KEY 16from muutils.tensor_utils import NDArray 17 18from zanj.externals import ExternalItem, ExternalItemType, _ZANJ_pre 19 20KW_ONLY_KWARGS: dict = dict() 21if sys.version_info >= (3, 10): 22 KW_ONLY_KWARGS["kw_only"] = True 23 24# pylint: disable=unused-argument, protected-access, unexpected-keyword-arg 25# for some reason pylint complains about kwargs to ZANJSerializerHandler 26 27 28def jsonl_metadata(data: list[JSONdict]) -> dict: 29 """metadata about a jsonl object""" 30 all_cols: set[str] = set([col for item in data for col in item.keys()]) 31 return { 32 "data[0]": data[0], 33 "len(data)": len(data), 34 "columns": { 35 col: { 36 "types": list( 37 set([type(item[col]).__name__ for item in data if col in item]) 38 ), 39 "len": len([item[col] for item in data if col in item]), 40 } 41 for col in all_cols 42 if col != _FORMAT_KEY 43 }, 44 } 45 46 47def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: NDArray) -> None: 48 """store numpy array to given file as .npy""" 49 np.lib.format.write_array( 50 fp=fp, 51 array=np.asanyarray(data), 52 allow_pickle=False, 53 ) 54 55 56def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None: 57 """store sequence to given file as .jsonl""" 58 59 for item in data: 60 fp.write(json.dumps(item).encode("utf-8")) 61 fp.write("\n".encode("utf-8")) 62 63 64EXTERNAL_STORE_FUNCS: dict[ 65 ExternalItemType, Callable[[_ZANJ_pre, IO[bytes], Any], None] 66] = { 67 "npy": store_npy, 68 "jsonl": store_jsonl, 69} 70 71 72@dataclass(**KW_ONLY_KWARGS) 73class ZANJSerializerHandler(SerializerHandler): 74 """a handler for ZANJ serialization""" 75 76 # unique identifier for the handler, saved in _FORMAT_KEY field 77 # uid: str 78 # source package of the handler -- note that this might be overridden by ZANJ 79 source_pckg: str 80 # (self_config, object) -> whether to use this handler 81 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool] 82 # (self_config, object, path) -> serialized object 83 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem] 84 # optional description of how this serializer works 85 # desc: str = "(no description)" 86 87 88def zanj_external_serialize( 89 jser: _ZANJ_pre, 90 data: Any, 91 path: ObjectPath, 92 item_type: ExternalItemType, 93 _format: str, 94) -> JSONitem: 95 """stores a numpy array or jsonl externally in a ZANJ object 96 97 # Parameters: 98 - `jser: ZANJ` 99 - `data: Any` 100 - `path: ObjectPath` 101 - `item_type: ExternalItemType` 102 103 # Returns: 104 - `JSONitem` 105 json data with reference 106 107 # Modifies: 108 - modifies `jser._externals` 109 """ 110 # get the path, make sure its unique 111 assert isinstance(path, tuple), ( 112 f"path must be a tuple, got {type(path) = } {path = }" 113 ) 114 joined_path: str = "/".join([str(p) for p in path]) 115 archive_path: str = f"{joined_path}.{item_type}" 116 117 if archive_path in jser._externals: 118 raise ValueError(f"external path {archive_path} already exists!") 119 if any([p.startswith(joined_path) for p in jser._externals.keys()]): 120 raise ValueError(f"external path {joined_path} is a prefix of another path!") 121 122 # process the data if needed, assemble metadata 123 data_new: Any = data 124 output: dict = { 125 _FORMAT_KEY: _format, 126 _REF_KEY: archive_path, 127 } 128 if item_type == "npy": 129 # check type 130 data_type_str: str = str(type(data)) 131 if data_type_str == "<class 'torch.Tensor'>": 132 # detach and convert 133 data_new = data.detach().cpu().numpy() 134 elif data_type_str == "<class 'numpy.ndarray'>": 135 pass 136 else: 137 # if not a numpy array, except 138 raise TypeError(f"expected numpy.ndarray, got {data_type_str}") 139 # get metadata 140 output.update(arr_metadata(data)) 141 elif item_type.startswith("jsonl"): 142 # check via mro to avoid importing pandas 143 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__): 144 output["columns"] = data.columns.tolist() 145 data_new = data.to_dict(orient="records") 146 elif isinstance(data, (list, tuple, Iterable, Sequence)): 147 data_new = [ 148 jser.json_serialize(item, tuple(path) + (i,)) 149 for i, item in enumerate(data) 150 ] 151 else: 152 raise TypeError( 153 f"expected list or pandas.DataFrame for jsonl, got {type(data)}" 154 ) 155 156 if all([isinstance(item, dict) for item in data_new]): 157 output.update(jsonl_metadata(data_new)) 158 159 # store the item for external serialization 160 jser._externals[archive_path] = ExternalItem( 161 item_type=item_type, 162 data=data_new, 163 path=path, 164 ) 165 166 return output 167 168 169DEFAULT_SERIALIZER_HANDLERS_ZANJ: MonoTuple[ZANJSerializerHandler] = tuple( 170 [ 171 ZANJSerializerHandler( 172 check=lambda self, obj, path: ( 173 isinstance(obj, np.ndarray) 174 and obj.size >= self.external_array_threshold 175 ), 176 serialize_func=lambda self, obj, path: zanj_external_serialize( 177 self, obj, path, item_type="npy", _format="numpy.ndarray:external" 178 ), 179 uid="numpy.ndarray:external", 180 source_pckg="zanj", 181 desc="external numpy array", 182 ), 183 ZANJSerializerHandler( 184 check=lambda self, obj, path: ( 185 str(type(obj)) == "<class 'torch.Tensor'>" 186 and int(obj.nelement()) >= self.external_array_threshold 187 ), 188 serialize_func=lambda self, obj, path: zanj_external_serialize( 189 self, obj, path, item_type="npy", _format="torch.Tensor:external" 190 ), 191 uid="torch.Tensor:external", 192 source_pckg="zanj", 193 desc="external torch tensor", 194 ), 195 ZANJSerializerHandler( 196 check=lambda self, obj, path: isinstance(obj, list) 197 and len(obj) >= self.external_list_threshold, 198 serialize_func=lambda self, obj, path: zanj_external_serialize( 199 self, obj, path, item_type="jsonl", _format="list:external" 200 ), 201 uid="list:external", 202 source_pckg="zanj", 203 desc="external list", 204 ), 205 ZANJSerializerHandler( 206 check=lambda self, obj, path: isinstance(obj, tuple) 207 and len(obj) >= self.external_list_threshold, 208 serialize_func=lambda self, obj, path: zanj_external_serialize( 209 self, obj, path, item_type="jsonl", _format="tuple:external" 210 ), 211 uid="tuple:external", 212 source_pckg="zanj", 213 desc="external tuple", 214 ), 215 ZANJSerializerHandler( 216 check=lambda self, obj, path: ( 217 any( 218 "pandas.core.frame.DataFrame" in str(t) 219 for t in obj.__class__.__mro__ 220 ) 221 and len(obj) >= self.external_list_threshold 222 ), 223 serialize_func=lambda self, obj, path: zanj_external_serialize( 224 self, obj, path, item_type="jsonl", _format="pandas.DataFrame:external" 225 ), 226 uid="pandas.DataFrame:external", 227 source_pckg="zanj", 228 desc="external pandas DataFrame", 229 ), 230 # ZANJSerializerHandler( 231 # check=lambda self, obj, path: "<class 'torch.nn.modules.module.Module'>" 232 # in [str(t) for t in obj.__class__.__mro__], 233 # serialize_func=lambda self, obj, path: zanj_serialize_torchmodule( 234 # self, obj, path, 235 # ), 236 # uid="torch.nn.Module", 237 # source_pckg="zanj", 238 # desc="fallback torch serialization", 239 # ), 240 ] 241) + tuple( 242 DEFAULT_HANDLERS # type: ignore[arg-type] 243) 244 245# the complaint above is: 246# error: Argument 1 to "tuple" has incompatible type "Sequence[SerializerHandler]"; expected "Iterable[ZANJSerializerHandler]" [arg-type]
KW_ONLY_KWARGS: dict =
{'kw_only': True}
def
jsonl_metadata( data: list[typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]], typing.Dict[str, typing.Union[bool, int, float, str, NoneType, typing.List[typing.Any], typing.Dict[str, typing.Any]]]]]]) -> dict:
29def jsonl_metadata(data: list[JSONdict]) -> dict: 30 """metadata about a jsonl object""" 31 all_cols: set[str] = set([col for item in data for col in item.keys()]) 32 return { 33 "data[0]": data[0], 34 "len(data)": len(data), 35 "columns": { 36 col: { 37 "types": list( 38 set([type(item[col]).__name__ for item in data if col in item]) 39 ), 40 "len": len([item[col] for item in data if col in item]), 41 } 42 for col in all_cols 43 if col != _FORMAT_KEY 44 }, 45 }
metadata about a jsonl object
def
store_npy( self: Any, fp: IO[bytes], data: muutils.tensor_utils.jaxtype_factory.<locals>._BaseArray) -> None:
48def store_npy(self: _ZANJ_pre, fp: IO[bytes], data: NDArray) -> None: 49 """store numpy array to given file as .npy""" 50 np.lib.format.write_array( 51 fp=fp, 52 array=np.asanyarray(data), 53 allow_pickle=False, 54 )
store numpy array to given file as .npy
def
store_jsonl( self: Any, fp: IO[bytes], data: Sequence[Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> None:
57def store_jsonl(self: _ZANJ_pre, fp: IO[bytes], data: Sequence[JSONitem]) -> None: 58 """store sequence to given file as .jsonl""" 59 60 for item in data: 61 fp.write(json.dumps(item).encode("utf-8")) 62 fp.write("\n".encode("utf-8"))
store sequence to given file as .jsonl
EXTERNAL_STORE_FUNCS: dict[typing.Literal['jsonl', 'npy'], typing.Callable[[typing.Any, typing.IO[bytes], typing.Any], NoneType]] =
{'npy': <function store_npy>, 'jsonl': <function store_jsonl>}
@dataclass(**KW_ONLY_KWARGS)
class
ZANJSerializerHandler73@dataclass(**KW_ONLY_KWARGS) 74class ZANJSerializerHandler(SerializerHandler): 75 """a handler for ZANJ serialization""" 76 77 # unique identifier for the handler, saved in _FORMAT_KEY field 78 # uid: str 79 # source package of the handler -- note that this might be overridden by ZANJ 80 source_pckg: str 81 # (self_config, object) -> whether to use this handler 82 check: Callable[[_ZANJ_pre, Any, ObjectPath], bool] 83 # (self_config, object, path) -> serialized object 84 serialize_func: Callable[[_ZANJ_pre, Any, ObjectPath], JSONitem] 85 # optional description of how this serializer works 86 # desc: str = "(no description)"
a handler for ZANJ serialization
ZANJSerializerHandler( uid: str, desc: str, *, check: Callable[[Any, Any, tuple[Union[str, int], ...]], bool], serialize_func: Callable[[Any, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]], source_pckg: str)
serialize_func: Callable[[Any, Any, tuple[Union[str, int], ...]], Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]
Inherited Members
- muutils.json_serialize.json_serialize.SerializerHandler
- uid
- desc
- serialize
def
zanj_external_serialize( jser: Any, data: Any, path: tuple[typing.Union[str, int], ...], item_type: Literal['jsonl', 'npy'], _format: str) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
89def zanj_external_serialize( 90 jser: _ZANJ_pre, 91 data: Any, 92 path: ObjectPath, 93 item_type: ExternalItemType, 94 _format: str, 95) -> JSONitem: 96 """stores a numpy array or jsonl externally in a ZANJ object 97 98 # Parameters: 99 - `jser: ZANJ` 100 - `data: Any` 101 - `path: ObjectPath` 102 - `item_type: ExternalItemType` 103 104 # Returns: 105 - `JSONitem` 106 json data with reference 107 108 # Modifies: 109 - modifies `jser._externals` 110 """ 111 # get the path, make sure its unique 112 assert isinstance(path, tuple), ( 113 f"path must be a tuple, got {type(path) = } {path = }" 114 ) 115 joined_path: str = "/".join([str(p) for p in path]) 116 archive_path: str = f"{joined_path}.{item_type}" 117 118 if archive_path in jser._externals: 119 raise ValueError(f"external path {archive_path} already exists!") 120 if any([p.startswith(joined_path) for p in jser._externals.keys()]): 121 raise ValueError(f"external path {joined_path} is a prefix of another path!") 122 123 # process the data if needed, assemble metadata 124 data_new: Any = data 125 output: dict = { 126 _FORMAT_KEY: _format, 127 _REF_KEY: archive_path, 128 } 129 if item_type == "npy": 130 # check type 131 data_type_str: str = str(type(data)) 132 if data_type_str == "<class 'torch.Tensor'>": 133 # detach and convert 134 data_new = data.detach().cpu().numpy() 135 elif data_type_str == "<class 'numpy.ndarray'>": 136 pass 137 else: 138 # if not a numpy array, except 139 raise TypeError(f"expected numpy.ndarray, got {data_type_str}") 140 # get metadata 141 output.update(arr_metadata(data)) 142 elif item_type.startswith("jsonl"): 143 # check via mro to avoid importing pandas 144 if any("pandas.core.frame.DataFrame" in str(t) for t in data.__class__.__mro__): 145 output["columns"] = data.columns.tolist() 146 data_new = data.to_dict(orient="records") 147 elif isinstance(data, (list, tuple, Iterable, Sequence)): 148 data_new = [ 149 jser.json_serialize(item, tuple(path) + (i,)) 150 for i, item in enumerate(data) 151 ] 152 else: 153 raise TypeError( 154 f"expected list or pandas.DataFrame for jsonl, got {type(data)}" 155 ) 156 157 if all([isinstance(item, dict) for item in data_new]): 158 output.update(jsonl_metadata(data_new)) 159 160 # store the item for external serialization 161 jser._externals[archive_path] = ExternalItem( 162 item_type=item_type, 163 data=data_new, 164 path=path, 165 ) 166 167 return output
stores a numpy array or jsonl externally in a ZANJ object
Parameters:
jser: ZANJ
data: Any
path: ObjectPath
item_type: ExternalItemType
Returns:
JSONitem
json data with reference
Modifies:
- modifies
jser._externals
DEFAULT_SERIALIZER_HANDLERS_ZANJ: None =
(ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray:external', desc='external numpy array', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor:external', desc='external torch tensor', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='list:external', desc='external list', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='tuple:external', desc='external tuple', source_pckg='zanj'), ZANJSerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame:external', desc='external pandas DataFrame', source_pckg='zanj'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='base types', desc='base types (bool, int, float, str, None)'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dictionaries', desc='dictionaries'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(list, tuple) -> list', desc='lists and tuples as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function _serialize_override_serialize_func>, uid='.serialize override', desc='objects with .serialize method'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='namedtuple -> dict', desc='namedtuples as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='dataclass -> dict', desc='dataclasses as dicts'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='path -> str', desc='Path objects as posix strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='obj -> str(obj)', desc='directly serialize objects in `SERIALIZE_DIRECT_AS_STR` to strings'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='numpy.ndarray', desc='numpy arrays'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='torch.Tensor', desc='pytorch tensors'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='pandas.DataFrame', desc='pandas DataFrames'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='(set, list, tuple, Iterable) -> list', desc='sets, lists, tuples, and Iterables as lists'), SerializerHandler(check=<function <lambda>>, serialize_func=<function <lambda>>, uid='fallback', desc='fallback handler -- serialize object attributes and special functions as strings'))