Coverage for test/test_model.py: 100%
181 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
1import json
2import time
3import datetime
4from unittest.mock import Mock
6import pytest
7import numpy as np
9from ..agent_model.Model import (DEFAULT_START_TIME,
10 DEFAULT_LOCATION,
11 DEFAULT_PRIORITIES,
12 Model,
13 Scheduler)
14from ..agent_model.agents import BaseAgent
15from ..agent_model.util import get_default_agent_data
17class MockModel:
18 def __init__(self, seed=1000):
19 self.rng = np.random.RandomState(seed)
20 self.priorities = ['test_agent_class_2', 'test_agent_class_1']
21 self.agents = {}
23class MockAgent():
24 def __init__(self, agent_id, agent_class):
25 self.agent_id = agent_id
26 self.agent_class = agent_class
27 self.dT = None
28 self.stepped_at = None
29 def step(self, dT):
30 time.sleep(.001)
31 self.dT = dT
32 self.stepped_at = datetime.datetime.now()
35@pytest.fixture
36def mock_model():
37 model = MockModel()
38 model.agents = {
39 'a': MockAgent('a', 'test_agent_class_1'),
40 'b': MockAgent('b', 'test_agent_class_1'),
41 'c': MockAgent('c', 'test_agent_class_2'),
42 'd': MockAgent('d', 'test_agent_class_3'),
43 'e': MockAgent('e', 'test_agent_class_4'),
44 }
45 return model
47class TestScheduler:
48 def test_scheduler_init(self, mock_model):
49 test_scheduler = Scheduler(mock_model)
50 assert test_scheduler.model == mock_model
51 assert test_scheduler.priorities == [*mock_model.priorities, 'other']
52 assert test_scheduler.class_agents == {
53 'test_agent_class_1': ['a', 'b'],
54 'test_agent_class_2': ['c'],
55 'other': ['d', 'e']
56 }
58 def test_scheduler_step(self, mock_model):
59 test_scheduler = Scheduler(mock_model)
61 # dT is passed appropriately
62 test_scheduler.step(0.42)
63 for agent in mock_model.agents.values():
64 assert agent.dT == 0.42
66 # Expected class order: c, [a and b], [d and e]
67 stepped_at = {k: v.stepped_at for k, v in mock_model.agents.items()}
68 assert stepped_at['c'] == min(stepped_at.values())
69 assert any(stepped_at[i] == max(stepped_at.values()) for i in ['d', 'e'])
71 # Within classes, order is random
72 samples = []
73 for _ in range(5):
74 test_scheduler.step(0.42)
75 samples.append((
76 mock_model.agents['a'].stepped_at.timestamp(),
77 mock_model.agents['b'].stepped_at.timestamp()))
78 assert any(a < b for a, b in samples)
79 assert any(b < a for a, b in samples)
81@pytest.fixture
82def model_kwargs():
83 return {
84 'termination': [{
85 "condition": "time",
86 "value": 10,
87 "unit": "day"
88 }],
89 'location': 'test_location',
90 'priorities': ['test_agent_class_1', 'test_agent_class_2'],
91 'start_time': '2020-01-01T00:00:00',
92 'elapsed_time': 0,
93 'step_num': 0,
94 'seed': 1000,
95 'is_terminated': False,
96 'termination_reason': 'test_termination',
97 }
99class TestModel:
100 def test_model_init_basic(self):
101 test_model = Model()
102 assert test_model.termination == []
103 assert test_model.location == DEFAULT_LOCATION
104 assert test_model.priorities == DEFAULT_PRIORITIES
105 assert test_model.start_time == datetime.datetime.fromisoformat(DEFAULT_START_TIME)
106 assert test_model.elapsed_time == datetime.timedelta()
107 assert test_model.step_num == 0
108 assert 0 <= test_model.seed <= 2**32 - 1
109 assert test_model.is_terminated is None
110 assert test_model.termination_reason == ''
111 assert test_model.agents == {}
112 assert test_model.currencies == {}
113 assert test_model.rng == None
114 assert test_model.scheduler == None
115 assert test_model.records == {'time': [], 'step_num': []}
117 def test_model_init_complex(self, model_kwargs):
118 model_kwargs['elapsed_time'] = 42
119 model_kwargs['step_num'] = 100
120 test_model = Model(**model_kwargs)
121 assert test_model.termination == model_kwargs['termination']
122 assert test_model.location == model_kwargs['location']
123 assert test_model.priorities == model_kwargs['priorities']
124 assert test_model.start_time == datetime.datetime.fromisoformat(model_kwargs['start_time'])
125 assert test_model.elapsed_time == datetime.timedelta(seconds=42)
126 assert test_model.step_num == 100
127 assert test_model.seed == model_kwargs['seed']
128 assert test_model.is_terminated is False
129 assert test_model.termination_reason == 'test_termination'
131 def test_model_add_agent(self, model_kwargs):
132 model = Model(**model_kwargs)
133 test_agent = object()
134 model.add_agent('test_agent_id', test_agent)
135 with pytest.raises(ValueError):
136 model.add_agent('test_agent_id', object())
137 assert model.agents == {'test_agent_id': test_agent}
139 def test_model_add_currency(self, model_kwargs):
140 model = Model(**model_kwargs)
141 model.add_currency('test_currency_id', {})
142 with pytest.raises(ValueError):
143 model.add_currency('test_currency_id', {})
144 assert model.currencies == {'test_currency_id': {}}
146 def test_model_register(self, model_kwargs):
147 # With record initial state
148 model = Model(**model_kwargs)
149 test_agent = Mock()
150 model.add_agent('test_agent', test_agent)
151 model.register(record_initial_state=True)
152 assert isinstance(model.rng, np.random.RandomState)
153 assert isinstance(model.scheduler, Scheduler)
154 assert model.records == {'time': [model_kwargs['start_time']], 'step_num': [0]}
155 test_agent.register.assert_called_once_with(True)
156 assert model.registered
158 # Without record initial state
159 model = Model(**model_kwargs)
160 test_agent = Mock()
161 model.add_agent('test_agent', test_agent)
162 model.register(record_initial_state=False)
163 assert model.records == {'time': [], 'step_num': []}
164 test_agent.register.assert_called_once_with(False)
166 def test_model_from_config(self, model_kwargs):
167 agents = {'o2_storage': {'description': 'test_description'},
168 'test_agent': {'capacity': {'test_currency': 0}, 'storage': {'test_currency': 0}}}
169 currencies = {'test_currency': {'description': 'test_description'}}
170 model = Model.from_config(agents, currencies, **model_kwargs)
171 assert list(model.agents.keys()) == ['o2_storage', 'test_agent']
172 assert 'test_currency' in model.currencies
173 assert model.registered
174 assert len(model.records['time']) == 1
176 # Check that agent is merged with default agent data
177 default_o2_storage_agent = get_default_agent_data('o2_storage')
178 o2_storage_agent = model.agents['o2_storage']
179 assert isinstance(o2_storage_agent, BaseAgent)
180 assert o2_storage_agent.description == 'test_description'
181 assert o2_storage_agent.capacity == default_o2_storage_agent['capacity']
183 # Check that loaded models (step_num != 0) don't record initial state
184 model_kwargs['step_num'] = 1
185 model = Model.from_config(agents, currencies, **model_kwargs)
186 assert model.registered
187 assert len(model.records['time']) == 0
189 def test_model_time(self, model_kwargs):
190 model = Model(**model_kwargs)
191 assert model.time == model.start_time
192 model.elapsed_time = datetime.timedelta(seconds=42)
193 assert model.time == model.start_time + model.elapsed_time
195 def test_model_step(self, model_kwargs):
196 model = Model(**model_kwargs)
197 model.step()
198 assert model.registered
199 assert model.step_num == 1
200 assert model.elapsed_time == datetime.timedelta(hours=1)
201 assert model.records['time'] == ['2020-01-01T01:00:00']
202 assert model.records['step_num'] == [1]
204 def test_model_terminate(self, model_kwargs):
205 for (unit, value) in (('day', 2), ('hour', 30)):
206 model_kwargs['termination'][0]['unit'] = unit
207 model_kwargs['termination'][0]['value'] = value
208 model = Model(**model_kwargs)
209 expected_steps = value * 24 if unit == 'day' else value
210 for _ in range(expected_steps - 1):
211 model.step()
212 assert not model.is_terminated
213 model.step()
214 assert model.is_terminated
215 assert model.termination_reason == 'time'
217 def test_model_get_records(self, model_kwargs):
218 model = Model.from_config({'o2_storage': {}}, {}, **model_kwargs)
219 model.step()
220 # Basic records
221 records = model.get_records()
222 assert list(records.keys()) == ['time', 'step_num', 'agents']
223 assert len(records['time']) == 2
224 assert len(records['agents']['o2_storage']['active']) == 2
225 # Include static records
226 records = model.get_records(static=True)
227 assert 'static' in records
228 assert 'static' in records['agents']['o2_storage']
229 assert list(records['static'].keys()) == ['currencies', 'termination',
230 'location', 'priorities',
231 'start_time', 'seed']
232 # Check that everything in records is serializable
233 json.dumps(records)
234 assert True
235 # Clear cache
236 records = model.get_records(clear_cache=True)
237 assert len(model.records['time']) == 0
238 assert len(model.agents['o2_storage'].records['active']) == 0
240 def test_model_save(self, model_kwargs):
241 model = Model.from_config({'o2_storage': {}}, {}, **model_kwargs)
242 model.step()
243 saved = model.save()
244 assert 'records' not in saved
245 json.dumps(saved)
246 assert True
247 # Re-load
248 # TODO: Test save/load in depth
249 model = Model.from_config(**saved)
250 model.step()
251 assert True
252 # Include records
253 saved = model.save(records=True)
254 assert 'records' in saved