bettermdptools.utils.plots

 1# -*- coding: utf-8 -*-
 2
 3import matplotlib.pyplot as plt
 4import numpy as np
 5import pandas as pd
 6import seaborn as sns
 7
 8
 9class Plots:
10    @staticmethod
11    def values_heat_map(data, title, size, show=True):
12        data = np.around(np.array(data).reshape(size), 2)
13        df = pd.DataFrame(data=data)
14        sns.heatmap(df, annot=True).set_title(title)
15
16        if show:
17            plt.show()
18
19    @staticmethod
20    def v_iters_plot(data, title, show=True):
21        df = pd.DataFrame(data=data)
22        sns.set_theme(style="whitegrid")
23        sns.lineplot(data=df, legend=None).set_title(title)
24
25        if show:
26            plt.show()
27
28    # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
29    @staticmethod
30    def get_policy_map(pi, val_max, actions, map_size):
31        """Map the best learned action to arrows."""
32        # convert pi to numpy array
33        best_action = np.zeros(val_max.shape[0], dtype=np.int32)
34        for idx, val in enumerate(val_max):
35            best_action[idx] = pi[idx]
36        policy_map = np.empty(best_action.flatten().shape, dtype=str)
37        for idx, val in enumerate(best_action.flatten()):
38            policy_map[idx] = actions[val]
39        policy_map = policy_map.reshape(map_size[0], map_size[1])
40        val_max = val_max.reshape(map_size[0], map_size[1])
41        return val_max, policy_map
42
43    # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
44    @staticmethod
45    def plot_policy(val_max, directions, map_size, title, show=True):
46        """Plot the policy learned."""
47        sns.heatmap(
48            val_max,
49            annot=directions,
50            fmt="",
51            cmap=sns.color_palette("Blues", as_cmap=True),
52            linewidths=0.7,
53            linecolor="black",
54            xticklabels=[],
55            yticklabels=[],
56            annot_kws={"fontsize": "xx-large"},
57        ).set(title=title)
58        img_title = f"Policy_{map_size[0]}x{map_size[1]}.png"
59
60        if show:
61            plt.show()
class Plots:
10class Plots:
11    @staticmethod
12    def values_heat_map(data, title, size, show=True):
13        data = np.around(np.array(data).reshape(size), 2)
14        df = pd.DataFrame(data=data)
15        sns.heatmap(df, annot=True).set_title(title)
16
17        if show:
18            plt.show()
19
20    @staticmethod
21    def v_iters_plot(data, title, show=True):
22        df = pd.DataFrame(data=data)
23        sns.set_theme(style="whitegrid")
24        sns.lineplot(data=df, legend=None).set_title(title)
25
26        if show:
27            plt.show()
28
29    # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
30    @staticmethod
31    def get_policy_map(pi, val_max, actions, map_size):
32        """Map the best learned action to arrows."""
33        # convert pi to numpy array
34        best_action = np.zeros(val_max.shape[0], dtype=np.int32)
35        for idx, val in enumerate(val_max):
36            best_action[idx] = pi[idx]
37        policy_map = np.empty(best_action.flatten().shape, dtype=str)
38        for idx, val in enumerate(best_action.flatten()):
39            policy_map[idx] = actions[val]
40        policy_map = policy_map.reshape(map_size[0], map_size[1])
41        val_max = val_max.reshape(map_size[0], map_size[1])
42        return val_max, policy_map
43
44    # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/
45    @staticmethod
46    def plot_policy(val_max, directions, map_size, title, show=True):
47        """Plot the policy learned."""
48        sns.heatmap(
49            val_max,
50            annot=directions,
51            fmt="",
52            cmap=sns.color_palette("Blues", as_cmap=True),
53            linewidths=0.7,
54            linecolor="black",
55            xticklabels=[],
56            yticklabels=[],
57            annot_kws={"fontsize": "xx-large"},
58        ).set(title=title)
59        img_title = f"Policy_{map_size[0]}x{map_size[1]}.png"
60
61        if show:
62            plt.show()
@staticmethod
def values_heat_map(data, title, size, show=True):
11    @staticmethod
12    def values_heat_map(data, title, size, show=True):
13        data = np.around(np.array(data).reshape(size), 2)
14        df = pd.DataFrame(data=data)
15        sns.heatmap(df, annot=True).set_title(title)
16
17        if show:
18            plt.show()
@staticmethod
def v_iters_plot(data, title, show=True):
20    @staticmethod
21    def v_iters_plot(data, title, show=True):
22        df = pd.DataFrame(data=data)
23        sns.set_theme(style="whitegrid")
24        sns.lineplot(data=df, legend=None).set_title(title)
25
26        if show:
27            plt.show()
@staticmethod
def get_policy_map(pi, val_max, actions, map_size):
30    @staticmethod
31    def get_policy_map(pi, val_max, actions, map_size):
32        """Map the best learned action to arrows."""
33        # convert pi to numpy array
34        best_action = np.zeros(val_max.shape[0], dtype=np.int32)
35        for idx, val in enumerate(val_max):
36            best_action[idx] = pi[idx]
37        policy_map = np.empty(best_action.flatten().shape, dtype=str)
38        for idx, val in enumerate(best_action.flatten()):
39            policy_map[idx] = actions[val]
40        policy_map = policy_map.reshape(map_size[0], map_size[1])
41        val_max = val_max.reshape(map_size[0], map_size[1])
42        return val_max, policy_map

Map the best learned action to arrows.

@staticmethod
def plot_policy(val_max, directions, map_size, title, show=True):
45    @staticmethod
46    def plot_policy(val_max, directions, map_size, title, show=True):
47        """Plot the policy learned."""
48        sns.heatmap(
49            val_max,
50            annot=directions,
51            fmt="",
52            cmap=sns.color_palette("Blues", as_cmap=True),
53            linewidths=0.7,
54            linecolor="black",
55            xticklabels=[],
56            yticklabels=[],
57            annot_kws={"fontsize": "xx-large"},
58        ).set(title=title)
59        img_title = f"Policy_{map_size[0]}x{map_size[1]}.png"
60
61        if show:
62            plt.show()

Plot the policy learned.