utils.plots
1# -*- coding: utf-8 -*- 2 3import math 4import numpy as np 5import pandas as pd 6import seaborn as sns 7import matplotlib.pyplot as plt 8import warnings 9from matplotlib.colors import LinearSegmentedColormap 10 11 12class Plots: 13 @staticmethod 14 def values_heat_map(data, title, size, show=True): 15 data = np.around(np.array(data).reshape(size), 2) 16 df = pd.DataFrame(data=data) 17 sns.heatmap(df, annot=True).set_title(title) 18 19 if show: 20 plt.show() 21 22 @staticmethod 23 def v_iters_plot(data, title, show=True): 24 df = pd.DataFrame(data=data) 25 sns.set_theme(style="whitegrid") 26 sns.lineplot(data=df, legend=None).set_title(title) 27 28 if show: 29 plt.show() 30 31 #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 32 @staticmethod 33 def get_policy_map(pi, val_max, actions, map_size): 34 """Map the best learned action to arrows.""" 35 #convert pi to numpy array 36 best_action = np.zeros(val_max.shape[0], dtype=np.int32) 37 for idx, val in enumerate(val_max): 38 best_action[idx] = pi[idx] 39 policy_map = np.empty(best_action.flatten().shape, dtype=str) 40 for idx, val in enumerate(best_action.flatten()): 41 policy_map[idx] = actions[val] 42 policy_map = policy_map.reshape(map_size[0], map_size[1]) 43 val_max = val_max.reshape(map_size[0], map_size[1]) 44 return val_max, policy_map 45 46 #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 47 @staticmethod 48 def plot_policy(val_max, directions, map_size, title, show=True): 49 """Plot the policy learned.""" 50 sns.heatmap( 51 val_max, 52 annot=directions, 53 fmt="", 54 cmap=sns.color_palette("Blues", as_cmap=True), 55 linewidths=0.7, 56 linecolor="black", 57 xticklabels=[], 58 yticklabels=[], 59 annot_kws={"fontsize": "xx-large"}, 60 ).set(title=title) 61 img_title = f"Policy_{map_size[0]}x{map_size[1]}.png" 62 63 if show: 64 plt.show()
class
Plots:
13class Plots: 14 @staticmethod 15 def values_heat_map(data, title, size, show=True): 16 data = np.around(np.array(data).reshape(size), 2) 17 df = pd.DataFrame(data=data) 18 sns.heatmap(df, annot=True).set_title(title) 19 20 if show: 21 plt.show() 22 23 @staticmethod 24 def v_iters_plot(data, title, show=True): 25 df = pd.DataFrame(data=data) 26 sns.set_theme(style="whitegrid") 27 sns.lineplot(data=df, legend=None).set_title(title) 28 29 if show: 30 plt.show() 31 32 #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 33 @staticmethod 34 def get_policy_map(pi, val_max, actions, map_size): 35 """Map the best learned action to arrows.""" 36 #convert pi to numpy array 37 best_action = np.zeros(val_max.shape[0], dtype=np.int32) 38 for idx, val in enumerate(val_max): 39 best_action[idx] = pi[idx] 40 policy_map = np.empty(best_action.flatten().shape, dtype=str) 41 for idx, val in enumerate(best_action.flatten()): 42 policy_map[idx] = actions[val] 43 policy_map = policy_map.reshape(map_size[0], map_size[1]) 44 val_max = val_max.reshape(map_size[0], map_size[1]) 45 return val_max, policy_map 46 47 #modified from https://gymnasium.farama.org/tutorials/training_agents/FrozenLake_tuto/ 48 @staticmethod 49 def plot_policy(val_max, directions, map_size, title, show=True): 50 """Plot the policy learned.""" 51 sns.heatmap( 52 val_max, 53 annot=directions, 54 fmt="", 55 cmap=sns.color_palette("Blues", as_cmap=True), 56 linewidths=0.7, 57 linecolor="black", 58 xticklabels=[], 59 yticklabels=[], 60 annot_kws={"fontsize": "xx-large"}, 61 ).set(title=title) 62 img_title = f"Policy_{map_size[0]}x{map_size[1]}.png" 63 64 if show: 65 plt.show()
@staticmethod
def
get_policy_map(pi, val_max, actions, map_size):
33 @staticmethod 34 def get_policy_map(pi, val_max, actions, map_size): 35 """Map the best learned action to arrows.""" 36 #convert pi to numpy array 37 best_action = np.zeros(val_max.shape[0], dtype=np.int32) 38 for idx, val in enumerate(val_max): 39 best_action[idx] = pi[idx] 40 policy_map = np.empty(best_action.flatten().shape, dtype=str) 41 for idx, val in enumerate(best_action.flatten()): 42 policy_map[idx] = actions[val] 43 policy_map = policy_map.reshape(map_size[0], map_size[1]) 44 val_max = val_max.reshape(map_size[0], map_size[1]) 45 return val_max, policy_map
Map the best learned action to arrows.
@staticmethod
def
plot_policy(val_max, directions, map_size, title, show=True):
48 @staticmethod 49 def plot_policy(val_max, directions, map_size, title, show=True): 50 """Plot the policy learned.""" 51 sns.heatmap( 52 val_max, 53 annot=directions, 54 fmt="", 55 cmap=sns.color_palette("Blues", as_cmap=True), 56 linewidths=0.7, 57 linecolor="black", 58 xticklabels=[], 59 yticklabels=[], 60 annot_kws={"fontsize": "xx-large"}, 61 ).set(title=title) 62 img_title = f"Policy_{map_size[0]}x{map_size[1]}.png" 63 64 if show: 65 plt.show()
Plot the policy learned.