In [1]:
# %load ../../scripts/cartpole/sarsa.py
import gym

from keras_gym.preprocessing import DefaultPreprocessor
from keras_gym.value_functions import LinearQTypeI
from keras_gym.policies import EpsilonGreedy


# env with preprocessing
env = gym.make('CartPole-v0')
env = DefaultPreprocessor(env)

# value function and its derived policy
Q = LinearQTypeI(env, lr=0.02, momentum=0.9, gamma=0.8,
                 update_strategy='sarsa', bootstrap_n=1)
policy = EpsilonGreedy(Q)

# static parameters
num_episodes = 200
num_steps = env.spec.max_episode_steps

# used for early stopping
num_consecutive_successes = 0


# train
for ep in range(num_episodes):
    s = env.reset()
    policy.epsilon = 0.1 if ep < 10 else 0.01

    for t in range(num_steps):
        a = policy(s)
        s_next, r, done, info = env.step(a)

        Q.update(s, a, r, done)

        if done:
            if t == num_steps - 1:
                num_consecutive_successes += 1
                print("num_consecutive_successes: {}"
                      .format(num_consecutive_successes))
            else:
                num_consecutive_successes = 0
                print("failed after {} steps".format(t))
            break

        s = s_next

    if num_consecutive_successes == 10:
        break


# run env one more time to render
s = env.reset()
env.render()
policy.epsilon = 0

for t in range(num_steps):

    a = policy(s)
    s, r, done, info = env.step(a)
    env.render()

    if done:
        break

env.close()
WARNING:tensorflow:From /home/kris/.local/lib/python3.6/site-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /home/kris/.local/lib/python3.6/site-packages/tensorflow/python/ops/losses/losses_impl.py:448: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
failed after 105 steps
failed after 63 steps
failed after 64 steps
failed after 126 steps
failed after 106 steps
failed after 84 steps
failed after 160 steps
failed after 80 steps
failed after 83 steps
failed after 190 steps
failed after 152 steps
failed after 85 steps
failed after 62 steps
failed after 157 steps
failed after 57 steps
failed after 64 steps
failed after 65 steps
failed after 91 steps
failed after 81 steps
failed after 137 steps
num_consecutive_successes: 1
num_consecutive_successes: 2
num_consecutive_successes: 3
num_consecutive_successes: 4
num_consecutive_successes: 5
num_consecutive_successes: 6
failed after 155 steps
num_consecutive_successes: 1
failed after 146 steps
num_consecutive_successes: 1
num_consecutive_successes: 2
num_consecutive_successes: 3
num_consecutive_successes: 4
num_consecutive_successes: 5
num_consecutive_successes: 6
num_consecutive_successes: 7
failed after 198 steps
num_consecutive_successes: 1
num_consecutive_successes: 2
num_consecutive_successes: 3
num_consecutive_successes: 4
num_consecutive_successes: 5
num_consecutive_successes: 6
num_consecutive_successes: 7
num_consecutive_successes: 8
num_consecutive_successes: 9
num_consecutive_successes: 10