Coverage for tests\test_isolate_zanj_handler_store.py: 98%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import json 

4import typing 

5import zipfile 

6from pathlib import Path 

7 

8import numpy as np 

9from muutils.json_serialize import ( 

10 SerializableDataclass, 

11 serializable_dataclass, 

12 serializable_field, 

13) 

14 

15from zanj import ZANJ 

16from zanj.loading import LOADER_MAP 

17 

18np.random.seed(0) 

19 

20# pylint: disable=missing-function-docstring,missing-class-docstring 

21 

22TEST_DATA_PATH: Path = Path("tests/junk_data") 

23 

24 

25@serializable_dataclass 

26class Basic(SerializableDataclass): 

27 a: str 

28 q: int = 42 

29 c: typing.List[int] = serializable_field(default_factory=list) 

30 

31 

32def test_Basic(): 

33 instance = Basic("hello", 42, [1, 2, 3]) 

34 

35 z = ZANJ() 

36 path = TEST_DATA_PATH / "test_Basic.zanj" 

37 z.save(instance, path) 

38 recovered = z.read(path) 

39 assert instance == recovered 

40 

41 

42print(list(LOADER_MAP.keys())) 

43 

44 

45@serializable_dataclass 

46class ModelCfg(SerializableDataclass): 

47 name: str 

48 num_layers: int 

49 hidden_size: int 

50 dropout: float 

51 

52 

53print(list(LOADER_MAP.keys())) 

54 

55 

56def test_isolate_handlers(): 

57 instance = ModelCfg("lstm", 3, 128, 0.1) 

58 

59 print(list(LOADER_MAP.keys())) 

60 

61 z = ZANJ() 

62 path = TEST_DATA_PATH / "00-test_isolate_handlers.zanj" 

63 z.save(instance, path) 

64 recovered = z.read(path) 

65 assert instance == recovered 

66 

67 assert "Basic(SerializableDataclass)" in LOADER_MAP 

68 assert "ModelCfg(SerializableDataclass)" in LOADER_MAP 

69 

70 # check they are in the zanj file 

71 with zipfile.ZipFile(path, "r") as zfile: 

72 zmeta = json.load(zfile.open("__zanj_meta__.json", "r")) 

73 assert "Basic(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"] 

74 assert "ModelCfg(SerializableDataclass)" in zmeta["zanj_cfg"]["load_handlers"] 

75 

76 

77if __name__ == "__main__": 

78 test_isolate_handlers()