Coverage for test/test_model.py: 100%

181 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-05-04 13:14 +0700

1import json 

2import time 

3import datetime 

4from unittest.mock import Mock 

5 

6import pytest 

7import numpy as np 

8 

9from ..agent_model.Model import (DEFAULT_START_TIME, 

10 DEFAULT_LOCATION, 

11 DEFAULT_PRIORITIES, 

12 Model, 

13 Scheduler) 

14from ..agent_model.agents import BaseAgent 

15from ..agent_model.util import get_default_agent_data 

16 

17class MockModel: 

18 def __init__(self, seed=1000): 

19 self.rng = np.random.RandomState(seed) 

20 self.priorities = ['test_agent_class_2', 'test_agent_class_1'] 

21 self.agents = {} 

22 

23class MockAgent(): 

24 def __init__(self, agent_id, agent_class): 

25 self.agent_id = agent_id 

26 self.agent_class = agent_class 

27 self.dT = None 

28 self.stepped_at = None 

29 def step(self, dT): 

30 time.sleep(.001) 

31 self.dT = dT 

32 self.stepped_at = datetime.datetime.now() 

33 

34 

35@pytest.fixture 

36def mock_model(): 

37 model = MockModel() 

38 model.agents = { 

39 'a': MockAgent('a', 'test_agent_class_1'), 

40 'b': MockAgent('b', 'test_agent_class_1'), 

41 'c': MockAgent('c', 'test_agent_class_2'), 

42 'd': MockAgent('d', 'test_agent_class_3'), 

43 'e': MockAgent('e', 'test_agent_class_4'), 

44 } 

45 return model 

46 

47class TestScheduler: 

48 def test_scheduler_init(self, mock_model): 

49 test_scheduler = Scheduler(mock_model) 

50 assert test_scheduler.model == mock_model 

51 assert test_scheduler.priorities == [*mock_model.priorities, 'other'] 

52 assert test_scheduler.class_agents == { 

53 'test_agent_class_1': ['a', 'b'], 

54 'test_agent_class_2': ['c'], 

55 'other': ['d', 'e'] 

56 } 

57 

58 def test_scheduler_step(self, mock_model): 

59 test_scheduler = Scheduler(mock_model) 

60 

61 # dT is passed appropriately 

62 test_scheduler.step(0.42) 

63 for agent in mock_model.agents.values(): 

64 assert agent.dT == 0.42 

65 

66 # Expected class order: c, [a and b], [d and e] 

67 stepped_at = {k: v.stepped_at for k, v in mock_model.agents.items()} 

68 assert stepped_at['c'] == min(stepped_at.values()) 

69 assert any(stepped_at[i] == max(stepped_at.values()) for i in ['d', 'e']) 

70 

71 # Within classes, order is random 

72 samples = [] 

73 for _ in range(5): 

74 test_scheduler.step(0.42) 

75 samples.append(( 

76 mock_model.agents['a'].stepped_at.timestamp(), 

77 mock_model.agents['b'].stepped_at.timestamp())) 

78 assert any(a < b for a, b in samples) 

79 assert any(b < a for a, b in samples) 

80 

81@pytest.fixture 

82def model_kwargs(): 

83 return { 

84 'termination': [{ 

85 "condition": "time", 

86 "value": 10, 

87 "unit": "day" 

88 }], 

89 'location': 'test_location', 

90 'priorities': ['test_agent_class_1', 'test_agent_class_2'], 

91 'start_time': '2020-01-01T00:00:00', 

92 'elapsed_time': 0, 

93 'step_num': 0, 

94 'seed': 1000, 

95 'is_terminated': False, 

96 'termination_reason': 'test_termination', 

97 } 

98 

99class TestModel: 

100 def test_model_init_basic(self): 

101 test_model = Model() 

102 assert test_model.termination == [] 

103 assert test_model.location == DEFAULT_LOCATION 

104 assert test_model.priorities == DEFAULT_PRIORITIES 

105 assert test_model.start_time == datetime.datetime.fromisoformat(DEFAULT_START_TIME) 

106 assert test_model.elapsed_time == datetime.timedelta() 

107 assert test_model.step_num == 0 

108 assert 0 <= test_model.seed <= 2**32 - 1 

109 assert test_model.is_terminated is None 

110 assert test_model.termination_reason == '' 

111 assert test_model.agents == {} 

112 assert test_model.currencies == {} 

113 assert test_model.rng == None 

114 assert test_model.scheduler == None 

115 assert test_model.records == {'time': [], 'step_num': []} 

116 

117 def test_model_init_complex(self, model_kwargs): 

118 model_kwargs['elapsed_time'] = 42 

119 model_kwargs['step_num'] = 100 

120 test_model = Model(**model_kwargs) 

121 assert test_model.termination == model_kwargs['termination'] 

122 assert test_model.location == model_kwargs['location'] 

123 assert test_model.priorities == model_kwargs['priorities'] 

124 assert test_model.start_time == datetime.datetime.fromisoformat(model_kwargs['start_time']) 

125 assert test_model.elapsed_time == datetime.timedelta(seconds=42) 

126 assert test_model.step_num == 100 

127 assert test_model.seed == model_kwargs['seed'] 

128 assert test_model.is_terminated is False 

129 assert test_model.termination_reason == 'test_termination' 

130 

131 def test_model_add_agent(self, model_kwargs): 

132 model = Model(**model_kwargs) 

133 test_agent = object() 

134 model.add_agent('test_agent_id', test_agent) 

135 with pytest.raises(ValueError): 

136 model.add_agent('test_agent_id', object()) 

137 assert model.agents == {'test_agent_id': test_agent} 

138 

139 def test_model_add_currency(self, model_kwargs): 

140 model = Model(**model_kwargs) 

141 model.add_currency('test_currency_id', {}) 

142 with pytest.raises(ValueError): 

143 model.add_currency('test_currency_id', {}) 

144 assert model.currencies == {'test_currency_id': {}} 

145 

146 def test_model_register(self, model_kwargs): 

147 # With record initial state 

148 model = Model(**model_kwargs) 

149 test_agent = Mock() 

150 model.add_agent('test_agent', test_agent) 

151 model.register(record_initial_state=True) 

152 assert isinstance(model.rng, np.random.RandomState) 

153 assert isinstance(model.scheduler, Scheduler) 

154 assert model.records == {'time': [model_kwargs['start_time']], 'step_num': [0]} 

155 test_agent.register.assert_called_once_with(True) 

156 assert model.registered 

157 

158 # Without record initial state 

159 model = Model(**model_kwargs) 

160 test_agent = Mock() 

161 model.add_agent('test_agent', test_agent) 

162 model.register(record_initial_state=False) 

163 assert model.records == {'time': [], 'step_num': []} 

164 test_agent.register.assert_called_once_with(False) 

165 

166 def test_model_from_config(self, model_kwargs): 

167 agents = {'o2_storage': {'description': 'test_description'}, 

168 'test_agent': {'capacity': {'test_currency': 0}, 'storage': {'test_currency': 0}}} 

169 currencies = {'test_currency': {'description': 'test_description'}} 

170 model = Model.from_config(agents, currencies, **model_kwargs) 

171 assert list(model.agents.keys()) == ['o2_storage', 'test_agent'] 

172 assert 'test_currency' in model.currencies 

173 assert model.registered 

174 assert len(model.records['time']) == 1 

175 

176 # Check that agent is merged with default agent data 

177 default_o2_storage_agent = get_default_agent_data('o2_storage') 

178 o2_storage_agent = model.agents['o2_storage'] 

179 assert isinstance(o2_storage_agent, BaseAgent) 

180 assert o2_storage_agent.description == 'test_description' 

181 assert o2_storage_agent.capacity == default_o2_storage_agent['capacity'] 

182 

183 # Check that loaded models (step_num != 0) don't record initial state 

184 model_kwargs['step_num'] = 1 

185 model = Model.from_config(agents, currencies, **model_kwargs) 

186 assert model.registered 

187 assert len(model.records['time']) == 0 

188 

189 def test_model_time(self, model_kwargs): 

190 model = Model(**model_kwargs) 

191 assert model.time == model.start_time 

192 model.elapsed_time = datetime.timedelta(seconds=42) 

193 assert model.time == model.start_time + model.elapsed_time 

194 

195 def test_model_step(self, model_kwargs): 

196 model = Model(**model_kwargs) 

197 model.step() 

198 assert model.registered 

199 assert model.step_num == 1 

200 assert model.elapsed_time == datetime.timedelta(hours=1) 

201 assert model.records['time'] == ['2020-01-01T01:00:00'] 

202 assert model.records['step_num'] == [1] 

203 

204 def test_model_terminate(self, model_kwargs): 

205 for (unit, value) in (('day', 2), ('hour', 30)): 

206 model_kwargs['termination'][0]['unit'] = unit 

207 model_kwargs['termination'][0]['value'] = value 

208 model = Model(**model_kwargs) 

209 expected_steps = value * 24 if unit == 'day' else value 

210 for _ in range(expected_steps - 1): 

211 model.step() 

212 assert not model.is_terminated 

213 model.step() 

214 assert model.is_terminated 

215 assert model.termination_reason == 'time' 

216 

217 def test_model_get_records(self, model_kwargs): 

218 model = Model.from_config({'o2_storage': {}}, {}, **model_kwargs) 

219 model.step() 

220 # Basic records 

221 records = model.get_records() 

222 assert list(records.keys()) == ['time', 'step_num', 'agents'] 

223 assert len(records['time']) == 2 

224 assert len(records['agents']['o2_storage']['active']) == 2 

225 # Include static records 

226 records = model.get_records(static=True) 

227 assert 'static' in records 

228 assert 'static' in records['agents']['o2_storage'] 

229 assert list(records['static'].keys()) == ['currencies', 'termination', 

230 'location', 'priorities', 

231 'start_time', 'seed'] 

232 # Check that everything in records is serializable 

233 json.dumps(records) 

234 assert True 

235 # Clear cache 

236 records = model.get_records(clear_cache=True) 

237 assert len(model.records['time']) == 0 

238 assert len(model.agents['o2_storage'].records['active']) == 0 

239 

240 def test_model_save(self, model_kwargs): 

241 model = Model.from_config({'o2_storage': {}}, {}, **model_kwargs) 

242 model.step() 

243 saved = model.save() 

244 assert 'records' not in saved 

245 json.dumps(saved) 

246 assert True 

247 # Re-load 

248 # TODO: Test save/load in depth 

249 model = Model.from_config(**saved) 

250 model.step() 

251 assert True 

252 # Include records 

253 saved = model.save(records=True) 

254 assert 'records' in saved