In [ ]:
import gym

from sklearn.linear_model import SGDRegressor
from sklearn.preprocessing import FunctionTransformer

from keras_gym.utils import reload_all
reload_all()

from keras_gym.value_functions import GenericQ
from keras_gym.policies import ValuePolicy
from keras_gym.algorithms import MonteCarlo
from keras_gym.environments import ConnectFour
from keras_gym.errors import UnavailableActionError                    


# define sklearn model for approximating Q-function
regressor = SGDRegressor(eta0=0.1, learning_rate='constant')
transformer = None


# define Q, its induced policy and update algorithm
Q1 = GenericQ(ConnectFour, regressor, transformer)
Q2 = GenericQ(ConnectFour, regressor, transformer)
policy1 = ValuePolicy(Q1)
policy2 = ValuePolicy(Q2)
algo1 = MonteCarlo(Q1, gamma=0.9)
algo2 = MonteCarlo(Q2, gamma=0.9)



env = ConnectFour()

swap_policies = True
num_episodes = 100000
epsilon = 0.1

# compute avg return
total_return = 0.0
num_observations = 0.0
history = []


for episode in range(1, num_episodes + 1):
    random_adversary = episode % 11 == 0
    epsilon = np.clip(1 - episode / (0.9 * num_episodes), 0, 1)

    if swap_policies and episode % 2:
        env.adversary_policy = policy2
        policy = policy1
        algo = algo1
    else:
        env.adversary_policy = policy1
        policy = policy2
        algo = algo2
        
    if random_adversary:
        env.adversary_policy = RandomPolicy(env)

    s = env.reset()

    for t in range(1, env.max_time_steps + 1):
        a = policy.epsilon_greedy(s, epsilon)
        if a not in env.available_actions:
            a = env.rnd.choice(env.available_actions)
        
        s_next, r, done, info = env.step(a)
        if not random_adversary:
            algo.update(s, a, r, s_next, done)

        if done:
            if random_adversary:
                total_return += r
                num_observations += 1.0
                history.append((total_return, num_observations, total_return / num_observations))
            if np.log2(episode).is_integer() and num_observations > 0:
                print("episode={}, avg_return={}".format(episode, total_return / num_observations))
                env.render()
            break
        
        s = s_next
    

# import pickle

# with open("policy1.pkl", 'wb') as f:
#     pickle.dump(policy1, f)

# with open("policy2.pkl", 'wb') as f:
#     pickle.dump(policy2, f)
module reloaded: keras_gym.errors
module reloaded: keras_gym.environments.adversarial
module reloaded: keras_gym.environments
episode=16, avg_return=-1.0
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
|   |   |   | ● |   |   | ○ |
+---+---+---+---+---+---+---+
| ● |   |   | ● |   |   | ○ |
+---+---+---+---+---+---+---+

episode=32, avg_return=-1.0
              ▽              
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
| ● |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
| ● | ● |   | ○ |   | ● |   |
+---+---+---+---+---+---+---+
| ○ | ○ |   | ● |   | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ● | ● | ● | ● | ● |
+---+---+---+---+---+---+---+
| ○ | ○ | ● | ● | ○ | ○ | ● |
+---+---+---+---+---+---+---+

episode=64, avg_return=-1.0
                  ▽          
+---+---+---+---+---+---+---+
| ● |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
| ○ |   |   |   |   |   | ● |
+---+---+---+---+---+---+---+
| ● |   |   |   | ○ |   | ○ |
+---+---+---+---+---+---+---+
| ○ |   |   |   | ○ |   | ○ |
+---+---+---+---+---+---+---+
| ● |   |   |   | ○ |   | ● |
+---+---+---+---+---+---+---+
| ○ | ● | ● | ● | ○ |   | ● |
+---+---+---+---+---+---+---+

episode=128, avg_return=-0.2727272727272727
                  ▽          
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ● |   |   |   |
+---+---+---+---+---+---+---+
| ● |   |   | ● |   |   |   |
+---+---+---+---+---+---+---+
| ● |   |   | ○ | ○ |   | ○ |
+---+---+---+---+---+---+---+
| ● | ○ |   | ○ | ○ |   | ○ |
+---+---+---+---+---+---+---+
| ● | ● | ● | ○ | ● |   | ○ |
+---+---+---+---+---+---+---+

episode=256, avg_return=-0.30434782608695654
              ▽              
+---+---+---+---+---+---+---+
|   |   |   |   | ● | ● | ○ |
+---+---+---+---+---+---+---+
| ○ | ○ |   |   | ● | ● | ○ |
+---+---+---+---+---+---+---+
| ○ | ● | ● |   | ● | ○ | ● |
+---+---+---+---+---+---+---+
| ● | ● | ○ |   | ○ | ○ | ○ |
+---+---+---+---+---+---+---+
| ● | ○ | ○ | ○ | ● | ● | ○ |
+---+---+---+---+---+---+---+
| ○ | ● | ● | ● | ○ | ○ | ● |
+---+---+---+---+---+---+---+

episode=512, avg_return=-0.043478260869565216
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● | ○ |
+---+---+---+---+---+---+---+
|   |   | ● |   |   | ○ | ○ |
+---+---+---+---+---+---+---+
| ○ | ● | ● |   |   | ● | ○ |
+---+---+---+---+---+---+---+
| ● | ○ | ○ |   | ● | ○ | ● |
+---+---+---+---+---+---+---+

episode=1024, avg_return=-0.053763440860215055
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ○ |   | ● | ○ |
+---+---+---+---+---+---+---+
|   |   |   | ○ |   | ○ | ○ |
+---+---+---+---+---+---+---+
| ● |   |   | ● | ● | ● | ● |
+---+---+---+---+---+---+---+
| ○ | ● |   | ○ | ○ | ● | ● |
+---+---+---+---+---+---+---+

episode=2048, avg_return=-0.03225806451612903
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ○ |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ○ |   |   |   |
+---+---+---+---+---+---+---+
|   | ● | ● | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+
|   | ○ | ● | ○ | ● | ○ | ○ |
+---+---+---+---+---+---+---+

episode=4096, avg_return=-0.09543010752688172
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
|   |   |   |   | ● |   | ○ |
+---+---+---+---+---+---+---+
|   |   |   |   | ● |   | ○ |
+---+---+---+---+---+---+---+
|   | ● |   | ● | ● |   | ● |
+---+---+---+---+---+---+---+
| ● | ○ |   | ○ | ● | ● | ○ |
+---+---+---+---+---+---+---+
| ● | ○ | ○ | ● | ○ | ○ | ○ |
+---+---+---+---+---+---+---+

episode=8192, avg_return=-0.02217741935483871
          ▽                  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ● |
+---+---+---+---+---+---+---+
|   |   | ○ | ○ | ● |   | ○ |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ● |   | ○ |
+---+---+---+---+---+---+---+
| ● | ● | ● | ● | ○ | ○ | ● |
+---+---+---+---+---+---+---+

episode=16384, avg_return=-0.0047011417058428475
  ▽                          
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● |   |   |   |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ |   |   |   |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ● | ● | ● | ● |   |
+---+---+---+---+---+---+---+

episode=32768, avg_return=-0.010577568838146407
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   | ○ |
+---+---+---+---+---+---+---+
| ○ |   |   | ● |   |   | ○ |
+---+---+---+---+---+---+---+
| ● |   | ● | ● | ○ |   | ○ |
+---+---+---+---+---+---+---+
| ● |   | ● | ○ | ● |   | ○ |
+---+---+---+---+---+---+---+
| ● |   | ○ | ○ | ● |   | ● |
+---+---+---+---+---+---+---+
| ○ | ● | ○ | ● | ○ |   | ○ |
+---+---+---+---+---+---+---+

In [36]:
class RandomPolicy:
    def __init__(self, env):
        self.env = env

    def greedy(self, s):
        return self.env.action_space.sample()

    def thompson(self, s):
        return self.env.action_space.sample()


env.adversary_policy = RandomPolicy(env)
policy = policy1
algo = algo1


avg_return = 0.0
num_episodes = 1024


for episode in range(1, num_episodes + 1):
    s = env.reset()

    for t in range(1, env.max_time_steps + 1):
        a = policy.greedy(s)
        if a not in env.available_actions:
            a = env.rnd.choice(env.available_actions)

        s, r, done, info = env.step(a)

        if done:
            avg_return += (r - avg_return) / episode
            if np.log2(episode).is_integer():
                print("episode={}, avg_return={}".format(episode, avg_return))
                env.render()
            break
episode=1, avg_return=1.0
              ▽              
+---+---+---+---+---+---+---+
|   |   | ○ |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ○ | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ○ | ● | ● | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ● | ● | ○ | ○ | ○ |   |
+---+---+---+---+---+---+---+
| ● | ● | ○ | ● | ● | ● | ● |
+---+---+---+---+---+---+---+
| ○ | ● | ● | ○ | ○ | ● | ○ |
+---+---+---+---+---+---+---+

episode=2, avg_return=1.0
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
|   |   | ○ | ● |   | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ● | ● | ● | ○ | ○ |
+---+---+---+---+---+---+---+
| ● |   | ● | ● | ○ | ○ | ○ |
+---+---+---+---+---+---+---+
| ○ |   | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

episode=4, avg_return=0.5
          ▽                  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ○ |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ○ | ○ |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ○ |   |   |
+---+---+---+---+---+---+---+
| ● |   | ● | ● | ○ | ● |   |
+---+---+---+---+---+---+---+
| ● |   | ● | ○ | ● | ○ | ○ |
+---+---+---+---+---+---+---+

episode=8, avg_return=0.25
                      ▽      
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ● |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   | ● |   | ○ |   |
+---+---+---+---+---+---+---+
| ○ |   |   | ● | ○ | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ● | ● | ○ | ● | ○ |
+---+---+---+---+---+---+---+

episode=16, avg_return=0.125
                      ▽      
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
|   |   | ● | ○ |   | ● |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ● | ○ |   | ● |   |
+---+---+---+---+---+---+---+
| ○ | ● | ● | ● |   | ○ |   |
+---+---+---+---+---+---+---+
| ● | ○ | ● | ● | ○ | ○ |   |
+---+---+---+---+---+---+---+

episode=32, avg_return=0.0625
      ▽                      
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   | ○ |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   | ○ |   |   |   |   |   |
+---+---+---+---+---+---+---+
| ● | ○ |   |   | ● | ● |   |
+---+---+---+---+---+---+---+
| ● | ○ | ● | ○ | ○ | ● |   |
+---+---+---+---+---+---+---+

episode=64, avg_return=0.1875
      ▽                      
+---+---+---+---+---+---+---+
|   |   |   | ● |   |   |   |
+---+---+---+---+---+---+---+
| ● |   |   | ○ |   |   |   |
+---+---+---+---+---+---+---+
| ○ |   | ○ | ● | ○ | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ● | ○ | ○ | ○ |   |
+---+---+---+---+---+---+---+
| ● |   | ○ | ● | ○ | ● | ● |
+---+---+---+---+---+---+---+
| ● | ○ | ● | ○ | ● | ● | ○ |
+---+---+---+---+---+---+---+

episode=128, avg_return=0.31249999999999994
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ○ |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● |   |   |   | ○ |
+---+---+---+---+---+---+---+
| ○ | ○ | ● | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+
| ● | ○ | ○ | ○ | ● | ● | ○ |
+---+---+---+---+---+---+---+

episode=256, avg_return=0.33593749999999994
                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ○ | ● |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ○ | ○ |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ● | ○ | ○ |
+---+---+---+---+---+---+---+

episode=512, avg_return=0.27734375
      ▽                      
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   | ○ |   |   | ○ |   |
+---+---+---+---+---+---+---+
|   | ○ | ○ |   | ○ | ● |   |
+---+---+---+---+---+---+---+
|   | ○ | ● |   | ○ | ● |   |
+---+---+---+---+---+---+---+
|   | ○ | ● | ● | ○ | ● |   |
+---+---+---+---+---+---+---+
| ● | ○ | ● | ● | ● | ○ |   |
+---+---+---+---+---+---+---+

episode=1024, avg_return=0.2929687499999998
                  ▽          
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ○ |   | ○ |   |   |
+---+---+---+---+---+---+---+
|   |   | ● | ● | ● | ● |   |
+---+---+---+---+---+---+---+
|   | ○ | ● | ○ | ○ | ● | ○ |
+---+---+---+---+---+---+---+

In [38]:
env.reset()

for t in range(1, env.max_time_steps + 1):
    a = policy.greedy(s)
    if a not in env.available_actions:
        a = env.rnd.choice(env.available_actions)
    
    env.render()
    s, r, done, info = env.step(a)

    if done:
        avg_return += (r - avg_return) / episode
        if np.log2(episode).is_integer():
            print("episode={}, avg_return={}".format(episode, avg_return))
            env.render()
        break
                             
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+

      ▽                      
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   | ○ |   |   |   | ● |   |
+---+---+---+---+---+---+---+

                      ▽      
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
|   | ○ |   | ● |   | ● |   |
+---+---+---+---+---+---+---+

                          ▽  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
|   | ○ |   | ● |   | ● | ○ |
+---+---+---+---+---+---+---+

  ▽                          
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ |   | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

          ▽                  
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

              ▽              
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   | ○ | ● | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

                  ▽          
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   |   |   | ○ | ● |   |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ● | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

  ▽                          
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   | ● |   | ○ | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ● | ○ | ● | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

              ▽              
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ○ | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ● | ○ | ● | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+

episode=1024, avg_return=0.294348996132612
              ▽              
+---+---+---+---+---+---+---+
|   |   |   |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● |   |   |   |   |
+---+---+---+---+---+---+---+
|   |   | ● |   |   | ● |   |
+---+---+---+---+---+---+---+
|   |   | ● | ○ | ○ | ● |   |
+---+---+---+---+---+---+---+
| ○ |   | ● | ○ | ● | ○ |   |
+---+---+---+---+---+---+---+
| ○ | ○ | ○ | ● | ● | ● | ○ |
+---+---+---+---+---+---+---+