bettermdptools.envs.pendulum_wrapper
Author: Aleksandr Spiridonov BSD 3-Clause License
1""" 2Author: Aleksandr Spiridonov 3BSD 3-Clause License 4""" 5 6import gymnasium as gym 7 8from bettermdptools.envs.pendulum_discretized import ( 9 DiscretizedPendulum, 10) # Ensure this path is correct 11 12 13class CustomTransformObservation(gym.ObservationWrapper): 14 def __init__(self, env, func, observation_space): 15 """ 16 Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not 17 accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). 18 Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) 19 to set both the conversion function and new observation space. 20 21 Parameters 22 ---------- 23 env : gymnasium.Env 24 Base environment to be wrapped 25 func : lambda 26 Function that converts the observation 27 observation_space : gymnasium.spaces.Space 28 New observation space 29 """ 30 super().__init__(env) 31 if observation_space is not None: 32 self.observation_space = observation_space 33 self.func = func 34 35 def observation(self, observation): 36 """ 37 Applies a function to the observation received from the environment's step function, 38 which is passed back to the user. 39 40 Parameters 41 ---------- 42 observation : Tuple 43 Base environment observation tuple 44 45 Returns 46 ------- 47 int 48 The converted observation (int). 49 """ 50 return self.func(observation) 51 52 53class PendulumWrapper(gym.Wrapper): 54 def __init__(self, env, angle_bins=11, angular_velocity_bins=11, torque_bins=11): 55 """ 56 Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P. 57 58 Parameters 59 ---------- 60 env : gymnasium.Env 61 Base environment 62 angle_bins : int 63 Number of discrete bins for the pendulum's angle. 64 angular_velocity_bins : int 65 Number of discrete bins for the pendulum's angular velocity. 66 torque_bins : int 67 Number of discrete bins for the torque action. 68 """ 69 super().__init__(env) 70 71 # Initialize the DiscretizedPendulum model 72 self.discretized_pendulum = DiscretizedPendulum( 73 angle_bins=angle_bins, 74 angular_velocity_bins=angular_velocity_bins, 75 torque_bins=torque_bins, 76 ) 77 78 # Transition probability matrix 79 self._P = self.discretized_pendulum.P 80 81 # Transformation function from continuous to discrete observations 82 self._transform_obs = self.discretized_pendulum.transform_cont_obs 83 self._get_action_value = self.discretized_pendulum.get_action_value 84 85 # Wrap the environment's observation space 86 self.observation_space = gym.spaces.Discrete( 87 self.discretized_pendulum.state_space 88 ) 89 self.env = CustomTransformObservation( 90 env, self._transform_obs, self.observation_space 91 ) 92 self.gym_env = env 93 94 # Override the action space to be discrete 95 self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space) 96 97 @property 98 def P(self): 99 """ 100 Returns the transition probability matrix. 101 102 Returns 103 ------- 104 dict 105 """ 106 return self._P 107 108 @property 109 def transform_obs(self): 110 """ 111 Returns the observation transformation function. 112 113 Returns 114 ------- 115 lambda 116 """ 117 return self._transform_obs 118 119 @property 120 def get_action_value(self): 121 f = lambda action: [self._get_action_value(action)] 122 return f 123 124 def step(self, action): 125 """ 126 Takes a discrete action, maps it to a continuous torque, and interacts with the environment. 127 128 Parameters 129 ---------- 130 action : int 131 The discrete action index. 132 133 Returns 134 ------- 135 int 136 The discretized next state index. 137 float 138 The reward obtained from the environment. 139 bool 140 Whether the episode has terminated. 141 dict 142 Additional information from the environment. 143 """ 144 # Map discrete action to continuous torque 145 torque = self.discretized_pendulum.get_action_value(action) 146 147 return self.env.step([torque]) 148 149 150def get_env_str(angle_bins, angular_velocity_bins, torque_bins): 151 """ 152 Returns the environment string based on the discretization parameters. 153 154 Parameters 155 ---------- 156 angle_bins : int 157 Number of discrete bins for the pendulum's angle. 158 angular_velocity_bins : int 159 Number of discrete bins for the pendulum's angular velocity. 160 torque_bins : int 161 Number of discrete bins for the torque action. 162 163 Returns 164 ------- 165 str 166 The environment string. 167 """ 168 return f"pendulum_{angle_bins}_{angular_velocity_bins}_{torque_bins}" 169 170 171def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11): 172 """ 173 Initializes the Pendulum wrapper environment. 174 175 Parameters 176 ---------- 177 angle_bins : int 178 Number of discrete bins for the pendulum's angle. 179 angular_velocity_bins : int 180 Number of discrete bins for the pendulum's angular velocity. 181 torque_bins : int 182 Number of discrete bins for the torque action. 183 184 Returns 185 ------- 186 PendulumWrapper 187 The Pendulum wrapper environment. 188 """ 189 pendulum_genv_train = gym.make("Pendulum-v1") 190 191 pendulum_train = PendulumWrapper( 192 angle_bins=angle_bins, 193 angular_velocity_bins=angular_velocity_bins, 194 torque_bins=torque_bins, 195 env=pendulum_genv_train, 196 ) 197 198 return pendulum_train
14class CustomTransformObservation(gym.ObservationWrapper): 15 def __init__(self, env, func, observation_space): 16 """ 17 Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not 18 accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). 19 Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) 20 to set both the conversion function and new observation space. 21 22 Parameters 23 ---------- 24 env : gymnasium.Env 25 Base environment to be wrapped 26 func : lambda 27 Function that converts the observation 28 observation_space : gymnasium.spaces.Space 29 New observation space 30 """ 31 super().__init__(env) 32 if observation_space is not None: 33 self.observation_space = observation_space 34 self.func = func 35 36 def observation(self, observation): 37 """ 38 Applies a function to the observation received from the environment's step function, 39 which is passed back to the user. 40 41 Parameters 42 ---------- 43 observation : Tuple 44 Base environment observation tuple 45 46 Returns 47 ------- 48 int 49 The converted observation (int). 50 """ 51 return self.func(observation)
Superclass of wrappers that can modify observations using observation()
for reset()
and step()
.
If you would like to apply a function to only the observation before
passing it to the learning code, you can simply inherit from ObservationWrapper
and overwrite the method
observation()
to implement that transformation. The transformation defined in that method must be
reflected by the env
observation space. Otherwise, you need to specify the new observation space of the
wrapper by setting self.observation_space
in the __init__()
method of your wrapper.
Among others, Gymnasium provides the observation wrapper TimeAwareObservation
, which adds information about the
index of the timestep to the observation.
15 def __init__(self, env, func, observation_space): 16 """ 17 Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not 18 accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). 19 Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) 20 to set both the conversion function and new observation space. 21 22 Parameters 23 ---------- 24 env : gymnasium.Env 25 Base environment to be wrapped 26 func : lambda 27 Function that converts the observation 28 observation_space : gymnasium.spaces.Space 29 New observation space 30 """ 31 super().__init__(env) 32 if observation_space is not None: 33 self.observation_space = observation_space 34 self.func = func
Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) to set both the conversion function and new observation space.
Parameters
- env (gymnasium.Env): Base environment to be wrapped
- func (lambda): Function that converts the observation
- observation_space (gymnasium.spaces.Space): New observation space
36 def observation(self, observation): 37 """ 38 Applies a function to the observation received from the environment's step function, 39 which is passed back to the user. 40 41 Parameters 42 ---------- 43 observation : Tuple 44 Base environment observation tuple 45 46 Returns 47 ------- 48 int 49 The converted observation (int). 50 """ 51 return self.func(observation)
Applies a function to the observation received from the environment's step function, which is passed back to the user.
Parameters
- observation (Tuple): Base environment observation tuple
Returns
- int: The converted observation (int).
54class PendulumWrapper(gym.Wrapper): 55 def __init__(self, env, angle_bins=11, angular_velocity_bins=11, torque_bins=11): 56 """ 57 Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P. 58 59 Parameters 60 ---------- 61 env : gymnasium.Env 62 Base environment 63 angle_bins : int 64 Number of discrete bins for the pendulum's angle. 65 angular_velocity_bins : int 66 Number of discrete bins for the pendulum's angular velocity. 67 torque_bins : int 68 Number of discrete bins for the torque action. 69 """ 70 super().__init__(env) 71 72 # Initialize the DiscretizedPendulum model 73 self.discretized_pendulum = DiscretizedPendulum( 74 angle_bins=angle_bins, 75 angular_velocity_bins=angular_velocity_bins, 76 torque_bins=torque_bins, 77 ) 78 79 # Transition probability matrix 80 self._P = self.discretized_pendulum.P 81 82 # Transformation function from continuous to discrete observations 83 self._transform_obs = self.discretized_pendulum.transform_cont_obs 84 self._get_action_value = self.discretized_pendulum.get_action_value 85 86 # Wrap the environment's observation space 87 self.observation_space = gym.spaces.Discrete( 88 self.discretized_pendulum.state_space 89 ) 90 self.env = CustomTransformObservation( 91 env, self._transform_obs, self.observation_space 92 ) 93 self.gym_env = env 94 95 # Override the action space to be discrete 96 self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space) 97 98 @property 99 def P(self): 100 """ 101 Returns the transition probability matrix. 102 103 Returns 104 ------- 105 dict 106 """ 107 return self._P 108 109 @property 110 def transform_obs(self): 111 """ 112 Returns the observation transformation function. 113 114 Returns 115 ------- 116 lambda 117 """ 118 return self._transform_obs 119 120 @property 121 def get_action_value(self): 122 f = lambda action: [self._get_action_value(action)] 123 return f 124 125 def step(self, action): 126 """ 127 Takes a discrete action, maps it to a continuous torque, and interacts with the environment. 128 129 Parameters 130 ---------- 131 action : int 132 The discrete action index. 133 134 Returns 135 ------- 136 int 137 The discretized next state index. 138 float 139 The reward obtained from the environment. 140 bool 141 Whether the episode has terminated. 142 dict 143 Additional information from the environment. 144 """ 145 # Map discrete action to continuous torque 146 torque = self.discretized_pendulum.get_action_value(action) 147 148 return self.env.step([torque])
Wraps a gymnasium.Env
to allow a modular transformation of the step()
and reset()
methods.
This class is the base class of all wrappers to change the behavior of the underlying environment.
Wrappers that inherit from this class can modify the action_space
, observation_space
,
reward_range
and metadata
attributes, without changing the underlying environment's attributes.
Moreover, the behavior of the step()
and reset()
methods can be changed by these wrappers.
Some attributes (spec
, render_mode
, np_random
) will point back to the wrapper's environment
(i.e. to the corresponding attributes of env
).
Note:
If you inherit from Wrapper
, don't forget to call super().__init__(env)
55 def __init__(self, env, angle_bins=11, angular_velocity_bins=11, torque_bins=11): 56 """ 57 Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P. 58 59 Parameters 60 ---------- 61 env : gymnasium.Env 62 Base environment 63 angle_bins : int 64 Number of discrete bins for the pendulum's angle. 65 angular_velocity_bins : int 66 Number of discrete bins for the pendulum's angular velocity. 67 torque_bins : int 68 Number of discrete bins for the torque action. 69 """ 70 super().__init__(env) 71 72 # Initialize the DiscretizedPendulum model 73 self.discretized_pendulum = DiscretizedPendulum( 74 angle_bins=angle_bins, 75 angular_velocity_bins=angular_velocity_bins, 76 torque_bins=torque_bins, 77 ) 78 79 # Transition probability matrix 80 self._P = self.discretized_pendulum.P 81 82 # Transformation function from continuous to discrete observations 83 self._transform_obs = self.discretized_pendulum.transform_cont_obs 84 self._get_action_value = self.discretized_pendulum.get_action_value 85 86 # Wrap the environment's observation space 87 self.observation_space = gym.spaces.Discrete( 88 self.discretized_pendulum.state_space 89 ) 90 self.env = CustomTransformObservation( 91 env, self._transform_obs, self.observation_space 92 ) 93 self.gym_env = env 94 95 # Override the action space to be discrete 96 self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)
Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
Parameters
- env (gymnasium.Env): Base environment
- angle_bins (int): Number of discrete bins for the pendulum's angle.
- angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
- torque_bins (int): Number of discrete bins for the torque action.
299 @property 300 def observation_space( 301 self, 302 ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]: 303 """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used.""" 304 if self._observation_space is None: 305 return self.env.observation_space 306 return self._observation_space
Return the Env
observation_space
unless overwritten then the wrapper observation_space
is used.
286 @property 287 def action_space( 288 self, 289 ) -> spaces.Space[ActType] | spaces.Space[WrapperActType]: 290 """Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used.""" 291 if self._action_space is None: 292 return self.env.action_space 293 return self._action_space
Return the Env
action_space
unless overwritten then the wrapper action_space
is used.
98 @property 99 def P(self): 100 """ 101 Returns the transition probability matrix. 102 103 Returns 104 ------- 105 dict 106 """ 107 return self._P
Returns the transition probability matrix.
Returns
- dict
109 @property 110 def transform_obs(self): 111 """ 112 Returns the observation transformation function. 113 114 Returns 115 ------- 116 lambda 117 """ 118 return self._transform_obs
Returns the observation transformation function.
Returns
- lambda
125 def step(self, action): 126 """ 127 Takes a discrete action, maps it to a continuous torque, and interacts with the environment. 128 129 Parameters 130 ---------- 131 action : int 132 The discrete action index. 133 134 Returns 135 ------- 136 int 137 The discretized next state index. 138 float 139 The reward obtained from the environment. 140 bool 141 Whether the episode has terminated. 142 dict 143 Additional information from the environment. 144 """ 145 # Map discrete action to continuous torque 146 torque = self.discretized_pendulum.get_action_value(action) 147 148 return self.env.step([torque])
Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
Parameters
- action (int): The discrete action index.
Returns
- int: The discretized next state index.
- float: The reward obtained from the environment.
- bool: Whether the episode has terminated.
- dict: Additional information from the environment.
151def get_env_str(angle_bins, angular_velocity_bins, torque_bins): 152 """ 153 Returns the environment string based on the discretization parameters. 154 155 Parameters 156 ---------- 157 angle_bins : int 158 Number of discrete bins for the pendulum's angle. 159 angular_velocity_bins : int 160 Number of discrete bins for the pendulum's angular velocity. 161 torque_bins : int 162 Number of discrete bins for the torque action. 163 164 Returns 165 ------- 166 str 167 The environment string. 168 """ 169 return f"pendulum_{angle_bins}_{angular_velocity_bins}_{torque_bins}"
Returns the environment string based on the discretization parameters.
Parameters
- angle_bins (int): Number of discrete bins for the pendulum's angle.
- angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
- torque_bins (int): Number of discrete bins for the torque action.
Returns
- str: The environment string.
172def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11): 173 """ 174 Initializes the Pendulum wrapper environment. 175 176 Parameters 177 ---------- 178 angle_bins : int 179 Number of discrete bins for the pendulum's angle. 180 angular_velocity_bins : int 181 Number of discrete bins for the pendulum's angular velocity. 182 torque_bins : int 183 Number of discrete bins for the torque action. 184 185 Returns 186 ------- 187 PendulumWrapper 188 The Pendulum wrapper environment. 189 """ 190 pendulum_genv_train = gym.make("Pendulum-v1") 191 192 pendulum_train = PendulumWrapper( 193 angle_bins=angle_bins, 194 angular_velocity_bins=angular_velocity_bins, 195 torque_bins=torque_bins, 196 env=pendulum_genv_train, 197 ) 198 199 return pendulum_train
Initializes the Pendulum wrapper environment.
Parameters
- angle_bins (int): Number of discrete bins for the pendulum's angle.
- angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
- torque_bins (int): Number of discrete bins for the torque action.
Returns
- PendulumWrapper: The Pendulum wrapper environment.