utils.plots

 1# -*- coding: utf-8 -*-
 2
 3import math
 4import numpy as np
 5import pandas as pd
 6import seaborn as sns
 7import matplotlib.pyplot as plt
 8import warnings
 9from matplotlib.colors import LinearSegmentedColormap
10
11
12class Plots:
13    @staticmethod
14    def values_heat_map(data, title, size, show=True):
15        data = np.around(np.array(data).reshape(size), 2)
16        df = pd.DataFrame(data=data)
17        sns.heatmap(df, annot=True).set_title(title)
18
19        if show:
20            plt.show()
21
22    @staticmethod
23    def v_iters_plot(data, title, show=True):
24        df = pd.DataFrame(data=data)
25        sns.set_theme(style="whitegrid")
26        sns.lineplot(data=df, legend=None).set_title(title)
27
28        if show:
29            plt.show()
30
31    #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
32    @staticmethod
33    def get_policy_map(pi, val_max, actions, map_size):
34        """Map the best learned action to arrows."""
35        #convert pi to numpy array
36        best_action = np.zeros(val_max.shape[0], dtype=np.int32)
37        for idx, val in enumerate(val_max):
38            best_action[idx] = pi[idx]
39        policy_map = np.empty(best_action.flatten().shape, dtype=str)
40        for idx, val in enumerate(best_action.flatten()):
41            policy_map[idx] = actions[val]
42        policy_map = policy_map.reshape(map_size[0], map_size[1])
43        val_max = val_max.reshape(map_size[0], map_size[1])
44        return val_max, policy_map
45
46    #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
47    @staticmethod
48    def plot_policy(val_max, directions, map_size, title, show=True):
49        """Plot the policy learned."""
50        sns.heatmap(
51            val_max,
52            annot=directions,
53            fmt="",
54            cmap=sns.color_palette("Blues", as_cmap=True),
55            linewidths=0.7,
56            linecolor="black",
57            xticklabels=[],
58            yticklabels=[],
59            annot_kws={"fontsize": "xx-large"},
60        ).set(title=title)
61        img_title = f"Policy_{map_size[0]}x{map_size[1]}.png"
62
63        if show:
64            plt.show()
class Plots:
13class Plots:
14    @staticmethod
15    def values_heat_map(data, title, size, show=True):
16        data = np.around(np.array(data).reshape(size), 2)
17        df = pd.DataFrame(data=data)
18        sns.heatmap(df, annot=True).set_title(title)
19
20        if show:
21            plt.show()
22
23    @staticmethod
24    def v_iters_plot(data, title, show=True):
25        df = pd.DataFrame(data=data)
26        sns.set_theme(style="whitegrid")
27        sns.lineplot(data=df, legend=None).set_title(title)
28
29        if show:
30            plt.show()
31
32    #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
33    @staticmethod
34    def get_policy_map(pi, val_max, actions, map_size):
35        """Map the best learned action to arrows."""
36        #convert pi to numpy array
37        best_action = np.zeros(val_max.shape[0], dtype=np.int32)
38        for idx, val in enumerate(val_max):
39            best_action[idx] = pi[idx]
40        policy_map = np.empty(best_action.flatten().shape, dtype=str)
41        for idx, val in enumerate(best_action.flatten()):
42            policy_map[idx] = actions[val]
43        policy_map = policy_map.reshape(map_size[0], map_size[1])
44        val_max = val_max.reshape(map_size[0], map_size[1])
45        return val_max, policy_map
46
47    #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
48    @staticmethod
49    def plot_policy(val_max, directions, map_size, title, show=True):
50        """Plot the policy learned."""
51        sns.heatmap(
52            val_max,
53            annot=directions,
54            fmt="",
55            cmap=sns.color_palette("Blues", as_cmap=True),
56            linewidths=0.7,
57            linecolor="black",
58            xticklabels=[],
59            yticklabels=[],
60            annot_kws={"fontsize": "xx-large"},
61        ).set(title=title)
62        img_title = f"Policy_{map_size[0]}x{map_size[1]}.png"
63
64        if show:
65            plt.show()
@staticmethod
def values_heat_map(data, title, size, show=True):
14    @staticmethod
15    def values_heat_map(data, title, size, show=True):
16        data = np.around(np.array(data).reshape(size), 2)
17        df = pd.DataFrame(data=data)
18        sns.heatmap(df, annot=True).set_title(title)
19
20        if show:
21            plt.show()
@staticmethod
def v_iters_plot(data, title, show=True):
23    @staticmethod
24    def v_iters_plot(data, title, show=True):
25        df = pd.DataFrame(data=data)
26        sns.set_theme(style="whitegrid")
27        sns.lineplot(data=df, legend=None).set_title(title)
28
29        if show:
30            plt.show()
@staticmethod
def get_policy_map(pi, val_max, actions, map_size):
33    @staticmethod
34    def get_policy_map(pi, val_max, actions, map_size):
35        """Map the best learned action to arrows."""
36        #convert pi to numpy array
37        best_action = np.zeros(val_max.shape[0], dtype=np.int32)
38        for idx, val in enumerate(val_max):
39            best_action[idx] = pi[idx]
40        policy_map = np.empty(best_action.flatten().shape, dtype=str)
41        for idx, val in enumerate(best_action.flatten()):
42            policy_map[idx] = actions[val]
43        policy_map = policy_map.reshape(map_size[0], map_size[1])
44        val_max = val_max.reshape(map_size[0], map_size[1])
45        return val_max, policy_map

Map the best learned action to arrows.

@staticmethod
def plot_policy(val_max, directions, map_size, title, show=True):
48    @staticmethod
49    def plot_policy(val_max, directions, map_size, title, show=True):
50        """Plot the policy learned."""
51        sns.heatmap(
52            val_max,
53            annot=directions,
54            fmt="",
55            cmap=sns.color_palette("Blues", as_cmap=True),
56            linewidths=0.7,
57            linecolor="black",
58            xticklabels=[],
59            yticklabels=[],
60            annot_kws={"fontsize": "xx-large"},
61        ).set(title=title)
62        img_title = f"Policy_{map_size[0]}x{map_size[1]}.png"
63
64        if show:
65            plt.show()

Plot the policy learned.