Coverage for test/test_util.py: 99%

179 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-05-04 13:14 +0700

1import pytest 

2import datetime 

3from copy import deepcopy 

4 

5from ..agent_model.util import (load_data_file, 

6 get_default_agent_data, 

7 get_default_currency_data, 

8 merge_json, 

9 recursively_clear_lists, 

10 evaluate_reference, 

11 pdf, 

12 sample_norm, 

13 sample_clipped_norm, 

14 sample_sigmoid, 

15 evaluate_growth, 

16 parse_data) 

17 

18class TestDataFilesHandling: 

19 def test_load_data_files(self): 

20 agent_desc = load_data_file('agent_desc.json') 

21 assert 'wheat' in agent_desc, 'Failed to load agent_desc' 

22 with pytest.raises(AssertionError): 

23 load_data_file('nonexistent_file.json') 

24 with pytest.raises(AssertionError): 

25 load_data_file('agent_desc.json', data_dir='nonexistent_dir') 

26 

27 def test_get_default_agent_data(self): 

28 wheat_data = get_default_agent_data('wheat') 

29 assert all([k in wheat_data for k in ['amount', 'storage', 'properties', 'storage', 'flows']]) 

30 

31 def test_get_default_currency_data(self): 

32 currency_data = get_default_currency_data() 

33 for k, v in currency_data.items(): 

34 assert 'currency_type' in v 

35 if v['currency_type'] == 'class': 

36 assert 'currencies' in v 

37 else: 

38 assert 'class' in v 

39 assert v['class'] in currency_data 

40 

41 def test_merge_json(self): 

42 default = {'a': 'red', 'b': 2, 'c': {'d': 3, 'e': 4}, 'f': [1, 2, 3]} 

43 to_merge = {'a': 'blue', 'c': {'d': 6}, 'f': [3, 4, 5]} 

44 merged = merge_json(default, to_merge) 

45 assert merged == {'a': 'blue', 'b': 2, 'c': {'d': 6, 'e': 4}, 'f': [1, 2, 3, 4, 5]} 

46 

47 def test_recursively_clear_lists(self): 

48 data = { 

49 'a': 'string', 

50 'b': 1, 

51 'c': 2.1, 

52 'd': ['e', 'f'], 

53 'g': {'h': 'string2', 

54 'i': ['j', 'k']}} 

55 data = recursively_clear_lists(data) 

56 assert data == { 

57 'a': 'string', 

58 'b': 1, 

59 'c': 2.1, 

60 'd': [], 

61 'g': {'h': 'string2', 

62 'i': []}} 

63 

64class MockAgent: 

65 def __init__(self, model): 

66 self.model = model 

67 self.attributes = {'test_attribute': 1} 

68 self.storage = {'test_currency_1': 1, 'test_currency_2': 2} 

69 self.flows = { 

70 'in': { 

71 'test_currency_1': { 

72 'value': 1, 

73 'connections': ['test_agent_2'] 

74 } 

75 }, 

76 'out': { 

77 'test_currency_1': { 

78 'value': 1, 

79 'connections': ['test_agent_2'] 

80 } 

81 }, 

82 } 

83 def view(self, view): 

84 if view in ('test_currency_1', 'test_currency_2'): 

85 return {view: self.storage[view]} 

86 elif view == 'test_currency_class': 

87 return deepcopy(self.storage) 

88 

89class MockModel: 

90 floating_point_precision = 6 

91 agents = {} 

92 currencies = { 

93 'test_currency_1': { 

94 'currency_type': 'currency', 

95 'class': 'test_currency_class' 

96 }, 

97 'test_currency_2': { 

98 'currency_type': 'currency', 

99 'class': 'test_currency_class' 

100 }, 

101 'test_currency_class': { 

102 'currency_type': 'class', 

103 'currencies': ['test_currency_1', 'test_currency_2'] 

104 } 

105 } 

106 

107@pytest.fixture(scope='function') 

108def test_model(): 

109 model = MockModel() 

110 test_agent_1 = MockAgent(model) 

111 test_agent_2 = MockAgent(model) 

112 test_agent_1.model = model 

113 test_agent_2.model = model 

114 model.agents = { 

115 'test_agent_1': test_agent_1, 

116 'test_agent_2': test_agent_2, 

117 } 

118 return model 

119 

120class TestEvaluateReference: 

121 def test_evaluate_reference_attribute(self, test_model): 

122 reference = { 

123 'path': 'test_attribute', 

124 'limit': '>', 

125 'value': 1 

126 } 

127 test_agent_1 = test_model.agents['test_agent_1'] 

128 assert not evaluate_reference(test_agent_1, reference) 

129 test_agent_1.attributes['test_attribute'] = 2 

130 assert evaluate_reference(test_agent_1, reference) 

131 

132 def test_evaluate_reference_storage(self, test_model): 

133 reference = { 

134 'path': 'test_currency_1', 

135 'limit': '>', 

136 'value': 1 

137 } 

138 test_agent_1 = test_model.agents['test_agent_1'] 

139 assert not evaluate_reference(test_agent_1, reference) 

140 test_agent_1.storage['test_currency_1'] = 2 

141 assert evaluate_reference(test_agent_1, reference) 

142 

143 def test_evaluate_reference_ratio(self, test_model): 

144 reference = { 

145 'path': 'test_currency_1_ratio', 

146 'limit': '>', 

147 'value': 0.5 

148 } 

149 test_agent_1 = test_model.agents['test_agent_1'] 

150 assert not evaluate_reference(test_agent_1, reference) 

151 reference['limit'] = '<' 

152 assert evaluate_reference(test_agent_1, reference) 

153 

154 def test_evaluate_reference_connected(self, test_model): 

155 reference = { 

156 'path': 'in_test_currency_1', 

157 'limit': '>', 

158 'value': 1 

159 } 

160 test_agent_1 = test_model.agents['test_agent_1'] 

161 assert not evaluate_reference(test_agent_1, reference) 

162 test_agent_2 = test_model.agents['test_agent_2'] 

163 test_agent_2.storage['test_currency_1'] = 2 

164 assert evaluate_reference(test_agent_1, reference) 

165 

166class TestGrowthFuncs: 

167 def test_growth_pdf(self): 

168 _cache = {} 

169 results = [pdf(x, 0.5, _cache) for x in range(-4, 5)] 

170 # middle value should be highest, symmetrical either side 

171 assert results[4] == max(results) 

172 for i in range(4): 

173 assert results[i] == results[-i-1] 

174 assert list(_cache.values()) == results 

175 

176 def test_growth_sample_norm(self): 

177 # Default: 0 < y < 1, x_center = 0.5 

178 n_samples = 100 

179 results = [sample_norm(x/100, n_samples=n_samples) for x in range(1, n_samples)] 

180 assert sum(results)/len(results) == pytest.approx(1, abs=0.02) 

181 # middle value should be highest, symmetrical either side 

182 midpoint = n_samples//2-1 

183 assert results[midpoint] == max(results) 

184 for i in range(midpoint): 

185 assert results[i] == pytest.approx(results[-i-1]) 

186 

187 # Shift center 

188 x_center = 0.25 

189 results = [sample_norm(x/1000, center=x_center) for x in range(1000)] 

190 assert sum(results)/len(results) == pytest.approx(1, abs=0.01) 

191 assert results[250] == max(results) 

192 

193 # TODO: Shift stdev 

194 

195 def test_growth_sample_clippped_norm(self): 

196 results = [sample_clipped_norm(x/10) for x in range(1, 10)] 

197 assert max(results) == 1 

198 assert results[4] == max(results) 

199 for i in range(4): 

200 assert results[i] == pytest.approx(results[-i-1]) 

201 

202 def test_growth_sample_sigmoid(self): 

203 results = [sample_sigmoid(x/1000) for x in range(1000)] 

204 assert results[-1] == max(results) 

205 assert all(results[i] <= results[i+1] for i in range(len(results)-1)) 

206 # Derivative (slope) is greatest at center 

207 derivatives = [results[i+1] - results[i] for i in range(len(results)-1)] 

208 for i in range(500): 

209 assert derivatives[i] <= derivatives[i+1] 

210 assert derivatives[500] == max(derivatives) 

211 for i in range(500, 998): 

212 assert derivatives[i] >= derivatives[i+1] 

213 

214@pytest.fixture 

215def mock_agent(): 

216 class MockAgent: 

217 def __init__(self, model): 

218 self.model = model 

219 self.attributes = {'age': 10} 

220 self.properties = {'lifetime': {'value': 20}} 

221 class MockModel: 

222 time = datetime.datetime(2019, 1, 1, 12) 

223 return MockAgent(MockModel()) 

224 

225class TestEvaluateGrowth: 

226 def test_evaluate_growth_daily(self, mock_agent): 

227 mode = 'daily' 

228 params = {'type': 'norm'} 

229 daily_vals = [] 

230 for hour in range(24): 

231 mock_agent.model.time = datetime.datetime(2019, 1, 1, hour) 

232 daily_vals.append(evaluate_growth(mock_agent, mode, params)) 

233 # Max growth (1) at noon 

234 assert daily_vals[12] == max(daily_vals) 

235 # Min growth (nearly 0) at midnight 

236 assert daily_vals[0] == min(daily_vals) 

237 

238 def test_evaluate_growth_lifetime(self, mock_agent): 

239 mode = 'lifetime' 

240 params = {'type': 'sigmoid'} 

241 # Halfway growth at age 10/20 

242 assert evaluate_growth(mock_agent, mode, params) == 0.5 

243 # Max growth (nearly 1) at age 20/20 

244 mock_agent.attributes['age'] = 20 

245 assert 0.999 < evaluate_growth(mock_agent, mode, params) < 1.0 

246 

247@pytest.fixture 

248def mock_data(): 

249 return { 

250 'model_string_attribute': 'test', 

251 'model_int_attribute': 1, 

252 'test_agent': { 

253 'agent_string_attribute': 'test', 

254 'agent_int_attribute': 1, 

255 'agent_list_attribute': [1, 2, 3], 

256 'agent_dict_attribute': { 

257 'dict_attr_1': [2, 3, 4], 

258 'dict_attr_2': [3, 4, 5], 

259 'dict_attr_3': [4, 5, 6], 

260 } 

261 } 

262 } 

263 

264class TestParseData: 

265 def test_parse_data_static_field(self, mock_data): 

266 model_string_attr = parse_data(mock_data, ['model_string_attribute']) 

267 assert model_string_attr == 'test' 

268 model_int_attr = parse_data(mock_data, ['model_int_attribute']) 

269 assert model_int_attr == 1 

270 agent_string_attr = parse_data(mock_data, ['test_agent', 'agent_string_attribute']) 

271 assert agent_string_attr == 'test' 

272 agent_int_attr = parse_data(mock_data, ['test_agent', 'agent_int_attribute']) 

273 assert agent_int_attr == 1 

274 

275 def test_parse_data_missing_field(self, mock_data): 

276 missing_value = parse_data(mock_data, ['missing_value']) 

277 assert missing_value == None 

278 # But still propagate Zeros 

279 mock_data['model_int_attribute'] = 0 

280 zero_value = parse_data(mock_data, ['model_int_attribute']) 

281 assert zero_value == 0 

282 

283 def test_parse_data_dict_keys(self, mock_data): 

284 single_field = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', 'dict_attr_1']) 

285 assert single_field == [2, 3, 4] 

286 all_fields = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', '*']) 

287 assert all_fields == mock_data['test_agent']['agent_dict_attribute'] 

288 selected_fields = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', 'dict_attr_1,dict_attr_2']) 

289 assert selected_fields == {'dict_attr_1': [2, 3, 4], 'dict_attr_2': [3, 4, 5]} 

290 summed_fields = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', 'SUM']) 

291 assert summed_fields == [9, 12, 15] 

292 

293 def test_parse_data_list(self, mock_data): 

294 all_items = parse_data(mock_data, ['test_agent', 'agent_list_attribute', '*']) 

295 assert all_items == [1, 2, 3] 

296 single_item = parse_data(mock_data, ['test_agent', 'agent_list_attribute', 1]) 

297 assert single_item == 2 

298 slice_item = parse_data(mock_data, ['test_agent', 'agent_list_attribute', '0:2']) 

299 assert slice_item == [1, 2]