bettermdptools.envs.cartpole_wrapper
Author: John Mansfield BSD 3-Clause License
1""" 2Author: John Mansfield 3BSD 3-Clause License 4""" 5 6import gymnasium as gym 7 8from bettermdptools.envs.cartpole_model import DiscretizedCartPole 9 10 11class CustomTransformObservation(gym.ObservationWrapper): 12 def __init__(self, env, func, observation_space): 13 """ 14 Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not 15 accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). 16 Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) 17 to set both the conversion function and new observation space. 18 19 Parameters 20 ---------- 21 env : gymnasium.Env 22 Base environment to be wrapped. 23 func : lambda 24 Function that converts the observation. 25 observation_space : gymnasium.spaces.Space 26 New observation space. 27 """ 28 super().__init__(env) 29 if observation_space is not None: 30 self.observation_space = observation_space 31 self.func = func 32 33 def observation(self, observation): 34 """ 35 Applies a function to the observation received from the environment's step function, 36 which is passed back to the user. 37 38 Parameters 39 ---------- 40 observation : Tuple 41 Base environment observation tuple. 42 43 Returns 44 ------- 45 int 46 The converted observation. 47 """ 48 return self.func(observation) 49 50 51class CartpoleWrapper(gym.Wrapper): 52 def __init__( 53 self, 54 env, 55 position_bins=10, 56 velocity_bins=10, 57 angular_velocity_bins=10, 58 angular_center_resolution=0.1, 59 angular_outer_resolution=0.5, 60 ): 61 """ 62 Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P. 63 64 Parameters 65 ---------- 66 env : gymnasium.Env 67 Base environment. 68 position_bins : int, optional 69 Number of discrete bins for the cart's position. 70 velocity_bins : int, optional 71 Number of discrete bins for the cart's velocity. 72 angular_velocity_bins : int, optional 73 Number of discrete bins for the pole's angular velocity. 74 angular_center_resolution : float, optional 75 The resolution of angle bins near the center (around zero). 76 angular_outer_resolution : float, optional 77 The resolution of angle bins away from the center. 78 """ 79 dpole = DiscretizedCartPole( 80 position_bins=position_bins, 81 velocity_bins=velocity_bins, 82 angular_velocity_bins=angular_velocity_bins, 83 angular_center_resolution=angular_center_resolution, 84 angular_outer_resolution=angular_outer_resolution, 85 ) 86 self._P = dpole.P 87 self._transform_obs = dpole.transform_obs 88 env = CustomTransformObservation( 89 env, self._transform_obs, gym.spaces.Discrete(dpole.n_states) 90 ) 91 super().__init__(env) 92 93 @property 94 def P(self): 95 """ 96 Returns 97 ------- 98 dict 99 Transition/reward matrix. 100 """ 101 return self._P 102 103 @property 104 def transform_obs(self): 105 """ 106 Returns 107 ------- 108 lambda 109 Function that converts the observation. 110 """ 111 return self._transform_obs
12class CustomTransformObservation(gym.ObservationWrapper): 13 def __init__(self, env, func, observation_space): 14 """ 15 Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not 16 accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). 17 Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) 18 to set both the conversion function and new observation space. 19 20 Parameters 21 ---------- 22 env : gymnasium.Env 23 Base environment to be wrapped. 24 func : lambda 25 Function that converts the observation. 26 observation_space : gymnasium.spaces.Space 27 New observation space. 28 """ 29 super().__init__(env) 30 if observation_space is not None: 31 self.observation_space = observation_space 32 self.func = func 33 34 def observation(self, observation): 35 """ 36 Applies a function to the observation received from the environment's step function, 37 which is passed back to the user. 38 39 Parameters 40 ---------- 41 observation : Tuple 42 Base environment observation tuple. 43 44 Returns 45 ------- 46 int 47 The converted observation. 48 """ 49 return self.func(observation)
Superclass of wrappers that can modify observations using observation()
for reset()
and step()
.
If you would like to apply a function to only the observation before
passing it to the learning code, you can simply inherit from ObservationWrapper
and overwrite the method
observation()
to implement that transformation. The transformation defined in that method must be
reflected by the env
observation space. Otherwise, you need to specify the new observation space of the
wrapper by setting self.observation_space
in the __init__()
method of your wrapper.
Among others, Gymnasium provides the observation wrapper TimeAwareObservation
, which adds information about the
index of the timestep to the observation.
13 def __init__(self, env, func, observation_space): 14 """ 15 Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not 16 accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). 17 Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) 18 to set both the conversion function and new observation space. 19 20 Parameters 21 ---------- 22 env : gymnasium.Env 23 Base environment to be wrapped. 24 func : lambda 25 Function that converts the observation. 26 observation_space : gymnasium.spaces.Space 27 New observation space. 28 """ 29 super().__init__(env) 30 if observation_space is not None: 31 self.observation_space = observation_space 32 self.func = func
Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) to set both the conversion function and new observation space.
Parameters
- env (gymnasium.Env): Base environment to be wrapped.
- func (lambda): Function that converts the observation.
- observation_space (gymnasium.spaces.Space): New observation space.
34 def observation(self, observation): 35 """ 36 Applies a function to the observation received from the environment's step function, 37 which is passed back to the user. 38 39 Parameters 40 ---------- 41 observation : Tuple 42 Base environment observation tuple. 43 44 Returns 45 ------- 46 int 47 The converted observation. 48 """ 49 return self.func(observation)
Applies a function to the observation received from the environment's step function, which is passed back to the user.
Parameters
- observation (Tuple): Base environment observation tuple.
Returns
- int: The converted observation.
52class CartpoleWrapper(gym.Wrapper): 53 def __init__( 54 self, 55 env, 56 position_bins=10, 57 velocity_bins=10, 58 angular_velocity_bins=10, 59 angular_center_resolution=0.1, 60 angular_outer_resolution=0.5, 61 ): 62 """ 63 Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P. 64 65 Parameters 66 ---------- 67 env : gymnasium.Env 68 Base environment. 69 position_bins : int, optional 70 Number of discrete bins for the cart's position. 71 velocity_bins : int, optional 72 Number of discrete bins for the cart's velocity. 73 angular_velocity_bins : int, optional 74 Number of discrete bins for the pole's angular velocity. 75 angular_center_resolution : float, optional 76 The resolution of angle bins near the center (around zero). 77 angular_outer_resolution : float, optional 78 The resolution of angle bins away from the center. 79 """ 80 dpole = DiscretizedCartPole( 81 position_bins=position_bins, 82 velocity_bins=velocity_bins, 83 angular_velocity_bins=angular_velocity_bins, 84 angular_center_resolution=angular_center_resolution, 85 angular_outer_resolution=angular_outer_resolution, 86 ) 87 self._P = dpole.P 88 self._transform_obs = dpole.transform_obs 89 env = CustomTransformObservation( 90 env, self._transform_obs, gym.spaces.Discrete(dpole.n_states) 91 ) 92 super().__init__(env) 93 94 @property 95 def P(self): 96 """ 97 Returns 98 ------- 99 dict 100 Transition/reward matrix. 101 """ 102 return self._P 103 104 @property 105 def transform_obs(self): 106 """ 107 Returns 108 ------- 109 lambda 110 Function that converts the observation. 111 """ 112 return self._transform_obs
Wraps a gymnasium.Env
to allow a modular transformation of the step()
and reset()
methods.
This class is the base class of all wrappers to change the behavior of the underlying environment.
Wrappers that inherit from this class can modify the action_space
, observation_space
,
reward_range
and metadata
attributes, without changing the underlying environment's attributes.
Moreover, the behavior of the step()
and reset()
methods can be changed by these wrappers.
Some attributes (spec
, render_mode
, np_random
) will point back to the wrapper's environment
(i.e. to the corresponding attributes of env
).
Note:
If you inherit from Wrapper
, don't forget to call super().__init__(env)
53 def __init__( 54 self, 55 env, 56 position_bins=10, 57 velocity_bins=10, 58 angular_velocity_bins=10, 59 angular_center_resolution=0.1, 60 angular_outer_resolution=0.5, 61 ): 62 """ 63 Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P. 64 65 Parameters 66 ---------- 67 env : gymnasium.Env 68 Base environment. 69 position_bins : int, optional 70 Number of discrete bins for the cart's position. 71 velocity_bins : int, optional 72 Number of discrete bins for the cart's velocity. 73 angular_velocity_bins : int, optional 74 Number of discrete bins for the pole's angular velocity. 75 angular_center_resolution : float, optional 76 The resolution of angle bins near the center (around zero). 77 angular_outer_resolution : float, optional 78 The resolution of angle bins away from the center. 79 """ 80 dpole = DiscretizedCartPole( 81 position_bins=position_bins, 82 velocity_bins=velocity_bins, 83 angular_velocity_bins=angular_velocity_bins, 84 angular_center_resolution=angular_center_resolution, 85 angular_outer_resolution=angular_outer_resolution, 86 ) 87 self._P = dpole.P 88 self._transform_obs = dpole.transform_obs 89 env = CustomTransformObservation( 90 env, self._transform_obs, gym.spaces.Discrete(dpole.n_states) 91 ) 92 super().__init__(env)
Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P.
Parameters
- env (gymnasium.Env): Base environment.
- position_bins (int, optional): Number of discrete bins for the cart's position.
- velocity_bins (int, optional): Number of discrete bins for the cart's velocity.
- angular_velocity_bins (int, optional): Number of discrete bins for the pole's angular velocity.
- angular_center_resolution (float, optional): The resolution of angle bins near the center (around zero).
- angular_outer_resolution (float, optional): The resolution of angle bins away from the center.