Coverage for tests\test_zanj_torch_cfgmismatch.py: 95%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-14 12:57 -0700

1from __future__ import annotations 

2 

3from typing import Any 

4 

5import torch 

6from muutils.json_serialize import ( 

7 SerializableDataclass, 

8 serializable_dataclass, 

9 serializable_field, 

10) 

11 

12from zanj.torchutil import ( 

13 ConfigMismatchException, 

14 ConfiguredModel, 

15 assert_model_cfg_equality, 

16 set_config_class, 

17) 

18 

19# Assuming required imports and classes are present (including ConfiguredModel, MyGPTConfig, and MyGPT) 

20 

21 

22@serializable_dataclass 

23class MyGPTConfig(SerializableDataclass): 

24 """basic test GPT config""" 

25 

26 n_layers: int 

27 n_heads: int 

28 embedding_size: int 

29 n_positions: int 

30 n_vocab: int 

31 junk_data: Any = serializable_field(default_factory=dict) 

32 

33 

34@set_config_class(MyGPTConfig) 

35class MyGPT(ConfiguredModel[MyGPTConfig]): 

36 """basic "GPT" model""" 

37 

38 def __init__(self, config: MyGPTConfig): 

39 super().__init__(config) 

40 self.transformer = torch.nn.Linear(config.embedding_size, config.n_vocab) 

41 

42 def forward(self, x): 

43 return self.transformer(x) 

44 

45 

46def test_config_mismatch_exception_direct(): 

47 msg = "Configs don't match" 

48 diff = {"model_cfg": {"are_weights_processed": {"self": False, "other": True}}} 

49 

50 exc = ConfigMismatchException(msg, diff) 

51 assert exc.diff == diff 

52 assert ( 

53 str(exc) 

54 == r"Configs don't match: {'model_cfg': {'are_weights_processed': {'self': False, 'other': True}}}" 

55 ) 

56 

57 

58def test_equal_configs(): 

59 config = MyGPTConfig( 

60 n_layers=2, 

61 n_heads=2, 

62 embedding_size=16, 

63 n_positions=16, 

64 n_vocab=128, 

65 junk_data={"a": 1, "b": 2}, 

66 ) 

67 

68 model_a = MyGPT(config) 

69 model_b = MyGPT(config) 

70 

71 assert_model_cfg_equality(model_a, model_b) 

72 

73 

74def test_unequal_configs(): 

75 config_a = MyGPTConfig( 

76 n_layers=2, 

77 n_heads=2, 

78 embedding_size=16, 

79 n_positions=16, 

80 n_vocab=128, 

81 junk_data={"a": 1, "b": 2}, 

82 ) 

83 # a different config 

84 config_b = MyGPTConfig( 

85 n_layers=3, 

86 n_heads=2, 

87 embedding_size=16, 

88 n_positions=16, 

89 n_vocab=128, 

90 junk_data={"a": 7, "something": "or other"}, 

91 ) 

92 

93 model_a = MyGPT(config_a) 

94 model_b = MyGPT(config_b) 

95 

96 try: 

97 assert_model_cfg_equality(model_a, model_b) 

98 except ConfigMismatchException as exc: 

99 assert exc.diff == { 

100 "n_layers": {"self": 2, "other": 3}, 

101 "junk_data": { 

102 "self": {"a": 1, "b": 2}, 

103 "other": {"a": 7, "something": "or other"}, 

104 }, 

105 } 

106 else: 

107 raise AssertionError("Expected a ConfigMismatchException!") 

108 

109 

110def test_unequal_configs_2(): 

111 config_a = MyGPTConfig( 

112 n_layers=2, 

113 n_heads=2, 

114 embedding_size=16, 

115 n_positions=16, 

116 n_vocab=128, 

117 junk_data={"a": 1, "b": 2}, 

118 ) 

119 # a different config 

120 config_b = MyGPTConfig( 

121 n_layers=3, 

122 n_heads=2, 

123 embedding_size=16, 

124 n_positions=16, 

125 n_vocab=128, 

126 junk_data="this isnt even a dict lol", 

127 ) 

128 

129 model_a = MyGPT(config_a) 

130 model_b = MyGPT(config_b) 

131 

132 try: 

133 assert_model_cfg_equality(model_a, model_b) 

134 except ConfigMismatchException as exc: 

135 assert exc.diff == { 

136 "n_layers": {"self": 2, "other": 3}, 

137 "junk_data": { 

138 "self": {"a": 1, "b": 2}, 

139 "other": "this isnt even a dict lol", 

140 }, 

141 } 

142 else: 

143 raise AssertionError("Expected a ConfigMismatchException!") 

144 

145 

146def test_incorrect_instance(): 

147 config = MyGPTConfig( 

148 n_layers=2, 

149 n_heads=2, 

150 embedding_size=16, 

151 n_positions=16, 

152 n_vocab=128, 

153 ) 

154 

155 model_a = MyGPT(config) 

156 model_b = "Not a ConfiguredModel instance" 

157 

158 try: 

159 assert_model_cfg_equality(model_a, model_b) # type: ignore 

160 except AssertionError as exc: 

161 assert str(exc) == "model_b must be a ConfiguredModel"