Coverage for tests\test_zanj_torch.py: 95%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3from pathlib import Path 

4 

5import numpy as np 

6import torch 

7from muutils.json_serialize import ( 

8 SerializableDataclass, 

9 serializable_dataclass, 

10 serializable_field, 

11) 

12from muutils.tensor_utils import compare_state_dicts 

13 

14from zanj import ZANJ 

15from zanj.torchutil import ( 

16 ConfiguredModel, 

17 assert_model_exact_equality, 

18 set_config_class, 

19) 

20 

21np.random.seed(0) 

22 

23TEST_DATA_PATH: Path = Path("tests/junk_data") 

24 

25 

26def test_torch_configmodel_minimal(): 

27 @serializable_dataclass 

28 class MyNNConfig(SerializableDataclass): 

29 n_layers: int 

30 

31 @set_config_class(MyNNConfig) 

32 class MyNN(ConfiguredModel[MyNNConfig]): 

33 def __init__(self, config: MyNNConfig): 

34 super().__init__(config) 

35 

36 self.layer = torch.nn.Linear(config.n_layers, 1) 

37 

38 def forward(self, x): 

39 return self.layer(x) 

40 

41 config: MyNNConfig = MyNNConfig( 

42 n_layers=2, 

43 ) 

44 

45 model: MyNN = MyNN(config) 

46 

47 fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj" 

48 ZANJ().save(model, fname) 

49 

50 print(f"saved model to {fname}") 

51 print(f"{model.zanj_model_config = }") 

52 

53 # try to load the model 

54 model2: MyNN = MyNN.read(fname) 

55 print(f"loaded model from {fname}") 

56 print(f"{model2.zanj_model_config = }") 

57 

58 assert model.zanj_model_config == model2.zanj_model_config 

59 assert model.training_records == model2.training_records 

60 

61 compare_state_dicts(model.state_dict(), model2.state_dict()) 

62 assert_model_exact_equality(model, model2) 

63 

64 model3: MyNN = ZANJ().read(fname) 

65 print(f"loaded model from {fname}") 

66 print(f"{model3.zanj_model_config = }") 

67 

68 assert model.zanj_model_config == model3.zanj_model_config 

69 assert model.training_records == model3.training_records 

70 

71 compare_state_dicts(model.state_dict(), model3.state_dict()) 

72 assert_model_exact_equality(model, model3) 

73 

74 

75def test_torch_configmodel(): 

76 import torch 

77 

78 from zanj.torchutil import ConfiguredModel, set_config_class 

79 

80 @serializable_dataclass 

81 class MyGPTConfig(SerializableDataclass): 

82 """basic test GPT config""" 

83 

84 n_layers: int 

85 n_heads: int 

86 embedding_size: int 

87 n_positions: int 

88 n_vocab: int 

89 

90 loss_factory: torch.nn.modules.loss._Loss = serializable_field( 

91 default_factory=lambda: torch.nn.CrossEntropyLoss, 

92 serialization_fn=lambda x: x.__name__, 

93 loading_fn=lambda x: getattr(torch.nn, x["loss_factory"]), 

94 ) 

95 

96 loss_kwargs: dict = serializable_field(default_factory=dict) 

97 

98 @property 

99 def loss(self): 

100 return self.loss_factory(**self.loss_kwargs) 

101 

102 optim_factory: torch.optim.Optimizer = serializable_field( 

103 default_factory=lambda: torch.optim.Adam, 

104 serialization_fn=lambda x: x.__name__, 

105 loading_fn=lambda x: getattr(torch.optim, x["optim_factory"]), 

106 ) 

107 

108 optim_kwargs: dict = serializable_field(default_factory=dict) 

109 

110 def optim(self, model): 

111 return self.optim_factory(model.parameters(), **self.optim_kwargs) # type: ignore 

112 

113 @set_config_class(MyGPTConfig) 

114 class MyGPT(ConfiguredModel[MyGPTConfig]): 

115 """basic GPT model""" 

116 

117 def __init__(self, config: MyGPTConfig): 

118 super().__init__(config) 

119 

120 # implementation of a GPT style model with decoders only 

121 

122 self.transformer = torch.nn.Transformer( 

123 d_model=config.embedding_size, 

124 nhead=config.n_heads, 

125 num_encoder_layers=0, 

126 num_decoder_layers=config.n_layers, 

127 ) 

128 

129 def forward(self, x): 

130 return self.transformer(x) 

131 

132 config: MyGPTConfig = MyGPTConfig( 

133 n_layers=2, 

134 n_heads=2, 

135 embedding_size=16, 

136 n_positions=16, 

137 n_vocab=128, 

138 loss_factory=torch.nn.CrossEntropyLoss, 

139 ) 

140 

141 model: MyGPT = MyGPT(config) 

142 model.training_records = dict(loss=[3, 2, 1], accuracy=[0.1, 0.2, 0.3]) 

143 

144 fname: Path = TEST_DATA_PATH / "test_torch_configmodel.zanj" 

145 ZANJ().save(model, fname) 

146 

147 print(f"saved model to {fname}") 

148 print(f"{model.zanj_model_config = }") 

149 

150 # try to load the model 

151 model2: MyGPT = MyGPT.read(fname) 

152 print(f"loaded model from {fname}") 

153 print(f"{model2.zanj_model_config = }") 

154 

155 assert model.zanj_model_config == model2.zanj_model_config 

156 assert model.training_records == model2.training_records 

157 

158 compare_state_dicts(model.state_dict(), model2.state_dict()) 

159 assert_model_exact_equality(model, model2) 

160 

161 model3: MyGPT = ZANJ().read(fname) 

162 print(f"loaded model from {fname}") 

163 print(f"{model3.zanj_model_config = }") 

164 

165 assert model.zanj_model_config == model3.zanj_model_config 

166 assert model.training_records == model3.training_records 

167 

168 compare_state_dicts(model.state_dict(), model3.state_dict()) 

169 assert_model_exact_equality(model, model3)