bettermdptools.utils.plots
1# -*- coding: utf-8 -*- 2 3import matplotlib.pyplot as plt 4import numpy as np 5import pandas as pd 6import seaborn as sns 7 8 9class Plots: 10 @staticmethod 11 def values_heat_map(data, title, size, show=True): 12 data = np.around(np.array(data).reshape(size), 2) 13 df = pd.DataFrame(data=data) 14 sns.heatmap(df, annot=True).set_title(title) 15 16 if show: 17 plt.show() 18 19 @staticmethod 20 def v_iters_plot(data, title, show=True): 21 df = pd.DataFrame(data=data) 22 sns.set_theme(style="whitegrid") 23 sns.lineplot(data=df, legend=None).set_title(title) 24 25 if show: 26 plt.show() 27 28 # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 29 @staticmethod 30 def get_policy_map(pi, val_max, actions, map_size): 31 """Map the best learned action to arrows.""" 32 # convert pi to numpy array 33 best_action = np.zeros(val_max.shape[0], dtype=np.int32) 34 for idx, val in enumerate(val_max): 35 best_action[idx] = pi[idx] 36 policy_map = np.empty(best_action.flatten().shape, dtype=str) 37 for idx, val in enumerate(best_action.flatten()): 38 policy_map[idx] = actions[val] 39 policy_map = policy_map.reshape(map_size[0], map_size[1]) 40 val_max = val_max.reshape(map_size[0], map_size[1]) 41 return val_max, policy_map 42 43 # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 44 @staticmethod 45 def plot_policy(val_max, directions, map_size, title, show=True): 46 """Plot the policy learned.""" 47 sns.heatmap( 48 val_max, 49 annot=directions, 50 fmt="", 51 cmap=sns.color_palette("Blues", as_cmap=True), 52 linewidths=0.7, 53 linecolor="black", 54 xticklabels=[], 55 yticklabels=[], 56 annot_kws={"fontsize": "xx-large"}, 57 ).set(title=title) 58 img_title = f"Policy_{map_size[0]}x{map_size[1]}.png" 59 60 if show: 61 plt.show()
class
Plots:
10class Plots: 11 @staticmethod 12 def values_heat_map(data, title, size, show=True): 13 data = np.around(np.array(data).reshape(size), 2) 14 df = pd.DataFrame(data=data) 15 sns.heatmap(df, annot=True).set_title(title) 16 17 if show: 18 plt.show() 19 20 @staticmethod 21 def v_iters_plot(data, title, show=True): 22 df = pd.DataFrame(data=data) 23 sns.set_theme(style="whitegrid") 24 sns.lineplot(data=df, legend=None).set_title(title) 25 26 if show: 27 plt.show() 28 29 # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 30 @staticmethod 31 def get_policy_map(pi, val_max, actions, map_size): 32 """Map the best learned action to arrows.""" 33 # convert pi to numpy array 34 best_action = np.zeros(val_max.shape[0], dtype=np.int32) 35 for idx, val in enumerate(val_max): 36 best_action[idx] = pi[idx] 37 policy_map = np.empty(best_action.flatten().shape, dtype=str) 38 for idx, val in enumerate(best_action.flatten()): 39 policy_map[idx] = actions[val] 40 policy_map = policy_map.reshape(map_size[0], map_size[1]) 41 val_max = val_max.reshape(map_size[0], map_size[1]) 42 return val_max, policy_map 43 44 # modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 45 @staticmethod 46 def plot_policy(val_max, directions, map_size, title, show=True): 47 """Plot the policy learned.""" 48 sns.heatmap( 49 val_max, 50 annot=directions, 51 fmt="", 52 cmap=sns.color_palette("Blues", as_cmap=True), 53 linewidths=0.7, 54 linecolor="black", 55 xticklabels=[], 56 yticklabels=[], 57 annot_kws={"fontsize": "xx-large"}, 58 ).set(title=title) 59 img_title = f"Policy_{map_size[0]}x{map_size[1]}.png" 60 61 if show: 62 plt.show()
@staticmethod
def
get_policy_map(pi, val_max, actions, map_size):
30 @staticmethod 31 def get_policy_map(pi, val_max, actions, map_size): 32 """Map the best learned action to arrows.""" 33 # convert pi to numpy array 34 best_action = np.zeros(val_max.shape[0], dtype=np.int32) 35 for idx, val in enumerate(val_max): 36 best_action[idx] = pi[idx] 37 policy_map = np.empty(best_action.flatten().shape, dtype=str) 38 for idx, val in enumerate(best_action.flatten()): 39 policy_map[idx] = actions[val] 40 policy_map = policy_map.reshape(map_size[0], map_size[1]) 41 val_max = val_max.reshape(map_size[0], map_size[1]) 42 return val_max, policy_map
Map the best learned action to arrows.
@staticmethod
def
plot_policy(val_max, directions, map_size, title, show=True):
45 @staticmethod 46 def plot_policy(val_max, directions, map_size, title, show=True): 47 """Plot the policy learned.""" 48 sns.heatmap( 49 val_max, 50 annot=directions, 51 fmt="", 52 cmap=sns.color_palette("Blues", as_cmap=True), 53 linewidths=0.7, 54 linecolor="black", 55 xticklabels=[], 56 yticklabels=[], 57 annot_kws={"fontsize": "xx-large"}, 58 ).set(title=title) 59 img_title = f"Policy_{map_size[0]}x{map_size[1]}.png" 60 61 if show: 62 plt.show()
Plot the policy learned.