SARSA

In [1]:
import gym

from tensorflow import keras
from tensorflow.keras import backend as K

from keras_gym.value_functions import LinearQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import Sarsa



# the Gym environment
env = gym.make('CartPole-v0')


# define Q, its induced policy and update algorithm
Q = LinearQ(env, lr=0.08, interaction='elementwise_quadratic')
policy = ValuePolicy(Q)
algo = Sarsa(Q, gamma=0.8)


# number of iterations
num_episodes = 200
max_episode_steps = env._max_episode_steps


# used for early stopping
num_consecutive_successes = 0


for episode in range(1, num_episodes + 1):
    last_episode = episode == num_episodes or num_consecutive_successes == 9
    
    # init
    s = env.reset()
    a = env.action_space.sample()
    
    # amount of random exploration
    if last_episode:
        epsilon = 0
        env.render()
    elif episode < 10:
        epsilon = 0.5
    else:
        epsilon = 0.01
    
    for t in range(1, max_episode_steps + 1):
        s_next, r, done, info = env.step(a)
        a_next = policy.epsilon_greedy(s, epsilon)
        
        # update or render
        if not last_episode:
            algo.update(s, a, r, s_next, a_next)         
        else:
            env.render()
        
        # keep track of consecutive successes
        if done:
            if t == max_episode_steps:
                num_consecutive_successes += 1
                print(f"num_consecutive_successes = {num_consecutive_successes}")
            else:
                num_consecutive_successes = 0
                print(f"failed after {t} steps")
            break
    
        # prepare for next step
        s, a = s_next, a_next

        
    if last_episode:
        break


env.close()
failed after 12 steps
failed after 43 steps
failed after 17 steps
failed after 24 steps
failed after 11 steps
failed after 53 steps
failed after 134 steps
failed after 39 steps
failed after 80 steps
failed after 108 steps
failed after 177 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
failed after 187 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
num_consecutive_successes = 5
num_consecutive_successes = 6
num_consecutive_successes = 7
num_consecutive_successes = 8
num_consecutive_successes = 9
num_consecutive_successes = 10

SARSA with scikit-learn function approximator

In [2]:
import gym

from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import FunctionTransformer

from keras_gym.value_functions import GenericQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import Sarsa
from keras_gym.wrappers import SklearnModelWrapper


# the Gym environment
env = gym.make('CartPole-v0')


# define sklearn model for approximating Q-function
model = SklearnModelWrapper(
    estimator=SGDRegressor(eta0=0.08, learning_rate='constant'),
    transformer=FunctionTransformer(
        lambda x: np.hstack([x, x ** 2]), validate=False),
)


# define Q, its induced policy and update algorithm
Q = GenericQ(env, model)
policy = ValuePolicy(Q)
algo = Sarsa(Q, gamma=0.8)


# number of iterations
num_episodes = 200
max_episode_steps = env._max_episode_steps


# used for early stopping
num_consecutive_successes = 0


for episode in range(1, num_episodes + 1):
    last_episode = episode == num_episodes or num_consecutive_successes == 9
    
    # init
    s = env.reset()
    a = env.action_space.sample()
    
    # amount of random exploration
    if last_episode:
        epsilon = 0
        env.render()
    elif episode < 10:
        epsilon = 0.5
    else:
        epsilon = 0.01
    
    for t in range(1, max_episode_steps + 1):
        s_next, r, done, info = env.step(a)
        a_next = policy.epsilon_greedy(s, epsilon)
        
        # update or render
        if not last_episode:
            algo.update(s, a, r, s_next, a_next)         
        else:
            env.render()
        
        # keep track of consecutive successes
        if done:
            if t == max_episode_steps:
                num_consecutive_successes += 1
                print(f"num_consecutive_successes = {num_consecutive_successes}")
            else:
                num_consecutive_successes = 0
                print(f"failed after {t} steps")
            break
    
        # prepare for next step
        s, a = s_next, a_next

        
    if last_episode:
        break


env.close()
failed after 15 steps
failed after 22 steps
failed after 12 steps
failed after 13 steps
failed after 33 steps
failed after 36 steps
failed after 29 steps
failed after 8 steps
failed after 20 steps
failed after 14 steps
failed after 10 steps
failed after 8 steps
failed after 39 steps
failed after 10 steps
failed after 8 steps
failed after 25 steps
failed after 31 steps
failed after 9 steps
failed after 29 steps
failed after 29 steps
failed after 14 steps
failed after 27 steps
failed after 10 steps
failed after 34 steps
failed after 17 steps
failed after 47 steps
failed after 8 steps
failed after 23 steps
failed after 9 steps
failed after 9 steps
failed after 8 steps
failed after 9 steps
failed after 8 steps
failed after 8 steps
failed after 11 steps
failed after 58 steps
failed after 12 steps
failed after 26 steps
failed after 11 steps
failed after 10 steps
failed after 11 steps
failed after 10 steps
failed after 43 steps
failed after 10 steps
failed after 36 steps
failed after 10 steps
failed after 9 steps
failed after 10 steps
failed after 32 steps
failed after 125 steps
failed after 121 steps
failed after 47 steps
failed after 120 steps
failed after 68 steps
failed after 146 steps
failed after 163 steps
failed after 28 steps
failed after 61 steps
failed after 51 steps
failed after 47 steps
failed after 67 steps
failed after 62 steps
failed after 188 steps
failed after 125 steps
failed after 103 steps
failed after 179 steps
num_consecutive_successes = 1
failed after 163 steps
failed after 125 steps
failed after 50 steps
failed after 114 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
failed after 99 steps
failed after 88 steps
failed after 57 steps
failed after 175 steps
failed after 117 steps
failed after 161 steps
num_consecutive_successes = 1
failed after 157 steps
failed after 113 steps
failed after 53 steps
num_consecutive_successes = 1
failed after 160 steps
failed after 190 steps
failed after 127 steps
num_consecutive_successes = 1
failed after 187 steps
num_consecutive_successes = 1
failed after 134 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
failed after 162 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
num_consecutive_successes = 5
num_consecutive_successes = 6
num_consecutive_successes = 7
num_consecutive_successes = 8
num_consecutive_successes = 9
num_consecutive_successes = 10

Q-learning

In [3]:
import gym

from tensorflow import keras
from tensorflow.keras import backend as K

from keras_gym.value_functions import LinearQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import QLearning


# the Gym environment
env = gym.make('CartPole-v0')


# define Q, its induced policy and update algorithm
Q = LinearQ(env, interaction='elementwise_quadratic', lr=0.8, momentum=0., decay=0.1)
policy = ValuePolicy(Q)
algo = QLearning(Q, gamma=0.8)


# number of iterations
num_episodes = 100
max_episode_steps = env._max_episode_steps


# used for early stopping
num_consecutive_successes = 0


for episode in range(1, num_episodes + 1):
    last_episode = episode == num_episodes or num_consecutive_successes == 9
    
    # init
    s = env.reset()
    
    # amount of random exploration
    if last_episode:
        epsilon = 0
        env.render()
    elif episode < 10:
        epsilon = 0.5
    else:
        epsilon = 0.01
    
    for t in range(1, max_episode_steps + 1):
        a = policy.epsilon_greedy(s, epsilon)
        s_next, r, done, info = env.step(a)
        
        # update or render
        if not last_episode:          
            algo.update(s, a, r, s_next)            
        else:
            env.render()
        
        # keep track of consecutive successes
        if done:
            if t == max_episode_steps:
                num_consecutive_successes += 1
                print(f"num_consecutive_successes = {num_consecutive_successes}")
            else:
                num_consecutive_successes = 0
                print(f"failed after {t} steps")
            break
    
        # prepare for next step
        s = s_next

        
    if last_episode:
        break


env.close()
failed after 16 steps
failed after 14 steps
failed after 24 steps
failed after 12 steps
failed after 15 steps
failed after 14 steps
failed after 17 steps
failed after 10 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 8 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 8 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 13 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 113 steps
failed after 77 steps
failed after 130 steps
failed after 75 steps
failed after 56 steps
failed after 160 steps
failed after 71 steps
failed after 147 steps
failed after 133 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 71 steps
failed after 94 steps
failed after 85 steps
failed after 79 steps
num_consecutive_successes = 1
failed after 86 steps
failed after 100 steps
failed after 69 steps
failed after 79 steps
num_consecutive_successes = 1
failed after 74 steps
num_consecutive_successes = 1
failed after 89 steps
failed after 65 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
num_consecutive_successes = 5
num_consecutive_successes = 6
num_consecutive_successes = 7
failed after 192 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
failed after 131 steps
failed after 89 steps
num_consecutive_successes = 1
failed after 84 steps
num_consecutive_successes = 1
failed after 65 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
num_consecutive_successes = 5
num_consecutive_successes = 6
failed after 79 steps
failed after 86 steps
num_consecutive_successes = 1
failed after 191 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
failed after 75 steps
failed after 155 steps
failed after 83 steps
num_consecutive_successes = 1
failed after 174 steps
failed after 158 steps
num_consecutive_successes = 1
failed after 147 steps
failed after 69 steps
failed after 72 steps
num_consecutive_successes = 1
failed after 69 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 186 steps
failed after 119 steps

Q-learning with custom Keras function approximator

In [4]:
import gym

from tensorflow import keras
from tensorflow.keras import backend as K

from keras_gym.value_functions import GenericQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import QLearning


# the Gym environment
env = gym.make('CartPole-v0')


# custom function apprixmator (linear regression)  
model = keras.Sequential(layers=[
    keras.layers.Lambda(lambda x: K.concatenate([x, x ** 2])),
    keras.layers.Dense(1),
])
model.compile(
    optimizer=keras.optimizers.SGD(lr=0.05, momentum=0.5),
    loss=keras.metrics.mean_squared_error)


# define Q, its induced policy and update algorithm
Q = GenericQ(env, model)
policy = ValuePolicy(Q)
algo = QLearning(Q, gamma=0.8)


# number of iterations
num_episodes = 100
max_episode_steps = env._max_episode_steps


# used for early stopping
num_consecutive_successes = 0


for episode in range(1, num_episodes + 1):
    last_episode = episode == num_episodes or num_consecutive_successes == 9
    
    # init
    s = env.reset()
    
    # amount of random exploration
    if last_episode:
        epsilon = 0
        env.render()
    elif episode < 10:
        epsilon = 0.5
    else:
        epsilon = 0.01
    
    for t in range(1, max_episode_steps + 1):
        a = policy.epsilon_greedy(s, epsilon)
        s_next, r, done, info = env.step(a)
        
        # update or render
        if not last_episode:
            algo.update(s, a, r, s_next)            
        else:
            env.render()
        
        # keep track of consecutive successes
        if done:
            if t == max_episode_steps:
                num_consecutive_successes += 1
                print(f"num_consecutive_successes = {num_consecutive_successes}")
            else:
                num_consecutive_successes = 0
                print(f"failed after {t} steps")
            break
    
        # prepare for next step
        s = s_next

        
    if last_episode:
        break


env.close()
failed after 15 steps
failed after 12 steps
failed after 13 steps
failed after 11 steps
failed after 18 steps
failed after 12 steps
failed after 9 steps
failed after 11 steps
failed after 12 steps
failed after 28 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 18 steps
failed after 157 steps
failed after 113 steps
failed after 195 steps
num_consecutive_successes = 1
failed after 86 steps
failed after 73 steps
failed after 132 steps
failed after 105 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 162 steps
failed after 86 steps
failed after 91 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 159 steps
failed after 129 steps
failed after 132 steps
failed after 193 steps
failed after 117 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 111 steps
failed after 156 steps
failed after 125 steps
failed after 95 steps
num_consecutive_successes = 1
failed after 66 steps
failed after 115 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 124 steps
failed after 80 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 144 steps
failed after 97 steps
num_consecutive_successes = 1
failed after 108 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 119 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
num_consecutive_successes = 5
num_consecutive_successes = 6
num_consecutive_successes = 7
num_consecutive_successes = 8
num_consecutive_successes = 9
num_consecutive_successes = 10

Q-learning with type-II model

In [5]:
import gym

from tensorflow import keras
from tensorflow.keras import backend as K

from keras_gym.value_functions import LinearQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import QLearning


# the Gym environment
env = gym.make('CartPole-v0')


# define Q, its induced policy and update algorithm
Q = LinearQ(env, model_type=2, lr=0.05, momentum=0.5, interaction='elementwise_quadratic')
policy = ValuePolicy(Q)
algo = QLearning(Q, gamma=0.8)


# number of iterations
num_episodes = 100
max_episode_steps = env._max_episode_steps


# used for early stopping
num_consecutive_successes = 0


for episode in range(1, num_episodes + 1):
    last_episode = episode == num_episodes or num_consecutive_successes == 9
    
    # init
    s = env.reset()
    
    # amount of random exploration
    if last_episode:
        epsilon = 0
        env.render()
    elif episode < 10:
        epsilon = 0.5
    else:
        epsilon = 0.01
    
    for t in range(1, max_episode_steps + 1):
        a = policy.epsilon_greedy(s, epsilon)
        s_next, r, done, info = env.step(a)

        
        # update or render
        if not last_episode:
            algo.update(s, a, r, s_next)         
        else:
            env.render()
        
        # keep track of consecutive successes
        if done:
            if t == max_episode_steps:
                num_consecutive_successes += 1
                print(f"num_consecutive_successes = {num_consecutive_successes}")
            else:
                num_consecutive_successes = 0
                print(f"failed after {t} steps")
            break
    
        # prepare for next step
        s = s_next

        
    if last_episode:
        break


env.close()
failed after 30 steps
failed after 20 steps
failed after 12 steps
failed after 10 steps
failed after 16 steps
failed after 10 steps
failed after 10 steps
failed after 8 steps
failed after 14 steps
failed after 9 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 15 steps
failed after 24 steps
failed after 94 steps
failed after 72 steps
failed after 42 steps
failed after 49 steps
failed after 85 steps
failed after 105 steps
failed after 24 steps
failed after 38 steps
failed after 50 steps
failed after 55 steps
failed after 47 steps
failed after 168 steps
failed after 63 steps
failed after 52 steps
failed after 123 steps
failed after 60 steps
failed after 44 steps
failed after 137 steps
failed after 63 steps
failed after 93 steps
failed after 61 steps
failed after 128 steps
failed after 56 steps
failed after 75 steps
failed after 148 steps
failed after 77 steps
failed after 76 steps
failed after 175 steps
failed after 67 steps
failed after 71 steps
failed after 127 steps
failed after 166 steps
failed after 9 steps
failed after 46 steps
failed after 69 steps
failed after 106 steps
failed after 68 steps
failed after 48 steps
failed after 80 steps
failed after 54 steps
failed after 49 steps
failed after 64 steps
failed after 71 steps
failed after 72 steps
failed after 102 steps
failed after 67 steps
failed after 68 steps
failed after 87 steps
failed after 79 steps
failed after 71 steps
failed after 96 steps
failed after 63 steps
failed after 63 steps
failed after 64 steps
failed after 178 steps
failed after 76 steps
failed after 82 steps
failed after 67 steps
failed after 102 steps
failed after 83 steps
failed after 92 steps
failed after 88 steps
failed after 99 steps
failed after 89 steps
failed after 102 steps
failed after 69 steps
failed after 119 steps
failed after 127 steps
failed after 59 steps
failed after 90 steps
failed after 127 steps
failed after 70 steps
failed after 136 steps
failed after 78 steps
failed after 114 steps
failed after 84 steps
failed after 68 steps
failed after 108 steps
failed after 82 steps
failed after 111 steps
failed after 99 steps
failed after 79 steps

Expected SARSA

In [6]:
import gym

from tensorflow import keras
from tensorflow.keras import backend as K

from keras_gym.value_functions import LinearQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import ExpectedSarsa



# the Gym environment
env = gym.make('CartPole-v0')


# define Q, its induced policy and update algorithm
Q = LinearQ(env, interaction='elementwise_quadratic', lr=0.01)
policy = ValuePolicy(Q)
algo = ExpectedSarsa(Q, policy, gamma=0.8)


# number of iterations
num_episodes = 200
max_episode_steps = env._max_episode_steps


# used for early stopping
num_consecutive_successes = 0


for episode in range(1, num_episodes + 1):
    last_episode = episode == num_episodes or num_consecutive_successes == 9
    
    # init
    s = env.reset()
    
    # amount of random exploration
    if last_episode:
        epsilon = 0
        env.render()
    elif episode < 10:
        epsilon = 0.5
    else:
        epsilon = 0.01
    
    for t in range(1, max_episode_steps + 1):
        a = policy.epsilon_greedy(s, epsilon)
        s_next, r, done, info = env.step(a)
        
        # update or render
        if not last_episode:
            algo.update(s, a, r, s_next)            
        else:
            env.render()
        
        # keep track of consecutive successes
        if done:
            if t == max_episode_steps:
                num_consecutive_successes += 1
                print(f"num_consecutive_successes = {num_consecutive_successes}")
            else:
                num_consecutive_successes = 0
                print(f"failed after {t} steps")
            break
    
        # prepare for next step
        s = s_next

        
    if last_episode:
        break


env.close()
failed after 27 steps
failed after 21 steps
failed after 28 steps
failed after 17 steps
failed after 11 steps
failed after 11 steps
failed after 15 steps
failed after 13 steps
failed after 9 steps
failed after 12 steps
failed after 13 steps
failed after 10 steps
failed after 8 steps
failed after 11 steps
failed after 10 steps
failed after 11 steps
failed after 10 steps
failed after 10 steps
failed after 8 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 8 steps
failed after 9 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 8 steps
failed after 11 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 8 steps
failed after 9 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 8 steps
failed after 9 steps
failed after 11 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 8 steps
failed after 9 steps
failed after 9 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 9 steps
failed after 8 steps
failed after 9 steps
failed after 8 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 10 steps
failed after 9 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 10 steps
failed after 9 steps
failed after 9 steps
failed after 11 steps
failed after 8 steps
failed after 13 steps
failed after 9 steps
failed after 8 steps
failed after 10 steps
failed after 14 steps
failed after 10 steps
failed after 36 steps
failed after 39 steps
failed after 33 steps
failed after 41 steps
failed after 49 steps
failed after 79 steps
failed after 96 steps
failed after 101 steps
failed after 104 steps
failed after 57 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 55 steps
failed after 57 steps
failed after 69 steps
failed after 69 steps
failed after 84 steps
failed after 133 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 73 steps
num_consecutive_successes = 1
failed after 84 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
failed after 125 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 66 steps
failed after 82 steps
failed after 65 steps
num_consecutive_successes = 1
failed after 70 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 67 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
failed after 141 steps
num_consecutive_successes = 1
failed after 57 steps
failed after 95 steps
failed after 89 steps
failed after 114 steps
num_consecutive_successes = 1
failed after 157 steps
num_consecutive_successes = 1
failed after 103 steps
failed after 116 steps
failed after 111 steps
failed after 157 steps
failed after 86 steps
failed after 89 steps
failed after 126 steps
num_consecutive_successes = 1
failed after 133 steps
failed after 82 steps
failed after 189 steps
failed after 165 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4
failed after 136 steps
failed after 179 steps
failed after 85 steps
failed after 67 steps
failed after 189 steps
failed after 96 steps
failed after 86 steps
failed after 85 steps
failed after 80 steps
failed after 88 steps
failed after 158 steps
failed after 158 steps
num_consecutive_successes = 1
failed after 96 steps
failed after 78 steps
failed after 188 steps
failed after 130 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 104 steps
failed after 191 steps
num_consecutive_successes = 1
failed after 101 steps
failed after 87 steps
failed after 132 steps
failed after 95 steps
failed after 121 steps
failed after 191 steps
num_consecutive_successes = 1
failed after 177 steps
failed after 188 steps
failed after 102 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 112 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
failed after 77 steps
failed after 180 steps
num_consecutive_successes = 1
num_consecutive_successes = 2
num_consecutive_successes = 3
num_consecutive_successes = 4

Record Video

In [ ]:
# env = gym.make('CartPole-v1')
# env = gym.wrappers.Monitor(env, os.path.join('data', 'video', 'cartpole-linear-model-sarsa'), force=True)
# s = env.reset()
# env.render()
# done = False

# while not done:
#     a = policy.greedy(s)
#     s, _, done, _ = env.step(a)
#     env.render()
    
# env.close()