docs for muutils v0.6.21
View Source on GitHub

muutils.json_serialize.array

this utilities module handles serialization and loading of numpy and torch arrays as json

  • array_list_meta is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
  • array_b64_meta is the most efficient, but is not human readable.
  • external is mostly for use in ZANJ

  1"""this utilities module handles serialization and loading of numpy and torch arrays as json
  2
  3- `array_list_meta` is less efficient (arrays are stored as nested lists), but preserves both metadata and human readability.
  4- `array_b64_meta` is the most efficient, but is not human readable.
  5- `external` is mostly for use in [`ZANJ`](https://github.com/mivanit/ZANJ)
  6
  7"""
  8
  9from __future__ import annotations
 10
 11import base64
 12import typing
 13import warnings
 14from typing import Any, Iterable, Literal, Optional, Sequence
 15
 16try:
 17    import numpy as np
 18except ImportError as e:
 19    warnings.warn(
 20        f"numpy is not installed, array serialization will not work: \n{e}",
 21        ImportWarning,
 22    )
 23
 24from muutils.json_serialize.util import JSONitem
 25
 26# pylint: disable=unused-argument
 27
 28ArrayMode = Literal[
 29    "list",
 30    "array_list_meta",
 31    "array_hex_meta",
 32    "array_b64_meta",
 33    "external",
 34    "zero_dim",
 35]
 36
 37
 38def array_n_elements(arr) -> int:  # type: ignore[name-defined]
 39    """get the number of elements in an array"""
 40    if isinstance(arr, np.ndarray):
 41        return arr.size
 42    elif str(type(arr)) == "<class 'torch.Tensor'>":
 43        return arr.nelement()
 44    else:
 45        raise TypeError(f"invalid type: {type(arr)}")
 46
 47
 48def arr_metadata(arr) -> dict[str, list[int] | str | int]:
 49    """get metadata for a numpy array"""
 50    return {
 51        "shape": list(arr.shape),
 52        "dtype": (
 53            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
 54        ),
 55        "n_elements": array_n_elements(arr),
 56    }
 57
 58
 59def serialize_array(
 60    jser: "JsonSerializer",  # type: ignore[name-defined] # noqa: F821
 61    arr: np.ndarray,
 62    path: str | Sequence[str | int],
 63    array_mode: ArrayMode | None = None,
 64) -> JSONitem:
 65    """serialize a numpy or pytorch array in one of several modes
 66
 67    if the object is zero-dimensional, simply get the unique item
 68
 69    `array_mode: ArrayMode` can be one of:
 70    - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
 71    - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
 72    - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
 73    - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
 74
 75    for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
 76    ```
 77    {
 78        "__format__": <array_list_meta|array_hex_meta>,
 79        "shape": arr.shape,
 80        "dtype": str(arr.dtype),
 81        "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
 82    }
 83    ```
 84
 85    # Parameters:
 86     - `arr : Any` array to serialize
 87     - `array_mode : ArrayMode` mode in which to serialize the array
 88       (defaults to `None` and inheriting from `jser: JsonSerializer`)
 89
 90    # Returns:
 91     - `JSONitem`
 92       json serialized array
 93
 94    # Raises:
 95     - `KeyError` : if the array mode is not valid
 96    """
 97
 98    if array_mode is None:
 99        array_mode = jser.array_mode
100
101    arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
102    arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)
103
104    # handle zero-dimensional arrays
105    if len(arr.shape) == 0:
106        return {
107            "__format__": f"{arr_type}:zero_dim",
108            "data": arr.item(),
109            **arr_metadata(arr),
110        }
111
112    if array_mode == "array_list_meta":
113        return {
114            "__format__": f"{arr_type}:array_list_meta",
115            "data": arr_np.tolist(),
116            **arr_metadata(arr_np),
117        }
118    elif array_mode == "list":
119        return arr_np.tolist()
120    elif array_mode == "array_hex_meta":
121        return {
122            "__format__": f"{arr_type}:array_hex_meta",
123            "data": arr_np.tobytes().hex(),
124            **arr_metadata(arr_np),
125        }
126    elif array_mode == "array_b64_meta":
127        return {
128            "__format__": f"{arr_type}:array_b64_meta",
129            "data": base64.b64encode(arr_np.tobytes()).decode(),
130            **arr_metadata(arr_np),
131        }
132    else:
133        raise KeyError(f"invalid array_mode: {array_mode}")
134
135
136def infer_array_mode(arr: JSONitem) -> ArrayMode:
137    """given a serialized array, infer the mode
138
139    assumes the array was serialized via `serialize_array()`
140    """
141    if isinstance(arr, typing.Mapping):
142        fmt: str = arr.get("__format__", "")
143        if fmt.endswith(":array_list_meta"):
144            if not isinstance(arr["data"], Iterable):
145                raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}")
146            return "array_list_meta"
147        elif fmt.endswith(":array_hex_meta"):
148            if not isinstance(arr["data"], str):
149                raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}")
150            return "array_hex_meta"
151        elif fmt.endswith(":array_b64_meta"):
152            if not isinstance(arr["data"], str):
153                raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}")
154            return "array_b64_meta"
155        elif fmt.endswith(":external"):
156            return "external"
157        elif fmt.endswith(":zero_dim"):
158            return "zero_dim"
159        else:
160            raise ValueError(f"invalid format: {arr}")
161    elif isinstance(arr, list):
162        return "list"
163    else:
164        raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")
165
166
167def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
168    """load a json-serialized array, infer the mode if not specified"""
169    # return arr if its already a numpy array
170    if isinstance(arr, np.ndarray) and array_mode is None:
171        return arr
172
173    # try to infer the array_mode
174    array_mode_inferred: ArrayMode = infer_array_mode(arr)
175    if array_mode is None:
176        array_mode = array_mode_inferred
177    elif array_mode != array_mode_inferred:
178        warnings.warn(
179            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
180        )
181
182    # actually load the array
183    if array_mode == "array_list_meta":
184        assert isinstance(
185            arr, typing.Mapping
186        ), f"invalid list format: {type(arr) = }\n{arr = }"
187
188        data = np.array(arr["data"], dtype=arr["dtype"])
189        if tuple(arr["shape"]) != tuple(data.shape):
190            raise ValueError(f"invalid shape: {arr}")
191        return data
192
193    elif array_mode == "array_hex_meta":
194        assert isinstance(
195            arr, typing.Mapping
196        ), f"invalid list format: {type(arr) = }\n{arr = }"
197
198        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])
199        return data.reshape(arr["shape"])
200
201    elif array_mode == "array_b64_meta":
202        assert isinstance(
203            arr, typing.Mapping
204        ), f"invalid list format: {type(arr) = }\n{arr = }"
205
206        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])
207        return data.reshape(arr["shape"])
208
209    elif array_mode == "list":
210        assert isinstance(
211            arr, typing.Sequence
212        ), f"invalid list format: {type(arr) = }\n{arr = }"
213
214        return np.array(arr)
215    elif array_mode == "external":
216        # assume ZANJ has taken care of it
217        assert isinstance(arr, typing.Mapping)
218        if "data" not in arr:
219            raise KeyError(
220                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
221            )
222        return arr["data"]
223    elif array_mode == "zero_dim":
224        assert isinstance(arr, typing.Mapping)
225        data = np.array(arr["data"])
226        if tuple(arr["shape"]) != tuple(data.shape):
227            raise ValueError(f"invalid shape: {arr}")
228        return data
229    else:
230        raise ValueError(f"invalid array_mode: {array_mode}")

ArrayMode = typing.Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']
def array_n_elements(arr) -> int:
39def array_n_elements(arr) -> int:  # type: ignore[name-defined]
40    """get the number of elements in an array"""
41    if isinstance(arr, np.ndarray):
42        return arr.size
43    elif str(type(arr)) == "<class 'torch.Tensor'>":
44        return arr.nelement()
45    else:
46        raise TypeError(f"invalid type: {type(arr)}")

get the number of elements in an array

def arr_metadata(arr) -> dict[str, list[int] | str | int]:
49def arr_metadata(arr) -> dict[str, list[int] | str | int]:
50    """get metadata for a numpy array"""
51    return {
52        "shape": list(arr.shape),
53        "dtype": (
54            arr.dtype.__name__ if hasattr(arr.dtype, "__name__") else str(arr.dtype)
55        ),
56        "n_elements": array_n_elements(arr),
57    }

get metadata for a numpy array

def serialize_array( jser: "'JsonSerializer'", arr: numpy.ndarray, path: Union[str, Sequence[str | int]], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Union[bool, int, float, str, list, Dict[str, Any], NoneType]:
 60def serialize_array(
 61    jser: "JsonSerializer",  # type: ignore[name-defined] # noqa: F821
 62    arr: np.ndarray,
 63    path: str | Sequence[str | int],
 64    array_mode: ArrayMode | None = None,
 65) -> JSONitem:
 66    """serialize a numpy or pytorch array in one of several modes
 67
 68    if the object is zero-dimensional, simply get the unique item
 69
 70    `array_mode: ArrayMode` can be one of:
 71    - `list`: serialize as a list of values, no metadata (equivalent to `arr.tolist()`)
 72    - `array_list_meta`: serialize dict with metadata, actual list under the key `data`
 73    - `array_hex_meta`: serialize dict with metadata, actual hex string under the key `data`
 74    - `array_b64_meta`: serialize dict with metadata, actual base64 string under the key `data`
 75
 76    for `array_list_meta`, `array_hex_meta`, and `array_b64_meta`, the serialized object is:
 77    ```
 78    {
 79        "__format__": <array_list_meta|array_hex_meta>,
 80        "shape": arr.shape,
 81        "dtype": str(arr.dtype),
 82        "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
 83    }
 84    ```
 85
 86    # Parameters:
 87     - `arr : Any` array to serialize
 88     - `array_mode : ArrayMode` mode in which to serialize the array
 89       (defaults to `None` and inheriting from `jser: JsonSerializer`)
 90
 91    # Returns:
 92     - `JSONitem`
 93       json serialized array
 94
 95    # Raises:
 96     - `KeyError` : if the array mode is not valid
 97    """
 98
 99    if array_mode is None:
100        array_mode = jser.array_mode
101
102    arr_type: str = f"{type(arr).__module__}.{type(arr).__name__}"
103    arr_np: np.ndarray = arr if isinstance(arr, np.ndarray) else np.array(arr)
104
105    # handle zero-dimensional arrays
106    if len(arr.shape) == 0:
107        return {
108            "__format__": f"{arr_type}:zero_dim",
109            "data": arr.item(),
110            **arr_metadata(arr),
111        }
112
113    if array_mode == "array_list_meta":
114        return {
115            "__format__": f"{arr_type}:array_list_meta",
116            "data": arr_np.tolist(),
117            **arr_metadata(arr_np),
118        }
119    elif array_mode == "list":
120        return arr_np.tolist()
121    elif array_mode == "array_hex_meta":
122        return {
123            "__format__": f"{arr_type}:array_hex_meta",
124            "data": arr_np.tobytes().hex(),
125            **arr_metadata(arr_np),
126        }
127    elif array_mode == "array_b64_meta":
128        return {
129            "__format__": f"{arr_type}:array_b64_meta",
130            "data": base64.b64encode(arr_np.tobytes()).decode(),
131            **arr_metadata(arr_np),
132        }
133    else:
134        raise KeyError(f"invalid array_mode: {array_mode}")

serialize a numpy or pytorch array in one of several modes

if the object is zero-dimensional, simply get the unique item

array_mode: ArrayMode can be one of:

  • list: serialize as a list of values, no metadata (equivalent to arr.tolist())
  • array_list_meta: serialize dict with metadata, actual list under the key data
  • array_hex_meta: serialize dict with metadata, actual hex string under the key data
  • array_b64_meta: serialize dict with metadata, actual base64 string under the key data

for array_list_meta, array_hex_meta, and array_b64_meta, the serialized object is:

{
    "__format__": <array_list_meta|array_hex_meta>,
    "shape": arr.shape,
    "dtype": str(arr.dtype),
    "data": <arr.tolist()|arr.tobytes().hex()|base64.b64encode(arr.tobytes()).decode()>,
}

Parameters:

  • arr : Any array to serialize
  • array_mode : ArrayMode mode in which to serialize the array (defaults to None and inheriting from jser: JsonSerializer)

Returns:

  • JSONitem json serialized array

Raises:

  • KeyError : if the array mode is not valid
def infer_array_mode( arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType]) -> Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']:
137def infer_array_mode(arr: JSONitem) -> ArrayMode:
138    """given a serialized array, infer the mode
139
140    assumes the array was serialized via `serialize_array()`
141    """
142    if isinstance(arr, typing.Mapping):
143        fmt: str = arr.get("__format__", "")
144        if fmt.endswith(":array_list_meta"):
145            if not isinstance(arr["data"], Iterable):
146                raise ValueError(f"invalid list format: {type(arr['data']) = }\t{arr}")
147            return "array_list_meta"
148        elif fmt.endswith(":array_hex_meta"):
149            if not isinstance(arr["data"], str):
150                raise ValueError(f"invalid hex format: {type(arr['data']) = }\t{arr}")
151            return "array_hex_meta"
152        elif fmt.endswith(":array_b64_meta"):
153            if not isinstance(arr["data"], str):
154                raise ValueError(f"invalid b64 format: {type(arr['data']) = }\t{arr}")
155            return "array_b64_meta"
156        elif fmt.endswith(":external"):
157            return "external"
158        elif fmt.endswith(":zero_dim"):
159            return "zero_dim"
160        else:
161            raise ValueError(f"invalid format: {arr}")
162    elif isinstance(arr, list):
163        return "list"
164    else:
165        raise ValueError(f"cannot infer array_mode from\t{type(arr) = }\n{arr = }")

given a serialized array, infer the mode

assumes the array was serialized via serialize_array()

def load_array( arr: Union[bool, int, float, str, list, Dict[str, Any], NoneType], array_mode: Optional[Literal['list', 'array_list_meta', 'array_hex_meta', 'array_b64_meta', 'external', 'zero_dim']] = None) -> Any:
168def load_array(arr: JSONitem, array_mode: Optional[ArrayMode] = None) -> Any:
169    """load a json-serialized array, infer the mode if not specified"""
170    # return arr if its already a numpy array
171    if isinstance(arr, np.ndarray) and array_mode is None:
172        return arr
173
174    # try to infer the array_mode
175    array_mode_inferred: ArrayMode = infer_array_mode(arr)
176    if array_mode is None:
177        array_mode = array_mode_inferred
178    elif array_mode != array_mode_inferred:
179        warnings.warn(
180            f"array_mode {array_mode} does not match inferred array_mode {array_mode_inferred}"
181        )
182
183    # actually load the array
184    if array_mode == "array_list_meta":
185        assert isinstance(
186            arr, typing.Mapping
187        ), f"invalid list format: {type(arr) = }\n{arr = }"
188
189        data = np.array(arr["data"], dtype=arr["dtype"])
190        if tuple(arr["shape"]) != tuple(data.shape):
191            raise ValueError(f"invalid shape: {arr}")
192        return data
193
194    elif array_mode == "array_hex_meta":
195        assert isinstance(
196            arr, typing.Mapping
197        ), f"invalid list format: {type(arr) = }\n{arr = }"
198
199        data = np.frombuffer(bytes.fromhex(arr["data"]), dtype=arr["dtype"])
200        return data.reshape(arr["shape"])
201
202    elif array_mode == "array_b64_meta":
203        assert isinstance(
204            arr, typing.Mapping
205        ), f"invalid list format: {type(arr) = }\n{arr = }"
206
207        data = np.frombuffer(base64.b64decode(arr["data"]), dtype=arr["dtype"])
208        return data.reshape(arr["shape"])
209
210    elif array_mode == "list":
211        assert isinstance(
212            arr, typing.Sequence
213        ), f"invalid list format: {type(arr) = }\n{arr = }"
214
215        return np.array(arr)
216    elif array_mode == "external":
217        # assume ZANJ has taken care of it
218        assert isinstance(arr, typing.Mapping)
219        if "data" not in arr:
220            raise KeyError(
221                f"invalid external array, expected key 'data', got keys: '{list(arr.keys())}' and arr: {arr}"
222            )
223        return arr["data"]
224    elif array_mode == "zero_dim":
225        assert isinstance(arr, typing.Mapping)
226        data = np.array(arr["data"])
227        if tuple(arr["shape"]) != tuple(data.shape):
228            raise ValueError(f"invalid shape: {arr}")
229        return data
230    else:
231        raise ValueError(f"invalid array_mode: {array_mode}")

load a json-serialized array, infer the mode if not specified