bettermdptools.envs.pendulum_wrapper

Author: Aleksandr Spiridonov BSD 3-Clause License

  1"""
  2Author: Aleksandr Spiridonov
  3BSD 3-Clause License
  4"""
  5
  6import gymnasium as gym
  7
  8from bettermdptools.envs.pendulum_discretized import (
  9    DiscretizedPendulum,
 10)  # Ensure this path is correct
 11
 12
 13class CustomTransformObservation(gym.ObservationWrapper):
 14    def __init__(self, env, func, observation_space):
 15        """
 16        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
 17        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
 18        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
 19        to set both the conversion function and new observation space.
 20
 21        Parameters
 22        ----------
 23        env : gymnasium.Env
 24            Base environment to be wrapped
 25        func : lambda
 26            Function that converts the observation
 27        observation_space : gymnasium.spaces.Space
 28            New observation space
 29        """
 30        super().__init__(env)
 31        if observation_space is not None:
 32            self.observation_space = observation_space
 33        self.func = func
 34
 35    def observation(self, observation):
 36        """
 37        Applies a function to the observation received from the environment's step function,
 38        which is passed back to the user.
 39
 40        Parameters
 41        ----------
 42        observation : Tuple
 43            Base environment observation tuple
 44
 45        Returns
 46        -------
 47        int
 48            The converted observation (int).
 49        """
 50        return self.func(observation)
 51
 52
 53class PendulumWrapper(gym.Wrapper):
 54    def __init__(self, env, angle_bins=11, angular_velocity_bins=11, torque_bins=11):
 55        """
 56        Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
 57
 58        Parameters
 59        ----------
 60        env : gymnasium.Env
 61            Base environment
 62        angle_bins : int
 63            Number of discrete bins for the pendulum's angle.
 64        angular_velocity_bins : int
 65            Number of discrete bins for the pendulum's angular velocity.
 66        torque_bins : int
 67            Number of discrete bins for the torque action.
 68        """
 69        super().__init__(env)
 70
 71        # Initialize the DiscretizedPendulum model
 72        self.discretized_pendulum = DiscretizedPendulum(
 73            angle_bins=angle_bins,
 74            angular_velocity_bins=angular_velocity_bins,
 75            torque_bins=torque_bins,
 76        )
 77
 78        # Transition probability matrix
 79        self._P = self.discretized_pendulum.P
 80
 81        # Transformation function from continuous to discrete observations
 82        self._transform_obs = self.discretized_pendulum.transform_cont_obs
 83        self._get_action_value = self.discretized_pendulum.get_action_value
 84
 85        # Wrap the environment's observation space
 86        self.observation_space = gym.spaces.Discrete(
 87            self.discretized_pendulum.state_space
 88        )
 89        self.env = CustomTransformObservation(
 90            env, self._transform_obs, self.observation_space
 91        )
 92        self.gym_env = env
 93
 94        # Override the action space to be discrete
 95        self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)
 96
 97    @property
 98    def P(self):
 99        """
100        Returns the transition probability matrix.
101
102        Returns
103        -------
104        dict
105        """
106        return self._P
107
108    @property
109    def transform_obs(self):
110        """
111        Returns the observation transformation function.
112
113        Returns
114        -------
115        lambda
116        """
117        return self._transform_obs
118
119    @property
120    def get_action_value(self):
121        f = lambda action: [self._get_action_value(action)]
122        return f
123
124    def step(self, action):
125        """
126        Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
127
128        Parameters
129        ----------
130        action : int
131            The discrete action index.
132
133        Returns
134        -------
135        int
136            The discretized next state index.
137        float
138            The reward obtained from the environment.
139        bool
140            Whether the episode has terminated.
141        dict
142            Additional information from the environment.
143        """
144        # Map discrete action to continuous torque
145        torque = self.discretized_pendulum.get_action_value(action)
146
147        return self.env.step([torque])
148
149
150def get_env_str(angle_bins, angular_velocity_bins, torque_bins):
151    """
152    Returns the environment string based on the discretization parameters.
153
154    Parameters
155    ----------
156    angle_bins : int
157        Number of discrete bins for the pendulum's angle.
158    angular_velocity_bins : int
159        Number of discrete bins for the pendulum's angular velocity.
160    torque_bins : int
161        Number of discrete bins for the torque action.
162
163    Returns
164    -------
165    str
166        The environment string.
167    """
168    return f"pendulum_{angle_bins}_{angular_velocity_bins}_{torque_bins}"
169
170
171def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11):
172    """
173    Initializes the Pendulum wrapper environment.
174
175    Parameters
176    ----------
177    angle_bins : int
178        Number of discrete bins for the pendulum's angle.
179    angular_velocity_bins : int
180        Number of discrete bins for the pendulum's angular velocity.
181    torque_bins : int
182        Number of discrete bins for the torque action.
183
184    Returns
185    -------
186    PendulumWrapper
187        The Pendulum wrapper environment.
188    """
189    pendulum_genv_train = gym.make("Pendulum-v1")
190
191    pendulum_train = PendulumWrapper(
192        angle_bins=angle_bins,
193        angular_velocity_bins=angular_velocity_bins,
194        torque_bins=torque_bins,
195        env=pendulum_genv_train,
196    )
197
198    return pendulum_train
class CustomTransformObservation(gymnasium.core.Wrapper[~WrapperObsType, ~ActType]):
14class CustomTransformObservation(gym.ObservationWrapper):
15    def __init__(self, env, func, observation_space):
16        """
17        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
18        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
19        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
20        to set both the conversion function and new observation space.
21
22        Parameters
23        ----------
24        env : gymnasium.Env
25            Base environment to be wrapped
26        func : lambda
27            Function that converts the observation
28        observation_space : gymnasium.spaces.Space
29            New observation space
30        """
31        super().__init__(env)
32        if observation_space is not None:
33            self.observation_space = observation_space
34        self.func = func
35
36    def observation(self, observation):
37        """
38        Applies a function to the observation received from the environment's step function,
39        which is passed back to the user.
40
41        Parameters
42        ----------
43        observation : Tuple
44            Base environment observation tuple
45
46        Returns
47        -------
48        int
49            The converted observation (int).
50        """
51        return self.func(observation)

Superclass of wrappers that can modify observations using observation() for reset() and step().

If you would like to apply a function to only the observation before passing it to the learning code, you can simply inherit from ObservationWrapper and overwrite the method observation() to implement that transformation. The transformation defined in that method must be reflected by the env observation space. Otherwise, you need to specify the new observation space of the wrapper by setting self.observation_space in the __init__() method of your wrapper.

Among others, Gymnasium provides the observation wrapper TimeAwareObservation, which adds information about the index of the timestep to the observation.

CustomTransformObservation(env, func, observation_space)
15    def __init__(self, env, func, observation_space):
16        """
17        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
18        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
19        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
20        to set both the conversion function and new observation space.
21
22        Parameters
23        ----------
24        env : gymnasium.Env
25            Base environment to be wrapped
26        func : lambda
27            Function that converts the observation
28        observation_space : gymnasium.spaces.Space
29            New observation space
30        """
31        super().__init__(env)
32        if observation_space is not None:
33            self.observation_space = observation_space
34        self.func = func

Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) to set both the conversion function and new observation space.

Parameters
  • env (gymnasium.Env): Base environment to be wrapped
  • func (lambda): Function that converts the observation
  • observation_space (gymnasium.spaces.Space): New observation space
func
def observation(self, observation):
36    def observation(self, observation):
37        """
38        Applies a function to the observation received from the environment's step function,
39        which is passed back to the user.
40
41        Parameters
42        ----------
43        observation : Tuple
44            Base environment observation tuple
45
46        Returns
47        -------
48        int
49            The converted observation (int).
50        """
51        return self.func(observation)

Applies a function to the observation received from the environment's step function, which is passed back to the user.

Parameters
  • observation (Tuple): Base environment observation tuple
Returns
  • int: The converted observation (int).
class PendulumWrapper(gymnasium.core.Env[~WrapperObsType, ~WrapperActType]):
 54class PendulumWrapper(gym.Wrapper):
 55    def __init__(self, env, angle_bins=11, angular_velocity_bins=11, torque_bins=11):
 56        """
 57        Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
 58
 59        Parameters
 60        ----------
 61        env : gymnasium.Env
 62            Base environment
 63        angle_bins : int
 64            Number of discrete bins for the pendulum's angle.
 65        angular_velocity_bins : int
 66            Number of discrete bins for the pendulum's angular velocity.
 67        torque_bins : int
 68            Number of discrete bins for the torque action.
 69        """
 70        super().__init__(env)
 71
 72        # Initialize the DiscretizedPendulum model
 73        self.discretized_pendulum = DiscretizedPendulum(
 74            angle_bins=angle_bins,
 75            angular_velocity_bins=angular_velocity_bins,
 76            torque_bins=torque_bins,
 77        )
 78
 79        # Transition probability matrix
 80        self._P = self.discretized_pendulum.P
 81
 82        # Transformation function from continuous to discrete observations
 83        self._transform_obs = self.discretized_pendulum.transform_cont_obs
 84        self._get_action_value = self.discretized_pendulum.get_action_value
 85
 86        # Wrap the environment's observation space
 87        self.observation_space = gym.spaces.Discrete(
 88            self.discretized_pendulum.state_space
 89        )
 90        self.env = CustomTransformObservation(
 91            env, self._transform_obs, self.observation_space
 92        )
 93        self.gym_env = env
 94
 95        # Override the action space to be discrete
 96        self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)
 97
 98    @property
 99    def P(self):
100        """
101        Returns the transition probability matrix.
102
103        Returns
104        -------
105        dict
106        """
107        return self._P
108
109    @property
110    def transform_obs(self):
111        """
112        Returns the observation transformation function.
113
114        Returns
115        -------
116        lambda
117        """
118        return self._transform_obs
119
120    @property
121    def get_action_value(self):
122        f = lambda action: [self._get_action_value(action)]
123        return f
124
125    def step(self, action):
126        """
127        Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
128
129        Parameters
130        ----------
131        action : int
132            The discrete action index.
133
134        Returns
135        -------
136        int
137            The discretized next state index.
138        float
139            The reward obtained from the environment.
140        bool
141            Whether the episode has terminated.
142        dict
143            Additional information from the environment.
144        """
145        # Map discrete action to continuous torque
146        torque = self.discretized_pendulum.get_action_value(action)
147
148        return self.env.step([torque])

Wraps a gymnasium.Env to allow a modular transformation of the step() and reset() methods.

This class is the base class of all wrappers to change the behavior of the underlying environment. Wrappers that inherit from this class can modify the action_space, observation_space, reward_range and metadata attributes, without changing the underlying environment's attributes. Moreover, the behavior of the step() and reset() methods can be changed by these wrappers.

Some attributes (spec, render_mode, np_random) will point back to the wrapper's environment (i.e. to the corresponding attributes of env).

Note: If you inherit from Wrapper, don't forget to call super().__init__(env)

PendulumWrapper(env, angle_bins=11, angular_velocity_bins=11, torque_bins=11)
55    def __init__(self, env, angle_bins=11, angular_velocity_bins=11, torque_bins=11):
56        """
57        Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
58
59        Parameters
60        ----------
61        env : gymnasium.Env
62            Base environment
63        angle_bins : int
64            Number of discrete bins for the pendulum's angle.
65        angular_velocity_bins : int
66            Number of discrete bins for the pendulum's angular velocity.
67        torque_bins : int
68            Number of discrete bins for the torque action.
69        """
70        super().__init__(env)
71
72        # Initialize the DiscretizedPendulum model
73        self.discretized_pendulum = DiscretizedPendulum(
74            angle_bins=angle_bins,
75            angular_velocity_bins=angular_velocity_bins,
76            torque_bins=torque_bins,
77        )
78
79        # Transition probability matrix
80        self._P = self.discretized_pendulum.P
81
82        # Transformation function from continuous to discrete observations
83        self._transform_obs = self.discretized_pendulum.transform_cont_obs
84        self._get_action_value = self.discretized_pendulum.get_action_value
85
86        # Wrap the environment's observation space
87        self.observation_space = gym.spaces.Discrete(
88            self.discretized_pendulum.state_space
89        )
90        self.env = CustomTransformObservation(
91            env, self._transform_obs, self.observation_space
92        )
93        self.gym_env = env
94
95        # Override the action space to be discrete
96        self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)

Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.

Parameters
  • env (gymnasium.Env): Base environment
  • angle_bins (int): Number of discrete bins for the pendulum's angle.
  • angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
  • torque_bins (int): Number of discrete bins for the torque action.
discretized_pendulum
observation_space: Union[gymnasium.spaces.space.Space[~ObsType], gymnasium.spaces.space.Space[~WrapperObsType]]
299    @property
300    def observation_space(
301        self,
302    ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]:
303        """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
304        if self._observation_space is None:
305            return self.env.observation_space
306        return self._observation_space

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

env
gym_env
action_space: Union[gymnasium.spaces.space.Space[~ActType], gymnasium.spaces.space.Space[~WrapperActType]]
286    @property
287    def action_space(
288        self,
289    ) -> spaces.Space[ActType] | spaces.Space[WrapperActType]:
290        """Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used."""
291        if self._action_space is None:
292            return self.env.action_space
293        return self._action_space

Return the Env action_space unless overwritten then the wrapper action_space is used.

P
 98    @property
 99    def P(self):
100        """
101        Returns the transition probability matrix.
102
103        Returns
104        -------
105        dict
106        """
107        return self._P

Returns the transition probability matrix.

Returns
  • dict
transform_obs
109    @property
110    def transform_obs(self):
111        """
112        Returns the observation transformation function.
113
114        Returns
115        -------
116        lambda
117        """
118        return self._transform_obs

Returns the observation transformation function.

Returns
  • lambda
get_action_value
120    @property
121    def get_action_value(self):
122        f = lambda action: [self._get_action_value(action)]
123        return f
def step(self, action):
125    def step(self, action):
126        """
127        Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
128
129        Parameters
130        ----------
131        action : int
132            The discrete action index.
133
134        Returns
135        -------
136        int
137            The discretized next state index.
138        float
139            The reward obtained from the environment.
140        bool
141            Whether the episode has terminated.
142        dict
143            Additional information from the environment.
144        """
145        # Map discrete action to continuous torque
146        torque = self.discretized_pendulum.get_action_value(action)
147
148        return self.env.step([torque])

Takes a discrete action, maps it to a continuous torque, and interacts with the environment.

Parameters
  • action (int): The discrete action index.
Returns
  • int: The discretized next state index.
  • float: The reward obtained from the environment.
  • bool: Whether the episode has terminated.
  • dict: Additional information from the environment.
def get_env_str(angle_bins, angular_velocity_bins, torque_bins):
151def get_env_str(angle_bins, angular_velocity_bins, torque_bins):
152    """
153    Returns the environment string based on the discretization parameters.
154
155    Parameters
156    ----------
157    angle_bins : int
158        Number of discrete bins for the pendulum's angle.
159    angular_velocity_bins : int
160        Number of discrete bins for the pendulum's angular velocity.
161    torque_bins : int
162        Number of discrete bins for the torque action.
163
164    Returns
165    -------
166    str
167        The environment string.
168    """
169    return f"pendulum_{angle_bins}_{angular_velocity_bins}_{torque_bins}"

Returns the environment string based on the discretization parameters.

Parameters
  • angle_bins (int): Number of discrete bins for the pendulum's angle.
  • angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
  • torque_bins (int): Number of discrete bins for the torque action.
Returns
  • str: The environment string.
def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11):
172def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11):
173    """
174    Initializes the Pendulum wrapper environment.
175
176    Parameters
177    ----------
178    angle_bins : int
179        Number of discrete bins for the pendulum's angle.
180    angular_velocity_bins : int
181        Number of discrete bins for the pendulum's angular velocity.
182    torque_bins : int
183        Number of discrete bins for the torque action.
184
185    Returns
186    -------
187    PendulumWrapper
188        The Pendulum wrapper environment.
189    """
190    pendulum_genv_train = gym.make("Pendulum-v1")
191
192    pendulum_train = PendulumWrapper(
193        angle_bins=angle_bins,
194        angular_velocity_bins=angular_velocity_bins,
195        torque_bins=torque_bins,
196        env=pendulum_genv_train,
197    )
198
199    return pendulum_train

Initializes the Pendulum wrapper environment.

Parameters
  • angle_bins (int): Number of discrete bins for the pendulum's angle.
  • angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
  • torque_bins (int): Number of discrete bins for the torque action.
Returns
  • PendulumWrapper: The Pendulum wrapper environment.