Coverage for tests\test_zanj_sdc_modelcfg.py: 100%

90 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3import json 

4import sys 

5import typing 

6from pathlib import Path 

7 

8import numpy as np 

9import torch 

10from muutils.json_serialize import ( 

11 SerializableDataclass, 

12 serializable_dataclass, 

13 serializable_field, 

14) 

15 

16from zanj import ZANJ 

17 

18np.random.seed(0) 

19 

20# pylint: disable=missing-function-docstring,missing-class-docstring 

21 

22TEST_DATA_PATH: Path = Path("tests/junk_data") 

23 

24 

25SUPPORTS_KW_ONLY: bool = bool(sys.version_info >= (3, 10)) 

26 

27 

28@serializable_dataclass 

29class MyModelCfg(SerializableDataclass): 

30 name: str 

31 num_layers: int 

32 hidden_size: int 

33 dropout: float 

34 

35 

36@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY) 

37class TrainCfg(SerializableDataclass): 

38 name: str 

39 weight_decay: float 

40 optimizer: typing.Type[torch.optim.Optimizer] = serializable_field( 

41 default_factory=lambda: torch.optim.Adam, 

42 serialization_fn=lambda x: x.__name__, 

43 loading_fn=lambda data: getattr(torch.optim, data["optimizer"]), 

44 ) 

45 optimizer_kwargs: typing.Dict[str, typing.Any] = serializable_field( # type: ignore 

46 default_factory=lambda: dict(lr=0.000001) 

47 ) 

48 

49 

50class CustomCfg: 

51 def __init__(self, x: int, y: str): 

52 self.x = x 

53 self.y = y 

54 

55 def __eq__(self, other): 

56 return self.x == other.x and self.y == other.y 

57 

58 def serialize(self): 

59 return {"x": self.x, "y": self.y} 

60 

61 @classmethod 

62 def load(cls, data): 

63 return cls( 

64 **{ 

65 "x": data["x"], 

66 "y": data["y"], 

67 } 

68 ) 

69 

70 

71@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY) 

72class BasicCfgHolder(SerializableDataclass): 

73 model: MyModelCfg 

74 optimizer: TrainCfg 

75 custom: typing.Optional[CustomCfg] = serializable_field( 

76 default=None, 

77 serialization_fn=lambda x: x.serialize(), 

78 loading_fn=lambda data: CustomCfg.load(data["custom"]), 

79 ) 

80 

81 

82instance_basic: BasicCfgHolder = BasicCfgHolder( # type: ignore 

83 model=MyModelCfg("lstm", 3, 128, 0.1), # type: ignore 

84 optimizer=TrainCfg( # type: ignore 

85 name="adamw", 

86 weight_decay=0.2, 

87 optimizer=torch.optim.AdamW, 

88 optimizer_kwargs=dict(lr=0.0001), 

89 ), 

90 custom=CustomCfg(42, "forty-two"), 

91) 

92 

93 

94def test_config_holder(): 

95 instance_stored = instance_basic.serialize() 

96 with open(TEST_DATA_PATH / "test_config_holder.json", "w") as f: 

97 json.dump(instance_stored, f, indent="\t") 

98 with open(TEST_DATA_PATH / "test_config_holder.json", "r") as f: 

99 instance_stored_read = json.load(f) 

100 recovered = BasicCfgHolder.load(instance_stored_read) 

101 assert isinstance(recovered.model, MyModelCfg) 

102 assert isinstance(recovered.optimizer, TrainCfg) 

103 assert isinstance(recovered.custom, CustomCfg) 

104 assert recovered.custom.x == 42 

105 assert instance_basic == recovered 

106 

107 

108def test_config_holder_zanj(): 

109 z = ZANJ() 

110 path = TEST_DATA_PATH / "test_config_holder.zanj" 

111 z.save(instance_basic, path) 

112 recovered = z.read(path) 

113 assert isinstance(recovered.model, MyModelCfg) 

114 assert isinstance(recovered.optimizer, TrainCfg) 

115 assert isinstance(recovered.custom, CustomCfg) 

116 assert recovered.custom.x == 42 

117 assert instance_basic == recovered 

118 

119 

120@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY) 

121class BaseGPTConfig(SerializableDataclass): 

122 name: str 

123 act_fn: str 

124 d_model: int 

125 d_head: int 

126 n_layers: int 

127 

128 

129@serializable_dataclass(kw_only=SUPPORTS_KW_ONLY) 

130class AdvCfgHolder(SerializableDataclass): 

131 model_cfg: BaseGPTConfig 

132 name: str = serializable_field(default="default") 

133 tokenizer: typing.Optional[CustomCfg] = serializable_field( 

134 default=None, 

135 serialization_fn=lambda x: repr(x) if x is not None else None, 

136 loading_fn=lambda data: ( 

137 None if data["tokenizer"] is None else NotImplementedError 

138 ), 

139 ) 

140 

141 

142instance_adv: AdvCfgHolder = AdvCfgHolder( # type: ignore 

143 model_cfg=BaseGPTConfig( # type: ignore 

144 name="gpt2", 

145 act_fn="gelu", 

146 d_model=128, 

147 d_head=64, 

148 n_layers=3, 

149 ), 

150 tokenizer=None, 

151) 

152 

153 

154def test_adv_config_holder(): 

155 instance_stored = instance_adv.serialize() 

156 with open(TEST_DATA_PATH / "test_adv_config_holder.json", "w") as f: 

157 json.dump(instance_stored, f, indent="\t") 

158 recovered = AdvCfgHolder.load(instance_stored) 

159 assert isinstance(recovered.model_cfg, BaseGPTConfig) 

160 assert instance_adv == recovered 

161 

162 

163def test_adv_config_holder_zanj(): 

164 z = ZANJ() 

165 path = TEST_DATA_PATH / "test_adv_config_holder.zanj" 

166 z.save(instance_adv, path) 

167 recovered = z.read(path) 

168 assert isinstance(recovered.model_cfg, BaseGPTConfig) 

169 assert instance_adv == recovered