bettermdptools.envs.cartpole_wrapper

Author: John Mansfield BSD 3-Clause License

  1"""
  2Author: John Mansfield
  3BSD 3-Clause License
  4"""
  5
  6import gymnasium as gym
  7
  8from bettermdptools.envs.cartpole_model import DiscretizedCartPole
  9
 10
 11class CustomTransformObservation(gym.ObservationWrapper):
 12    def __init__(self, env, func, observation_space):
 13        """
 14        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
 15        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
 16        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
 17        to set both the conversion function and new observation space.
 18
 19        Parameters
 20        ----------
 21        env : gymnasium.Env
 22            Base environment to be wrapped.
 23        func : lambda
 24            Function that converts the observation.
 25        observation_space : gymnasium.spaces.Space
 26            New observation space.
 27        """
 28        super().__init__(env)
 29        if observation_space is not None:
 30            self.observation_space = observation_space
 31        self.func = func
 32
 33    def observation(self, observation):
 34        """
 35        Applies a function to the observation received from the environment's step function,
 36        which is passed back to the user.
 37
 38        Parameters
 39        ----------
 40        observation : Tuple
 41            Base environment observation tuple.
 42
 43        Returns
 44        -------
 45        int
 46            The converted observation.
 47        """
 48        return self.func(observation)
 49
 50
 51class CartpoleWrapper(gym.Wrapper):
 52    def __init__(
 53        self,
 54        env,
 55        position_bins=10,
 56        velocity_bins=10,
 57        angular_velocity_bins=10,
 58        angular_center_resolution=0.1,
 59        angular_outer_resolution=0.5,
 60    ):
 61        """
 62        Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P.
 63
 64        Parameters
 65        ----------
 66        env : gymnasium.Env
 67            Base environment.
 68        position_bins : int, optional
 69            Number of discrete bins for the cart's position.
 70        velocity_bins : int, optional
 71            Number of discrete bins for the cart's velocity.
 72        angular_velocity_bins : int, optional
 73            Number of discrete bins for the pole's angular velocity.
 74        angular_center_resolution : float, optional
 75            The resolution of angle bins near the center (around zero).
 76        angular_outer_resolution : float, optional
 77            The resolution of angle bins away from the center.
 78        """
 79        dpole = DiscretizedCartPole(
 80            position_bins=position_bins,
 81            velocity_bins=velocity_bins,
 82            angular_velocity_bins=angular_velocity_bins,
 83            angular_center_resolution=angular_center_resolution,
 84            angular_outer_resolution=angular_outer_resolution,
 85        )
 86        self._P = dpole.P
 87        self._transform_obs = dpole.transform_obs
 88        env = CustomTransformObservation(
 89            env, self._transform_obs, gym.spaces.Discrete(dpole.n_states)
 90        )
 91        super().__init__(env)
 92
 93    @property
 94    def P(self):
 95        """
 96        Returns
 97        -------
 98        dict
 99            Transition/reward matrix.
100        """
101        return self._P
102
103    @property
104    def transform_obs(self):
105        """
106        Returns
107        -------
108        lambda
109            Function that converts the observation.
110        """
111        return self._transform_obs
class CustomTransformObservation(gymnasium.core.Wrapper[~WrapperObsType, ~ActType]):
12class CustomTransformObservation(gym.ObservationWrapper):
13    def __init__(self, env, func, observation_space):
14        """
15        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
16        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
17        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
18        to set both the conversion function and new observation space.
19
20        Parameters
21        ----------
22        env : gymnasium.Env
23            Base environment to be wrapped.
24        func : lambda
25            Function that converts the observation.
26        observation_space : gymnasium.spaces.Space
27            New observation space.
28        """
29        super().__init__(env)
30        if observation_space is not None:
31            self.observation_space = observation_space
32        self.func = func
33
34    def observation(self, observation):
35        """
36        Applies a function to the observation received from the environment's step function,
37        which is passed back to the user.
38
39        Parameters
40        ----------
41        observation : Tuple
42            Base environment observation tuple.
43
44        Returns
45        -------
46        int
47            The converted observation.
48        """
49        return self.func(observation)

Superclass of wrappers that can modify observations using observation() for reset() and step().

If you would like to apply a function to only the observation before passing it to the learning code, you can simply inherit from ObservationWrapper and overwrite the method observation() to implement that transformation. The transformation defined in that method must be reflected by the env observation space. Otherwise, you need to specify the new observation space of the wrapper by setting self.observation_space in the __init__() method of your wrapper.

Among others, Gymnasium provides the observation wrapper TimeAwareObservation, which adds information about the index of the timestep to the observation.

CustomTransformObservation(env, func, observation_space)
13    def __init__(self, env, func, observation_space):
14        """
15        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
16        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
17        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
18        to set both the conversion function and new observation space.
19
20        Parameters
21        ----------
22        env : gymnasium.Env
23            Base environment to be wrapped.
24        func : lambda
25            Function that converts the observation.
26        observation_space : gymnasium.spaces.Space
27            New observation space.
28        """
29        super().__init__(env)
30        if observation_space is not None:
31            self.observation_space = observation_space
32        self.func = func

Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) to set both the conversion function and new observation space.

Parameters
  • env (gymnasium.Env): Base environment to be wrapped.
  • func (lambda): Function that converts the observation.
  • observation_space (gymnasium.spaces.Space): New observation space.
func
def observation(self, observation):
34    def observation(self, observation):
35        """
36        Applies a function to the observation received from the environment's step function,
37        which is passed back to the user.
38
39        Parameters
40        ----------
41        observation : Tuple
42            Base environment observation tuple.
43
44        Returns
45        -------
46        int
47            The converted observation.
48        """
49        return self.func(observation)

Applies a function to the observation received from the environment's step function, which is passed back to the user.

Parameters
  • observation (Tuple): Base environment observation tuple.
Returns
  • int: The converted observation.
class CartpoleWrapper(gymnasium.core.Env[~WrapperObsType, ~WrapperActType]):
 52class CartpoleWrapper(gym.Wrapper):
 53    def __init__(
 54        self,
 55        env,
 56        position_bins=10,
 57        velocity_bins=10,
 58        angular_velocity_bins=10,
 59        angular_center_resolution=0.1,
 60        angular_outer_resolution=0.5,
 61    ):
 62        """
 63        Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P.
 64
 65        Parameters
 66        ----------
 67        env : gymnasium.Env
 68            Base environment.
 69        position_bins : int, optional
 70            Number of discrete bins for the cart's position.
 71        velocity_bins : int, optional
 72            Number of discrete bins for the cart's velocity.
 73        angular_velocity_bins : int, optional
 74            Number of discrete bins for the pole's angular velocity.
 75        angular_center_resolution : float, optional
 76            The resolution of angle bins near the center (around zero).
 77        angular_outer_resolution : float, optional
 78            The resolution of angle bins away from the center.
 79        """
 80        dpole = DiscretizedCartPole(
 81            position_bins=position_bins,
 82            velocity_bins=velocity_bins,
 83            angular_velocity_bins=angular_velocity_bins,
 84            angular_center_resolution=angular_center_resolution,
 85            angular_outer_resolution=angular_outer_resolution,
 86        )
 87        self._P = dpole.P
 88        self._transform_obs = dpole.transform_obs
 89        env = CustomTransformObservation(
 90            env, self._transform_obs, gym.spaces.Discrete(dpole.n_states)
 91        )
 92        super().__init__(env)
 93
 94    @property
 95    def P(self):
 96        """
 97        Returns
 98        -------
 99        dict
100            Transition/reward matrix.
101        """
102        return self._P
103
104    @property
105    def transform_obs(self):
106        """
107        Returns
108        -------
109        lambda
110            Function that converts the observation.
111        """
112        return self._transform_obs

Wraps a gymnasium.Env to allow a modular transformation of the step() and reset() methods.

This class is the base class of all wrappers to change the behavior of the underlying environment. Wrappers that inherit from this class can modify the action_space, observation_space, reward_range and metadata attributes, without changing the underlying environment's attributes. Moreover, the behavior of the step() and reset() methods can be changed by these wrappers.

Some attributes (spec, render_mode, np_random) will point back to the wrapper's environment (i.e. to the corresponding attributes of env).

Note: If you inherit from Wrapper, don't forget to call super().__init__(env)

CartpoleWrapper( env, position_bins=10, velocity_bins=10, angular_velocity_bins=10, angular_center_resolution=0.1, angular_outer_resolution=0.5)
53    def __init__(
54        self,
55        env,
56        position_bins=10,
57        velocity_bins=10,
58        angular_velocity_bins=10,
59        angular_center_resolution=0.1,
60        angular_outer_resolution=0.5,
61    ):
62        """
63        Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P.
64
65        Parameters
66        ----------
67        env : gymnasium.Env
68            Base environment.
69        position_bins : int, optional
70            Number of discrete bins for the cart's position.
71        velocity_bins : int, optional
72            Number of discrete bins for the cart's velocity.
73        angular_velocity_bins : int, optional
74            Number of discrete bins for the pole's angular velocity.
75        angular_center_resolution : float, optional
76            The resolution of angle bins near the center (around zero).
77        angular_outer_resolution : float, optional
78            The resolution of angle bins away from the center.
79        """
80        dpole = DiscretizedCartPole(
81            position_bins=position_bins,
82            velocity_bins=velocity_bins,
83            angular_velocity_bins=angular_velocity_bins,
84            angular_center_resolution=angular_center_resolution,
85            angular_outer_resolution=angular_outer_resolution,
86        )
87        self._P = dpole.P
88        self._transform_obs = dpole.transform_obs
89        env = CustomTransformObservation(
90            env, self._transform_obs, gym.spaces.Discrete(dpole.n_states)
91        )
92        super().__init__(env)

Cartpole wrapper that modifies the observation space and creates a transition/reward matrix P.

Parameters
  • env (gymnasium.Env): Base environment.
  • position_bins (int, optional): Number of discrete bins for the cart's position.
  • velocity_bins (int, optional): Number of discrete bins for the cart's velocity.
  • angular_velocity_bins (int, optional): Number of discrete bins for the pole's angular velocity.
  • angular_center_resolution (float, optional): The resolution of angle bins near the center (around zero).
  • angular_outer_resolution (float, optional): The resolution of angle bins away from the center.
P
 94    @property
 95    def P(self):
 96        """
 97        Returns
 98        -------
 99        dict
100            Transition/reward matrix.
101        """
102        return self._P
Returns
  • dict: Transition/reward matrix.
transform_obs
104    @property
105    def transform_obs(self):
106        """
107        Returns
108        -------
109        lambda
110            Function that converts the observation.
111        """
112        return self._transform_obs
Returns
  • lambda: Function that converts the observation.