envs.pendulum_wrapper

Author: Aleksandr Spiridonov BSD 3-Clause License

  1"""
  2Author: Aleksandr Spiridonov
  3BSD 3-Clause License
  4"""
  5
  6import gymnasium as gym
  7from bettermdptools.envs.pendulum_discretized import DiscretizedPendulum  # Ensure this path is correct
  8
  9
 10class CustomTransformObservation(gym.ObservationWrapper):
 11    def __init__(self, env, func, observation_space):
 12        """
 13        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
 14        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
 15        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
 16        to set both the conversion function and new observation space.
 17
 18        Parameters
 19        ----------------------------
 20        env {gymnasium.Env}:
 21            Base environment to be wrapped
 22
 23        func {lambda}:
 24            Function that converts the observation
 25
 26        observation_space {gymnasium.spaces.Space}:
 27            New observation space
 28        """
 29        super().__init__(env)
 30        if observation_space is not None:
 31            self.observation_space = observation_space
 32        self.func = func
 33
 34    def observation(self, observation):
 35        """
 36        Applies a function to the observation received from the environment's step function,
 37        which is passed back to the user.
 38
 39        Parameters
 40        ----------------------------
 41        observation {Tuple}:
 42            Base environment observation tuple
 43
 44        Returns
 45        ----------------------------
 46        func(observation) {int}:
 47            The converted observation (int).
 48        """
 49        return self.func(observation)
 50
 51
 52class PendulumWrapper(gym.Wrapper):
 53    def __init__(self,
 54                 env,
 55                 angle_bins=11,
 56                 angular_velocity_bins=11,
 57                 torque_bins=11):
 58        """
 59        Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
 60
 61        Parameters
 62        ----------------------------
 63        env {gymnasium.Env}: Base environment
 64        angle_bins (int): Number of discrete bins for the pendulum's angle.
 65        angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
 66        action_bins (int): Number of discrete bins for the torque action.
 67        angular_center_resolution (float): The resolution of angle bins near the center (around zero).
 68        angular_outer_resolution (float): The resolution of angle bins away from the center.
 69        torque_range (tuple): The minimum and maximum torque values.
 70        """
 71        super().__init__(env)
 72
 73        # Initialize the DiscretizedPendulum model
 74        self.discretized_pendulum = DiscretizedPendulum(
 75            angle_bins=angle_bins,
 76            angular_velocity_bins=angular_velocity_bins,
 77            torque_bins=torque_bins,
 78        )
 79
 80        # Transition probability matrix
 81        self._P = self.discretized_pendulum.P
 82
 83        # Transformation function from continuous to discrete observations
 84        self._transform_obs = self.discretized_pendulum.transform_cont_obs
 85        self._get_action_value = self.discretized_pendulum.get_action_value
 86
 87        # Wrap the environment's observation space
 88        self.observation_space = gym.spaces.Discrete(self.discretized_pendulum.state_space)
 89        self.env = CustomTransformObservation(env, self._transform_obs, self.observation_space)
 90        self.gym_env = env
 91
 92        # Override the action space to be discrete
 93        self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)
 94
 95    @property
 96    def P(self):
 97        """
 98        Returns the transition probability matrix.
 99
100        Returns
101        ----------------------------
102        _P {dict}
103        """
104        return self._P
105
106    @property
107    def transform_obs(self):
108        """
109        Returns the observation transformation function.
110
111        Returns
112        ----------------------------
113        _transform_obs {lambda}
114        """
115        return self._transform_obs
116    
117    @property
118    def get_action_value(self):
119        f = lambda action: [self._get_action_value(action)]
120        return f
121
122    def step(self, action):
123        """
124        Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
125
126        Parameters
127        ----------------------------
128        action {int}:
129            The discrete action index.
130
131        Returns
132        ----------------------------
133        state {int}:
134            The discretized next state index.
135        reward {float}:
136            The reward obtained from the environment.
137        done {bool}:
138            Whether the episode has terminated.
139        info {dict}:
140            Additional information from the environment.
141        """
142        # Map discrete action to continuous torque
143        torque = self.discretized_pendulum.get_action_value(action)
144
145        return self.env.step([torque])
146
147def get_env_str(angle_bins, angular_velocity_bins, torque_bins):
148    """
149    Returns the environment string based on the discretization parameters.
150
151    Parameters
152    ----------------------------
153    angle_bins (int): Number of discrete bins for the pendulum's angle.
154    angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
155    action_bins (int): Number of discrete bins for the torque action.
156
157    Returns
158    ----------------------------
159    env_str {str}: The environment string.
160    """
161    return f'pendulum_{angle_bins}_{angular_velocity_bins}_{torque_bins}'
162
163def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11):
164    """
165    Initializes the Pendulum wrapper environment.
166
167    Parameters
168    ----------------------------
169    angle_bins (int): Number of discrete bins for the pendulum's angle.
170    angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
171    torque_bins (int): Number of discrete bins for the torque action.
172
173    Returns
174    ----------------------------
175    pendulum_env {PendulumWrapper}: The Pendulum wrapper environment.
176    """
177    pendulum_genv_train = gym.make('Pendulum-v1')
178
179    pendulum_train = PendulumWrapper(
180        angle_bins=angle_bins,
181        angular_velocity_bins=angular_velocity_bins,
182        torque_bins=torque_bins,
183        env=pendulum_genv_train
184    )
185
186    return pendulum_train
class CustomTransformObservation(gymnasium.core.Wrapper[~WrapperObsType, ~ActType]):
11class CustomTransformObservation(gym.ObservationWrapper):
12    def __init__(self, env, func, observation_space):
13        """
14        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
15        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
16        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
17        to set both the conversion function and new observation space.
18
19        Parameters
20        ----------------------------
21        env {gymnasium.Env}:
22            Base environment to be wrapped
23
24        func {lambda}:
25            Function that converts the observation
26
27        observation_space {gymnasium.spaces.Space}:
28            New observation space
29        """
30        super().__init__(env)
31        if observation_space is not None:
32            self.observation_space = observation_space
33        self.func = func
34
35    def observation(self, observation):
36        """
37        Applies a function to the observation received from the environment's step function,
38        which is passed back to the user.
39
40        Parameters
41        ----------------------------
42        observation {Tuple}:
43            Base environment observation tuple
44
45        Returns
46        ----------------------------
47        func(observation) {int}:
48            The converted observation (int).
49        """
50        return self.func(observation)

Superclass of wrappers that can modify observations using observation() for reset() and step().

If you would like to apply a function to only the observation before passing it to the learning code, you can simply inherit from ObservationWrapper and overwrite the method observation() to implement that transformation. The transformation defined in that method must be reflected by the env observation space. Otherwise, you need to specify the new observation space of the wrapper by setting self.observation_space in the __init__() method of your wrapper.

Among others, Gymnasium provides the observation wrapper TimeAwareObservation, which adds information about the index of the timestep to the observation.

CustomTransformObservation(env, func, observation_space)
12    def __init__(self, env, func, observation_space):
13        """
14        Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not
15        accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int).
16        Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation)
17        to set both the conversion function and new observation space.
18
19        Parameters
20        ----------------------------
21        env {gymnasium.Env}:
22            Base environment to be wrapped
23
24        func {lambda}:
25            Function that converts the observation
26
27        observation_space {gymnasium.spaces.Space}:
28            New observation space
29        """
30        super().__init__(env)
31        if observation_space is not None:
32            self.observation_space = observation_space
33        self.func = func

Helper class that modifies the observation space. The v26 gymnasium TransformObservation wrapper does not accept an observation_space parameter, which is needed in order to match the lambda conversion (tuple->int). Instead, we subclass gym.ObservationWrapper (parent class of gym.TransformObservation) to set both the conversion function and new observation space.

Parameters
  • env {gymnasium.Env}:: Base environment to be wrapped
  • func {lambda}:: Function that converts the observation
  • observation_space {gymnasium.spaces.Space}:: New observation space
func
def observation(self, observation):
35    def observation(self, observation):
36        """
37        Applies a function to the observation received from the environment's step function,
38        which is passed back to the user.
39
40        Parameters
41        ----------------------------
42        observation {Tuple}:
43            Base environment observation tuple
44
45        Returns
46        ----------------------------
47        func(observation) {int}:
48            The converted observation (int).
49        """
50        return self.func(observation)

Applies a function to the observation received from the environment's step function, which is passed back to the user.

Parameters
  • observation {Tuple}:: Base environment observation tuple
Returns
  • func(observation) {int}:: The converted observation (int).
class PendulumWrapper(gymnasium.core.Env[~WrapperObsType, ~WrapperActType]):
 53class PendulumWrapper(gym.Wrapper):
 54    def __init__(self,
 55                 env,
 56                 angle_bins=11,
 57                 angular_velocity_bins=11,
 58                 torque_bins=11):
 59        """
 60        Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
 61
 62        Parameters
 63        ----------------------------
 64        env {gymnasium.Env}: Base environment
 65        angle_bins (int): Number of discrete bins for the pendulum's angle.
 66        angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
 67        action_bins (int): Number of discrete bins for the torque action.
 68        angular_center_resolution (float): The resolution of angle bins near the center (around zero).
 69        angular_outer_resolution (float): The resolution of angle bins away from the center.
 70        torque_range (tuple): The minimum and maximum torque values.
 71        """
 72        super().__init__(env)
 73
 74        # Initialize the DiscretizedPendulum model
 75        self.discretized_pendulum = DiscretizedPendulum(
 76            angle_bins=angle_bins,
 77            angular_velocity_bins=angular_velocity_bins,
 78            torque_bins=torque_bins,
 79        )
 80
 81        # Transition probability matrix
 82        self._P = self.discretized_pendulum.P
 83
 84        # Transformation function from continuous to discrete observations
 85        self._transform_obs = self.discretized_pendulum.transform_cont_obs
 86        self._get_action_value = self.discretized_pendulum.get_action_value
 87
 88        # Wrap the environment's observation space
 89        self.observation_space = gym.spaces.Discrete(self.discretized_pendulum.state_space)
 90        self.env = CustomTransformObservation(env, self._transform_obs, self.observation_space)
 91        self.gym_env = env
 92
 93        # Override the action space to be discrete
 94        self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)
 95
 96    @property
 97    def P(self):
 98        """
 99        Returns the transition probability matrix.
100
101        Returns
102        ----------------------------
103        _P {dict}
104        """
105        return self._P
106
107    @property
108    def transform_obs(self):
109        """
110        Returns the observation transformation function.
111
112        Returns
113        ----------------------------
114        _transform_obs {lambda}
115        """
116        return self._transform_obs
117    
118    @property
119    def get_action_value(self):
120        f = lambda action: [self._get_action_value(action)]
121        return f
122
123    def step(self, action):
124        """
125        Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
126
127        Parameters
128        ----------------------------
129        action {int}:
130            The discrete action index.
131
132        Returns
133        ----------------------------
134        state {int}:
135            The discretized next state index.
136        reward {float}:
137            The reward obtained from the environment.
138        done {bool}:
139            Whether the episode has terminated.
140        info {dict}:
141            Additional information from the environment.
142        """
143        # Map discrete action to continuous torque
144        torque = self.discretized_pendulum.get_action_value(action)
145
146        return self.env.step([torque])

Wraps a gymnasium.Env to allow a modular transformation of the step() and reset() methods.

This class is the base class of all wrappers to change the behavior of the underlying environment. Wrappers that inherit from this class can modify the action_space, observation_space, reward_range and metadata attributes, without changing the underlying environment's attributes. Moreover, the behavior of the step() and reset() methods can be changed by these wrappers.

Some attributes (spec, render_mode, np_random) will point back to the wrapper's environment (i.e. to the corresponding attributes of env).

Note: If you inherit from Wrapper, don't forget to call super().__init__(env)

PendulumWrapper(env, angle_bins=11, angular_velocity_bins=11, torque_bins=11)
54    def __init__(self,
55                 env,
56                 angle_bins=11,
57                 angular_velocity_bins=11,
58                 torque_bins=11):
59        """
60        Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.
61
62        Parameters
63        ----------------------------
64        env {gymnasium.Env}: Base environment
65        angle_bins (int): Number of discrete bins for the pendulum's angle.
66        angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
67        action_bins (int): Number of discrete bins for the torque action.
68        angular_center_resolution (float): The resolution of angle bins near the center (around zero).
69        angular_outer_resolution (float): The resolution of angle bins away from the center.
70        torque_range (tuple): The minimum and maximum torque values.
71        """
72        super().__init__(env)
73
74        # Initialize the DiscretizedPendulum model
75        self.discretized_pendulum = DiscretizedPendulum(
76            angle_bins=angle_bins,
77            angular_velocity_bins=angular_velocity_bins,
78            torque_bins=torque_bins,
79        )
80
81        # Transition probability matrix
82        self._P = self.discretized_pendulum.P
83
84        # Transformation function from continuous to discrete observations
85        self._transform_obs = self.discretized_pendulum.transform_cont_obs
86        self._get_action_value = self.discretized_pendulum.get_action_value
87
88        # Wrap the environment's observation space
89        self.observation_space = gym.spaces.Discrete(self.discretized_pendulum.state_space)
90        self.env = CustomTransformObservation(env, self._transform_obs, self.observation_space)
91        self.gym_env = env
92
93        # Override the action space to be discrete
94        self.action_space = gym.spaces.Discrete(self.discretized_pendulum.action_space)

Pendulum wrapper that modifies the observation and action spaces and creates a transition/reward matrix P.

Parameters
  • env {gymnasium.Env} (Base environment):

  • angle_bins (int) (Number of discrete bins for the pendulum's angle.):

  • angular_velocity_bins (int) (Number of discrete bins for the pendulum's angular velocity.):

  • action_bins (int) (Number of discrete bins for the torque action.):

  • angular_center_resolution (float) (The resolution of angle bins near the center (around zero).):

  • angular_outer_resolution (float) (The resolution of angle bins away from the center.):

  • torque_range (tuple) (The minimum and maximum torque values.):

discretized_pendulum
observation_space: Union[gymnasium.spaces.space.Space[~ObsType], gymnasium.spaces.space.Space[~WrapperObsType]]
299    @property
300    def observation_space(
301        self,
302    ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]:
303        """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used."""
304        if self._observation_space is None:
305            return self.env.observation_space
306        return self._observation_space

Return the Env observation_space unless overwritten then the wrapper observation_space is used.

env
gym_env
action_space: Union[gymnasium.spaces.space.Space[~ActType], gymnasium.spaces.space.Space[~WrapperActType]]
286    @property
287    def action_space(
288        self,
289    ) -> spaces.Space[ActType] | spaces.Space[WrapperActType]:
290        """Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used."""
291        if self._action_space is None:
292            return self.env.action_space
293        return self._action_space

Return the Env action_space unless overwritten then the wrapper action_space is used.

P
 96    @property
 97    def P(self):
 98        """
 99        Returns the transition probability matrix.
100
101        Returns
102        ----------------------------
103        _P {dict}
104        """
105        return self._P

Returns the transition probability matrix.

Returns
  • _P {dict}
transform_obs
107    @property
108    def transform_obs(self):
109        """
110        Returns the observation transformation function.
111
112        Returns
113        ----------------------------
114        _transform_obs {lambda}
115        """
116        return self._transform_obs

Returns the observation transformation function.

Returns
  • _transform_obs {lambda}
get_action_value
118    @property
119    def get_action_value(self):
120        f = lambda action: [self._get_action_value(action)]
121        return f
def step(self, action):
123    def step(self, action):
124        """
125        Takes a discrete action, maps it to a continuous torque, and interacts with the environment.
126
127        Parameters
128        ----------------------------
129        action {int}:
130            The discrete action index.
131
132        Returns
133        ----------------------------
134        state {int}:
135            The discretized next state index.
136        reward {float}:
137            The reward obtained from the environment.
138        done {bool}:
139            Whether the episode has terminated.
140        info {dict}:
141            Additional information from the environment.
142        """
143        # Map discrete action to continuous torque
144        torque = self.discretized_pendulum.get_action_value(action)
145
146        return self.env.step([torque])

Takes a discrete action, maps it to a continuous torque, and interacts with the environment.

Parameters
  • action {int}:: The discrete action index.
Returns
  • state {int}:: The discretized next state index.
  • reward {float}:: The reward obtained from the environment.
  • done {bool}:: Whether the episode has terminated.
  • info {dict}:: Additional information from the environment.
def get_env_str(angle_bins, angular_velocity_bins, torque_bins):
148def get_env_str(angle_bins, angular_velocity_bins, torque_bins):
149    """
150    Returns the environment string based on the discretization parameters.
151
152    Parameters
153    ----------------------------
154    angle_bins (int): Number of discrete bins for the pendulum's angle.
155    angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
156    action_bins (int): Number of discrete bins for the torque action.
157
158    Returns
159    ----------------------------
160    env_str {str}: The environment string.
161    """
162    return f'pendulum_{angle_bins}_{angular_velocity_bins}_{torque_bins}'

Returns the environment string based on the discretization parameters.

Parameters
  • angle_bins (int) (Number of discrete bins for the pendulum's angle.):

  • angular_velocity_bins (int) (Number of discrete bins for the pendulum's angular velocity.):

  • action_bins (int) (Number of discrete bins for the torque action.):

Returns
  • env_str {str} (The environment string.):
def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11):
164def init_wrapper_env(angle_bins=11, angular_velocity_bins=11, torque_bins=11):
165    """
166    Initializes the Pendulum wrapper environment.
167
168    Parameters
169    ----------------------------
170    angle_bins (int): Number of discrete bins for the pendulum's angle.
171    angular_velocity_bins (int): Number of discrete bins for the pendulum's angular velocity.
172    torque_bins (int): Number of discrete bins for the torque action.
173
174    Returns
175    ----------------------------
176    pendulum_env {PendulumWrapper}: The Pendulum wrapper environment.
177    """
178    pendulum_genv_train = gym.make('Pendulum-v1')
179
180    pendulum_train = PendulumWrapper(
181        angle_bins=angle_bins,
182        angular_velocity_bins=angular_velocity_bins,
183        torque_bins=torque_bins,
184        env=pendulum_genv_train
185    )
186
187    return pendulum_train

Initializes the Pendulum wrapper environment.

Parameters
  • angle_bins (int) (Number of discrete bins for the pendulum's angle.):

  • angular_velocity_bins (int) (Number of discrete bins for the pendulum's angular velocity.):

  • torque_bins (int) (Number of discrete bins for the torque action.):

Returns
  • pendulum_env {PendulumWrapper} (The Pendulum wrapper environment.):