bettermdptools.envs.pendulum_discretized
Author: Aleksandr Spiridonov BSD 3-Clause License
1""" 2Author: Aleksandr Spiridonov 3BSD 3-Clause License 4""" 5 6import gzip 7import os 8import pickle 9from concurrent.futures import ProcessPoolExecutor, as_completed 10 11import numpy as np 12from gymnasium.envs.classic_control.acrobot import wrap 13from gymnasium.envs.classic_control.pendulum import angle_normalize 14 15from bettermdptools.envs.binning import generate_bin_edges 16 17CACHED_P_PATH_FORMAT = "cached_P_discretized_pendulum_{angle_bins}_{angular_velocity_bins}_{action_bins}.pkl.gz" 18 19 20def index_to_state(index, angle_bins, angular_velocity_bins): 21 angle_idx = index // angular_velocity_bins 22 angular_velocity_idx = index % angular_velocity_bins 23 return angle_idx, angular_velocity_idx 24 25 26def index_to_continous_state(index, angle_bin_edges, angular_velocity_bin_edges): 27 angle_idx, angular_velocity_idx = index_to_state( 28 index, len(angle_bin_edges) - 1, len(angular_velocity_bin_edges) - 1 29 ) 30 angle = (angle_bin_edges[angle_idx] + angle_bin_edges[angle_idx + 1]) / 2.0 31 angular_velocity = ( 32 angular_velocity_bin_edges[angular_velocity_idx] 33 + angular_velocity_bin_edges[angular_velocity_idx + 1] 34 ) / 2.0 35 return angle, angular_velocity 36 37 38def state_to_index(angle_idx, angular_velocity_idx, angular_velocity_bins): 39 return angle_idx * angular_velocity_bins + angular_velocity_idx 40 41 42def get_torque_value(torque_bin_edges, action): 43 return (torque_bin_edges[action] + torque_bin_edges[action + 1]) / 2.0 44 45 46def compute_next_probable_states( 47 angle_idx, 48 angular_velocity_idx, 49 action, 50 angle_bin_edges, 51 angular_velocity_bin_edges, 52 torque_bin_edges, 53 num_samples=11, 54 g=10.0, 55 l=1.0, 56 m=1.0, 57 dt=0.05, 58): 59 angle_low, angle_high = angle_bin_edges[angle_idx], angle_bin_edges[angle_idx + 1] 60 angular_velocity_low, angular_velocity_high = ( 61 angular_velocity_bin_edges[angular_velocity_idx], 62 angular_velocity_bin_edges[angular_velocity_idx + 1], 63 ) 64 torque = get_torque_value(torque_bin_edges, action) 65 66 min_angular_velocity = angular_velocity_bin_edges[0] 67 max_angular_velocity = angular_velocity_bin_edges[-1] 68 69 angle_samples = np.linspace(angle_low, angle_high, num_samples) 70 angle_samples = angle_samples[1:-1] # Exclude the bin edges 71 angular_velocity_samples = np.linspace( 72 angular_velocity_low, angular_velocity_high, num_samples 73 ) 74 angular_velocity_samples = angular_velocity_samples[1:-1] 75 76 angle_bins = len(angle_bin_edges) - 1 77 angular_velocity_bins = len(angular_velocity_bin_edges) - 1 78 79 next_states_and_rewards = {} 80 81 for angle in angle_samples: 82 for angular_velocity in angular_velocity_samples: 83 costs = ( 84 angle_normalize(angle) ** 2 85 + 0.1 * angular_velocity**2 86 + 0.001 * (torque**2) 87 ) 88 89 new_angular_velocity = ( 90 angular_velocity 91 + (3 * g / (2 * l) * np.sin(angle) + 3.0 / (m * l**2) * torque) * dt 92 ) 93 new_angular_velocity = np.clip( 94 new_angular_velocity, 95 min_angular_velocity + 1e-6, 96 max_angular_velocity - 1e-6, 97 ) 98 99 new_angle = angle + new_angular_velocity * dt 100 new_angle = wrap(new_angle, -np.pi, np.pi) 101 102 new_angle_idx = np.digitize(new_angle, angle_bin_edges) - 1 103 new_angular_velocity_idx = ( 104 np.digitize(new_angular_velocity, angular_velocity_bin_edges) - 1 105 ) 106 107 new_state = state_to_index( 108 new_angle_idx, new_angular_velocity_idx, angular_velocity_bins 109 ) 110 111 if new_state < 0 or new_state >= angle_bins * angular_velocity_bins: 112 raise ValueError(f"Invalid state index: {new_state}") 113 114 terminated = False 115 116 summary = (new_state, -costs, terminated) 117 118 if new_state not in next_states_and_rewards: 119 next_states_and_rewards[new_state] = [] 120 next_states_and_rewards[new_state].append(summary) 121 122 n_total = len(angle_samples) * len(angular_velocity_samples) 123 124 results = [] 125 126 for new_state, summaries in next_states_and_rewards.items(): 127 n = len(summaries) 128 prob = n / n_total 129 ave_reward = sum(r for _, r, _ in summaries) / n 130 terminated = False 131 results.append((prob, new_state, ave_reward, terminated)) 132 133 return results 134 135 136def setup_transition_probabilities_for_state(args): 137 ( 138 state, 139 angle_bin_edges, 140 angular_velocity_bin_edges, 141 torque_bin_edges, 142 dim_samples, 143 ) = args 144 angle_bins = len(angle_bin_edges) - 1 145 angular_velocity_bins = len(angular_velocity_bin_edges) - 1 146 action_bins = len(torque_bin_edges) - 1 147 148 P_state = {action: [] for action in range(action_bins)} 149 150 angle_idx, angular_velocity_idx = index_to_state( 151 state, angle_bins, angular_velocity_bins 152 ) 153 154 for action in range(action_bins): 155 P_state[action] = compute_next_probable_states( 156 angle_idx, 157 angular_velocity_idx, 158 action, 159 angle_bin_edges, 160 angular_velocity_bin_edges, 161 torque_bin_edges, 162 num_samples=dim_samples, 163 ) 164 165 try: 166 return (state, P_state) 167 except Exception as e: 168 print(f"Error in state {state}: {e}") 169 return None 170 171 172class DiscretizedPendulum: 173 """ 174 Initialize the DiscretizedPendulum environment. 175 Parameters: 176 ----------- 177 angle_bins : int 178 Number of bins to discretize the angle. 179 angular_velocity_bins : int 180 Number of bins to discretize the angular velocity. 181 torque_bins : int, optional (default=11) 182 Number of bins to discretize the torque. 183 n_workers : int, optional (default=4) 184 Number of worker processes to use for setting up transition probabilities. 185 cache_dir : str, optional (default='./cached') 186 Directory to cache the transition probabilities. 187 dim_samples : int, optional (default=11) 188 Number of samples to use for each dimension when setting up transition probabilities. 189 Attributes: 190 ----------- 191 angle_bins : int 192 Number of bins to discretize the angle. Must be odd. 193 angular_velocity_bins : int 194 Number of bins to discretize the angular velocity. Must be odd. 195 dim_samples : int 196 Number of samples to use for each dimension when setting up transition probabilities. 197 angle_bin_edges : numpy.ndarray 198 Edges of the bins for discretizing the angle. 199 angular_velocity_bin_edges : numpy.ndarray 200 Edges of the bins for discretizing the angular velocity. 201 torque_bin_edges : numpy.ndarray 202 Edges of the bins for discretizing the torque. 203 state_space : int 204 Total number of discrete states. 205 action_space : int 206 Total number of discrete actions. 207 P : dict 208 Transition probability matrix. 209 n_workers : int 210 Number of worker processes to use for setting up transition probabilities. 211 """ 212 213 def __init__( 214 self, 215 angle_bins, 216 angular_velocity_bins, 217 torque_bins=11, 218 n_workers=4, 219 cache_dir="./cached", 220 dim_samples=11, 221 ): 222 self.angle_bins = angle_bins 223 self.angular_velocity_bins = angular_velocity_bins 224 self.dim_samples = dim_samples 225 self.angle_bin_edges = generate_bin_edges(np.pi, angle_bins, 3, center=True) 226 self.angular_velocity_bin_edges = generate_bin_edges( 227 8, angular_velocity_bins, 3, center=False 228 ) 229 self.torque_bin_edges = generate_bin_edges(2, torque_bins, 3, center=False) 230 231 self.state_space = angle_bins * angular_velocity_bins 232 self.action_space = torque_bins 233 234 self.P = { 235 state: {action: [] for action in range(torque_bins)} 236 for state in range(self.state_space) 237 } 238 239 self.n_workers = n_workers 240 241 cached_P_filepath = CACHED_P_PATH_FORMAT.format( 242 angle_bins=angle_bins, 243 angular_velocity_bins=angular_velocity_bins, 244 action_bins=torque_bins, 245 ) 246 cached_P_filepath = os.path.join(cache_dir, cached_P_filepath) 247 248 if not os.path.exists(cache_dir): 249 os.makedirs(cache_dir) 250 251 if os.path.exists(cached_P_filepath): 252 with gzip.open(cached_P_filepath, "rb") as f: 253 self.P = pickle.load(f) 254 else: 255 self.setup_transition_probabilities() 256 with gzip.open(cached_P_filepath, "wb") as f: 257 pickle.dump(self.P, f) 258 259 def discretize_angle(self, angle): 260 return np.digitize(angle, self.angle_bin_edges) - 1 261 262 def discretize_angular_velocity(self, angular_velocity): 263 return np.digitize(angular_velocity, self.angular_velocity_bin_edges) - 1 264 265 def index_to_state(self, index): 266 return index_to_state(index, self.angle_bins, self.angular_velocity_bins) 267 268 def state_to_index(self, angle_idx, angular_velocity_idx): 269 idx = state_to_index( 270 angle_idx, angular_velocity_idx, self.angular_velocity_bins 271 ) 272 if idx < 0 or idx >= self.state_space: 273 raise ValueError(f"Invalid state index: {idx}") 274 return idx 275 276 def transform_cont_obs(self, cont_obs): 277 x = cont_obs[0] 278 y = cont_obs[1] 279 theta = np.arctan2(y, x) 280 theta = wrap(theta, -np.pi, np.pi) 281 theta_dot = cont_obs[2] 282 theta_dot = np.clip(theta_dot, -8 + 1e-6, 8 - 1e-6) 283 284 angle_idx = self.discretize_angle(theta) 285 angular_velocity_idx = self.discretize_angular_velocity(theta_dot) 286 287 return self.state_to_index(angle_idx, angular_velocity_idx) 288 289 def get_action_value(self, action): 290 return get_torque_value(self.torque_bin_edges, action) 291 292 def setup_transition_probabilities(self): 293 state_space_values = list(range(self.state_space)) 294 295 args = [ 296 ( 297 state, 298 self.angle_bin_edges, 299 self.angular_velocity_bin_edges, 300 self.torque_bin_edges, 301 self.dim_samples, 302 ) 303 for state in state_space_values 304 ] 305 306 new_P = {} 307 308 args = [arg for arg in args if arg[0] not in new_P] 309 310 num_workers = self.n_workers 311 312 n_completed = len(new_P) 313 314 batch_size = 1000 315 316 with ProcessPoolExecutor(max_workers=num_workers) as executor: 317 for i in range(0, len(args), batch_size): 318 batch = args[i : i + batch_size] 319 futures = [ 320 executor.submit(setup_transition_probabilities_for_state, arg) 321 for arg in batch 322 ] 323 for future in as_completed(futures): 324 n_completed += 1 325 try: 326 state, P_state = future.result() 327 new_P[state] = P_state 328 if n_completed % 100 == 0: 329 print(f"Completed {n_completed}/{self.state_space}") 330 except Exception as e: 331 print(f"Error in future: {e}") 332 print("task failed") 333 334 self.P = new_P 335 336 337if __name__ == "__main__": 338 n_bins = 31 339 angle_bins = n_bins 340 angular_velocity_bins = n_bins 341 342 discretized_pendulum = DiscretizedPendulum( 343 angle_bins=angle_bins, angular_velocity_bins=angular_velocity_bins 344 ) 345 346 angle = np.pi / 2 347 angular_velocity = 3 348 349 obs = np.array([np.cos(angle), np.sin(angle), angular_velocity]) 350 351 state = discretized_pendulum.transform_cont_obs(obs) 352 print(f"Discretized state index: {state}") 353 354 for action in range(discretized_pendulum.action_space): 355 transitions = discretized_pendulum.P[state][action] 356 for prob, next_state, reward, terminated in transitions: 357 print( 358 f"Action: {action}, Probability: {prob}, Next state: {next_state}, Reward: {reward}, Terminated: {terminated}" 359 )
27def index_to_continous_state(index, angle_bin_edges, angular_velocity_bin_edges): 28 angle_idx, angular_velocity_idx = index_to_state( 29 index, len(angle_bin_edges) - 1, len(angular_velocity_bin_edges) - 1 30 ) 31 angle = (angle_bin_edges[angle_idx] + angle_bin_edges[angle_idx + 1]) / 2.0 32 angular_velocity = ( 33 angular_velocity_bin_edges[angular_velocity_idx] 34 + angular_velocity_bin_edges[angular_velocity_idx + 1] 35 ) / 2.0 36 return angle, angular_velocity
47def compute_next_probable_states( 48 angle_idx, 49 angular_velocity_idx, 50 action, 51 angle_bin_edges, 52 angular_velocity_bin_edges, 53 torque_bin_edges, 54 num_samples=11, 55 g=10.0, 56 l=1.0, 57 m=1.0, 58 dt=0.05, 59): 60 angle_low, angle_high = angle_bin_edges[angle_idx], angle_bin_edges[angle_idx + 1] 61 angular_velocity_low, angular_velocity_high = ( 62 angular_velocity_bin_edges[angular_velocity_idx], 63 angular_velocity_bin_edges[angular_velocity_idx + 1], 64 ) 65 torque = get_torque_value(torque_bin_edges, action) 66 67 min_angular_velocity = angular_velocity_bin_edges[0] 68 max_angular_velocity = angular_velocity_bin_edges[-1] 69 70 angle_samples = np.linspace(angle_low, angle_high, num_samples) 71 angle_samples = angle_samples[1:-1] # Exclude the bin edges 72 angular_velocity_samples = np.linspace( 73 angular_velocity_low, angular_velocity_high, num_samples 74 ) 75 angular_velocity_samples = angular_velocity_samples[1:-1] 76 77 angle_bins = len(angle_bin_edges) - 1 78 angular_velocity_bins = len(angular_velocity_bin_edges) - 1 79 80 next_states_and_rewards = {} 81 82 for angle in angle_samples: 83 for angular_velocity in angular_velocity_samples: 84 costs = ( 85 angle_normalize(angle) ** 2 86 + 0.1 * angular_velocity**2 87 + 0.001 * (torque**2) 88 ) 89 90 new_angular_velocity = ( 91 angular_velocity 92 + (3 * g / (2 * l) * np.sin(angle) + 3.0 / (m * l**2) * torque) * dt 93 ) 94 new_angular_velocity = np.clip( 95 new_angular_velocity, 96 min_angular_velocity + 1e-6, 97 max_angular_velocity - 1e-6, 98 ) 99 100 new_angle = angle + new_angular_velocity * dt 101 new_angle = wrap(new_angle, -np.pi, np.pi) 102 103 new_angle_idx = np.digitize(new_angle, angle_bin_edges) - 1 104 new_angular_velocity_idx = ( 105 np.digitize(new_angular_velocity, angular_velocity_bin_edges) - 1 106 ) 107 108 new_state = state_to_index( 109 new_angle_idx, new_angular_velocity_idx, angular_velocity_bins 110 ) 111 112 if new_state < 0 or new_state >= angle_bins * angular_velocity_bins: 113 raise ValueError(f"Invalid state index: {new_state}") 114 115 terminated = False 116 117 summary = (new_state, -costs, terminated) 118 119 if new_state not in next_states_and_rewards: 120 next_states_and_rewards[new_state] = [] 121 next_states_and_rewards[new_state].append(summary) 122 123 n_total = len(angle_samples) * len(angular_velocity_samples) 124 125 results = [] 126 127 for new_state, summaries in next_states_and_rewards.items(): 128 n = len(summaries) 129 prob = n / n_total 130 ave_reward = sum(r for _, r, _ in summaries) / n 131 terminated = False 132 results.append((prob, new_state, ave_reward, terminated)) 133 134 return results
137def setup_transition_probabilities_for_state(args): 138 ( 139 state, 140 angle_bin_edges, 141 angular_velocity_bin_edges, 142 torque_bin_edges, 143 dim_samples, 144 ) = args 145 angle_bins = len(angle_bin_edges) - 1 146 angular_velocity_bins = len(angular_velocity_bin_edges) - 1 147 action_bins = len(torque_bin_edges) - 1 148 149 P_state = {action: [] for action in range(action_bins)} 150 151 angle_idx, angular_velocity_idx = index_to_state( 152 state, angle_bins, angular_velocity_bins 153 ) 154 155 for action in range(action_bins): 156 P_state[action] = compute_next_probable_states( 157 angle_idx, 158 angular_velocity_idx, 159 action, 160 angle_bin_edges, 161 angular_velocity_bin_edges, 162 torque_bin_edges, 163 num_samples=dim_samples, 164 ) 165 166 try: 167 return (state, P_state) 168 except Exception as e: 169 print(f"Error in state {state}: {e}") 170 return None
173class DiscretizedPendulum: 174 """ 175 Initialize the DiscretizedPendulum environment. 176 Parameters: 177 ----------- 178 angle_bins : int 179 Number of bins to discretize the angle. 180 angular_velocity_bins : int 181 Number of bins to discretize the angular velocity. 182 torque_bins : int, optional (default=11) 183 Number of bins to discretize the torque. 184 n_workers : int, optional (default=4) 185 Number of worker processes to use for setting up transition probabilities. 186 cache_dir : str, optional (default='./cached') 187 Directory to cache the transition probabilities. 188 dim_samples : int, optional (default=11) 189 Number of samples to use for each dimension when setting up transition probabilities. 190 Attributes: 191 ----------- 192 angle_bins : int 193 Number of bins to discretize the angle. Must be odd. 194 angular_velocity_bins : int 195 Number of bins to discretize the angular velocity. Must be odd. 196 dim_samples : int 197 Number of samples to use for each dimension when setting up transition probabilities. 198 angle_bin_edges : numpy.ndarray 199 Edges of the bins for discretizing the angle. 200 angular_velocity_bin_edges : numpy.ndarray 201 Edges of the bins for discretizing the angular velocity. 202 torque_bin_edges : numpy.ndarray 203 Edges of the bins for discretizing the torque. 204 state_space : int 205 Total number of discrete states. 206 action_space : int 207 Total number of discrete actions. 208 P : dict 209 Transition probability matrix. 210 n_workers : int 211 Number of worker processes to use for setting up transition probabilities. 212 """ 213 214 def __init__( 215 self, 216 angle_bins, 217 angular_velocity_bins, 218 torque_bins=11, 219 n_workers=4, 220 cache_dir="./cached", 221 dim_samples=11, 222 ): 223 self.angle_bins = angle_bins 224 self.angular_velocity_bins = angular_velocity_bins 225 self.dim_samples = dim_samples 226 self.angle_bin_edges = generate_bin_edges(np.pi, angle_bins, 3, center=True) 227 self.angular_velocity_bin_edges = generate_bin_edges( 228 8, angular_velocity_bins, 3, center=False 229 ) 230 self.torque_bin_edges = generate_bin_edges(2, torque_bins, 3, center=False) 231 232 self.state_space = angle_bins * angular_velocity_bins 233 self.action_space = torque_bins 234 235 self.P = { 236 state: {action: [] for action in range(torque_bins)} 237 for state in range(self.state_space) 238 } 239 240 self.n_workers = n_workers 241 242 cached_P_filepath = CACHED_P_PATH_FORMAT.format( 243 angle_bins=angle_bins, 244 angular_velocity_bins=angular_velocity_bins, 245 action_bins=torque_bins, 246 ) 247 cached_P_filepath = os.path.join(cache_dir, cached_P_filepath) 248 249 if not os.path.exists(cache_dir): 250 os.makedirs(cache_dir) 251 252 if os.path.exists(cached_P_filepath): 253 with gzip.open(cached_P_filepath, "rb") as f: 254 self.P = pickle.load(f) 255 else: 256 self.setup_transition_probabilities() 257 with gzip.open(cached_P_filepath, "wb") as f: 258 pickle.dump(self.P, f) 259 260 def discretize_angle(self, angle): 261 return np.digitize(angle, self.angle_bin_edges) - 1 262 263 def discretize_angular_velocity(self, angular_velocity): 264 return np.digitize(angular_velocity, self.angular_velocity_bin_edges) - 1 265 266 def index_to_state(self, index): 267 return index_to_state(index, self.angle_bins, self.angular_velocity_bins) 268 269 def state_to_index(self, angle_idx, angular_velocity_idx): 270 idx = state_to_index( 271 angle_idx, angular_velocity_idx, self.angular_velocity_bins 272 ) 273 if idx < 0 or idx >= self.state_space: 274 raise ValueError(f"Invalid state index: {idx}") 275 return idx 276 277 def transform_cont_obs(self, cont_obs): 278 x = cont_obs[0] 279 y = cont_obs[1] 280 theta = np.arctan2(y, x) 281 theta = wrap(theta, -np.pi, np.pi) 282 theta_dot = cont_obs[2] 283 theta_dot = np.clip(theta_dot, -8 + 1e-6, 8 - 1e-6) 284 285 angle_idx = self.discretize_angle(theta) 286 angular_velocity_idx = self.discretize_angular_velocity(theta_dot) 287 288 return self.state_to_index(angle_idx, angular_velocity_idx) 289 290 def get_action_value(self, action): 291 return get_torque_value(self.torque_bin_edges, action) 292 293 def setup_transition_probabilities(self): 294 state_space_values = list(range(self.state_space)) 295 296 args = [ 297 ( 298 state, 299 self.angle_bin_edges, 300 self.angular_velocity_bin_edges, 301 self.torque_bin_edges, 302 self.dim_samples, 303 ) 304 for state in state_space_values 305 ] 306 307 new_P = {} 308 309 args = [arg for arg in args if arg[0] not in new_P] 310 311 num_workers = self.n_workers 312 313 n_completed = len(new_P) 314 315 batch_size = 1000 316 317 with ProcessPoolExecutor(max_workers=num_workers) as executor: 318 for i in range(0, len(args), batch_size): 319 batch = args[i : i + batch_size] 320 futures = [ 321 executor.submit(setup_transition_probabilities_for_state, arg) 322 for arg in batch 323 ] 324 for future in as_completed(futures): 325 n_completed += 1 326 try: 327 state, P_state = future.result() 328 new_P[state] = P_state 329 if n_completed % 100 == 0: 330 print(f"Completed {n_completed}/{self.state_space}") 331 except Exception as e: 332 print(f"Error in future: {e}") 333 print("task failed") 334 335 self.P = new_P
Initialize the DiscretizedPendulum environment.
Parameters:
angle_bins : int Number of bins to discretize the angle. angular_velocity_bins : int Number of bins to discretize the angular velocity. torque_bins : int, optional (default=11) Number of bins to discretize the torque. n_workers : int, optional (default=4) Number of worker processes to use for setting up transition probabilities. cache_dir : str, optional (default='./cached') Directory to cache the transition probabilities. dim_samples : int, optional (default=11) Number of samples to use for each dimension when setting up transition probabilities.
Attributes:
angle_bins : int Number of bins to discretize the angle. Must be odd. angular_velocity_bins : int Number of bins to discretize the angular velocity. Must be odd. dim_samples : int Number of samples to use for each dimension when setting up transition probabilities. angle_bin_edges : numpy.ndarray Edges of the bins for discretizing the angle. angular_velocity_bin_edges : numpy.ndarray Edges of the bins for discretizing the angular velocity. torque_bin_edges : numpy.ndarray Edges of the bins for discretizing the torque. state_space : int Total number of discrete states. action_space : int Total number of discrete actions. P : dict Transition probability matrix. n_workers : int Number of worker processes to use for setting up transition probabilities.
214 def __init__( 215 self, 216 angle_bins, 217 angular_velocity_bins, 218 torque_bins=11, 219 n_workers=4, 220 cache_dir="./cached", 221 dim_samples=11, 222 ): 223 self.angle_bins = angle_bins 224 self.angular_velocity_bins = angular_velocity_bins 225 self.dim_samples = dim_samples 226 self.angle_bin_edges = generate_bin_edges(np.pi, angle_bins, 3, center=True) 227 self.angular_velocity_bin_edges = generate_bin_edges( 228 8, angular_velocity_bins, 3, center=False 229 ) 230 self.torque_bin_edges = generate_bin_edges(2, torque_bins, 3, center=False) 231 232 self.state_space = angle_bins * angular_velocity_bins 233 self.action_space = torque_bins 234 235 self.P = { 236 state: {action: [] for action in range(torque_bins)} 237 for state in range(self.state_space) 238 } 239 240 self.n_workers = n_workers 241 242 cached_P_filepath = CACHED_P_PATH_FORMAT.format( 243 angle_bins=angle_bins, 244 angular_velocity_bins=angular_velocity_bins, 245 action_bins=torque_bins, 246 ) 247 cached_P_filepath = os.path.join(cache_dir, cached_P_filepath) 248 249 if not os.path.exists(cache_dir): 250 os.makedirs(cache_dir) 251 252 if os.path.exists(cached_P_filepath): 253 with gzip.open(cached_P_filepath, "rb") as f: 254 self.P = pickle.load(f) 255 else: 256 self.setup_transition_probabilities() 257 with gzip.open(cached_P_filepath, "wb") as f: 258 pickle.dump(self.P, f)
277 def transform_cont_obs(self, cont_obs): 278 x = cont_obs[0] 279 y = cont_obs[1] 280 theta = np.arctan2(y, x) 281 theta = wrap(theta, -np.pi, np.pi) 282 theta_dot = cont_obs[2] 283 theta_dot = np.clip(theta_dot, -8 + 1e-6, 8 - 1e-6) 284 285 angle_idx = self.discretize_angle(theta) 286 angular_velocity_idx = self.discretize_angular_velocity(theta_dot) 287 288 return self.state_to_index(angle_idx, angular_velocity_idx)
293 def setup_transition_probabilities(self): 294 state_space_values = list(range(self.state_space)) 295 296 args = [ 297 ( 298 state, 299 self.angle_bin_edges, 300 self.angular_velocity_bin_edges, 301 self.torque_bin_edges, 302 self.dim_samples, 303 ) 304 for state in state_space_values 305 ] 306 307 new_P = {} 308 309 args = [arg for arg in args if arg[0] not in new_P] 310 311 num_workers = self.n_workers 312 313 n_completed = len(new_P) 314 315 batch_size = 1000 316 317 with ProcessPoolExecutor(max_workers=num_workers) as executor: 318 for i in range(0, len(args), batch_size): 319 batch = args[i : i + batch_size] 320 futures = [ 321 executor.submit(setup_transition_probabilities_for_state, arg) 322 for arg in batch 323 ] 324 for future in as_completed(futures): 325 n_completed += 1 326 try: 327 state, P_state = future.result() 328 new_P[state] = P_state 329 if n_completed % 100 == 0: 330 print(f"Completed {n_completed}/{self.state_space}") 331 except Exception as e: 332 print(f"Error in future: {e}") 333 print("task failed") 334 335 self.P = new_P