Source code for pydtmc.plotting

# -*- coding: utf-8 -*-

__all__ = [
    'plot_eigenvalues',
    'plot_graph',
    'plot_redistributions',
    'plot_walk'
]


###########
# IMPORTS #
###########


# Major

import matplotlib.colors as mplc
import matplotlib.image as mpli
import matplotlib.pyplot as pp
import matplotlib.ticker as mplt
import networkx as nx
import numpy as np
import numpy.linalg as npl

# Minor

from inspect import (
    trace
)

from io import (
    BytesIO
)

from subprocess import (
    call,
    PIPE
)

# Internal

from .custom_types import (
    # Specific
    tdistributions,
    tmc,
    oplot,
    ostate,
    tstateswalk_flex,
    ostatus,
    # Lists
    tlist_str
)

from .exceptions import (
    ValidationError
)

from .validation import (
    validate_boolean,
    validate_distribution,
    validate_enumerator,
    validate_dpi,
    validate_markov_chain,
    validate_state,
    validate_status,
    validate_walk
)


#############
# CONSTANTS #
#############


color_black = '#000000'
color_gray = '#E0E0E0'
color_white = '#FFFFFF'
colors = ['#80B1D3', '#FFED6F', '#B3DE69', '#BEBADA', '#FDB462', '#8DD3C7', '#FB8072', '#FCCDE5']


#############
# FUNCTIONS #
#############


[docs]def plot_eigenvalues(mc: tmc, dpi: int = 100) -> oplot: """ The function plots the eigenvalues of the Markov chain on the complex plane. :param mc: the target Markov chain. :param dpi: the resolution of the plot expressed in dots per inch (by default, 100). :return: None if Matplotlib is in interactive mode as the plot is immediately displayed, otherwise the handles of the plot. :raises ValidationError: if any input argument is not compliant. """ try: mc = validate_markov_chain(mc) dpi = validate_dpi(dpi) except Exception as e: argument = ''.join(trace()[0][4]).split('=', 1)[0].strip() raise ValidationError(str(e).replace('@arg@', argument)) from None figure, ax = pp.subplots(dpi=dpi) handles = list() labels = list() theta = np.linspace(0.0, 2.0 * np.pi, 200) values = npl.eigvals(mc.p) values = values.astype(complex) values_final = np.unique(np.append(values, np.array([1.0]).astype(complex))) x_unit_circle = np.cos(theta) y_unit_circle = np.sin(theta) if mc.is_ergodic: values_abs = np.sort(np.abs(values)) values_ct1 = np.isclose(values_abs, 1.0) if not np.all(values_ct1): mu = values_abs[~values_ct1][-1] if not np.isclose(mu, 0.0): x_slem_circle = mu * x_unit_circle y_slem_circle = mu * y_unit_circle cs = np.linspace(-1.1, 1.1, 201) x_spectral_gap, y_spectral_gap = np.meshgrid(cs, cs) z_spectral_gap = x_spectral_gap**2 + y_spectral_gap**2 h = ax.contourf(x_spectral_gap, y_spectral_gap, z_spectral_gap, alpha=0.2, colors='r', levels=[mu ** 2.0, 1.0]) handles.append(pp.Rectangle((0.0, 0.0), 1.0, 1.0, fc=h.collections[0].get_facecolor()[0])) labels.append('Spectral Gap') ax.plot(x_slem_circle, y_slem_circle, color='red', linestyle='--', linewidth=1.5) ax.plot(x_unit_circle, y_unit_circle, color='red', linestyle='-', linewidth=3.0) h, = ax.plot(np.real(values_final), np.imag(values_final), color='blue', linestyle='None', marker='*', markersize=12.5) handles.append(h) labels.append('Eigenvalues') ax.set_xlim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1) ax.set_aspect('equal') formatter = mplt.FormatStrFormatter('%g') ax.xaxis.set_major_formatter(formatter) ax.yaxis.set_major_formatter(formatter) ax.set_xticks(np.linspace(-1.0, 1.0, 9)) ax.set_yticks(np.linspace(-1.0, 1.0, 9)) ax.grid(which='major') ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(0.5, -0.1), loc='upper center', ncol=len(handles)) ax.set_title('Eigenplot', fontsize=15.0, fontweight='bold') pp.subplots_adjust(bottom=0.2) if pp.isinteractive(): pp.show(block=False) return None return figure, ax
[docs]def plot_graph(mc: tmc, nodes_color: bool = True, nodes_type: bool = True, edges_color: bool = True, edges_value: bool = True, dpi: int = 100) -> oplot: """ The function plots the directed graph of the Markov chain. | **Notes:** Graphviz and Pydot are not required, but they provide access to extended graphs with additional features. :param mc: the target Markov chain. :param nodes_color: a boolean indicating whether to display colored nodes based on communicating classes (by default, True). :param nodes_type: a boolean indicating whether to use a different shape for every node type (by default, True). :param edges_color: a boolean indicating whether to display edges using a gradient based on transition probabilities, valid only for extended graphs (by default, True). :param edges_value: a boolean indicating whether to display the transition probability of every edge (by default, True). :param dpi: the resolution of the plot expressed in dots per inch (by default, 100). :return: None if Matplotlib is in interactive mode as the plot is immediately displayed, otherwise the handles of the plot. :raises ValidationError: if any input argument is not compliant. """ def edge_colors(hex_from: str, hex_to: str, steps: int) -> tlist_str: begin = [int(hex_from[i:i + 2], 16) for i in range(1, 6, 2)] end = [int(hex_to[i:i + 2], 16) for i in range(1, 6, 2)] clist = [hex_from] for s in range(1, steps): vector = [int(begin[j] + (float(s) / (steps - 1)) * (end[j] - begin[j])) for j in range(3)] rgb = [int(v) for v in vector] clist.append(f'#{"".join(["0{0:x}".format(v) if v < 16 else "{0:x}".format(v) for v in rgb])}') return clist def node_colors(count: int) -> tlist_str: colors_limit = len(colors) - 1 offset = 0 clist = list() while count > 0: clist.append(colors[offset]) offset += 1 if offset > colors_limit: offset = 0 count -= 1 return clist try: mc = validate_markov_chain(mc) nodes_color = validate_boolean(nodes_color) nodes_type = validate_boolean(nodes_type) edges_color = validate_boolean(edges_color) edges_value = validate_boolean(edges_value) dpi = validate_dpi(dpi) except Exception as e: argument = ''.join(trace()[0][4]).split('=', 1)[0].strip() raise ValidationError(str(e).replace('@arg@', argument)) from None extended_graph = True # noinspection PyBroadException try: call(['dot', '-V'], stdout=PIPE, stderr=PIPE) except Exception: extended_graph = False pass try: import pydot as pyd except ImportError: extended_graph = False pass g = mc.to_directed_graph() if extended_graph: g_pydot = nx.nx_pydot.to_pydot(g) if nodes_color: c = node_colors(len(mc.communicating_classes)) for node in g_pydot.get_nodes(): state = node.get_name() for x, cc in enumerate(mc.communicating_classes): if state in cc: node.set_style('filled') node.set_fillcolor(c[x]) break if nodes_type: for node in g_pydot.get_nodes(): if node.get_name() in mc.transient_states: node.set_shape('box') else: node.set_shape('ellipse') if edges_color: c = edge_colors(color_gray, color_black, 20) for edge in g_pydot.get_edges(): probability = mc.transition_probability(edge.get_destination(), edge.get_source()) x = int(round(probability * 20.0)) - 1 edge.set_style('filled') edge.set_color(c[x]) if edges_value: for edge in g_pydot.get_edges(): probability = mc.transition_probability(edge.get_destination(), edge.get_source()) if probability.is_integer(): edge.set_label(f' {probability:g}.0 ') else: edge.set_label(f' {round(probability,2):g} ') buffer = BytesIO() buffer.write(g_pydot.create_png()) buffer.seek(0) img = mpli.imread(buffer) img_x = img.shape[0] / dpi img_y = img.shape[1] / dpi figure = pp.figure(figsize=(img_y, img_x), dpi=dpi) figure.figimage(img) ax = figure.gca() else: mpi = pp.isinteractive() pp.interactive(False) figure, ax = pp.subplots(dpi=dpi) positions = nx.spring_layout(g) node_colors_all = node_colors(len(mc.communicating_classes)) for node in g.nodes: node_color = None if nodes_color: for x, cc in enumerate(mc.communicating_classes): if node in cc: node_color = node_colors_all[x] break if nodes_type: if node in mc.transient_states: node_shape = 's' else: node_shape = 'o' else: node_shape = None if node_color is not None and node_shape is not None: nx.draw_networkx_nodes(g, positions, ax=ax, nodelist=[node], edgecolors='k', node_color=node_color, node_shape=node_shape) elif node_color is not None and node_shape is None: nx.draw_networkx_nodes(g, positions, ax=ax, nodelist=[node], edgecolors='k', node_color=node_color) elif node_color is None and node_shape is not None: nx.draw_networkx_nodes(g, positions, ax=ax, nodelist=[node], edgecolors='k', node_shape=node_shape) else: nx.draw_networkx_nodes(g, positions, ax=ax, edgecolors='k') nx.draw_networkx_labels(g, positions, ax=ax) nx.draw_networkx_edges(g, positions, ax=ax, arrows=False) if edges_value: edges_values = dict() for edge in g.edges: probability = mc.transition_probability(edge[1], edge[0]) if probability.is_integer(): value = f' {probability:g}.0 ' else: value = f' {round(probability,2):g} ' edges_values[(edge[0], edge[1])] = value nx.draw_networkx_edge_labels(g, positions, ax=ax, edge_labels=edges_values, label_pos=0.7) pp.interactive(mpi) if pp.isinteractive(): pp.show(block=False) return None return figure, ax
[docs]def plot_redistributions(mc: tmc, distributions: tdistributions, initial_status: ostatus = None, plot_type: str = 'projection', dpi: int = 100) -> oplot: """ The function plots a redistribution of states on the given Markov chain. :param mc: the target Markov chain. :param distributions: a sequence of redistributions or the number of redistributions to perform. :param initial_status: the initial state or the initial distribution of the states (if omitted, the states are assumed to be uniformly distributed). :param plot_type: the type of plot to display (either heatmap or projection; projection by default). :param dpi: the resolution of the plot expressed in dots per inch (by default, 100). :return: None if Matplotlib is in interactive mode as the plot is immediately displayed, otherwise the handles of the plot. :raises ValueError: if the "distributions" parameter represents a sequence of redistributions and the "initial_status" parameter does not match its first element. :raises ValidationError: if any input argument is not compliant. """ try: mc = validate_markov_chain(mc) distributions = validate_distribution(distributions, mc.size) if initial_status is not None: initial_status = validate_status(initial_status, mc.states) plot_type = validate_enumerator(plot_type, ['heatmap', 'projection']) dpi = validate_dpi(dpi) except Exception as e: argument = ''.join(trace()[0][4]).split('=', 1)[0].strip() raise ValidationError(str(e).replace('@arg@', argument)) from None if isinstance(distributions, int): distributions = mc.redistribute(distributions, initial_status=initial_status, include_initial=True, output_last=False) elif initial_status is not None and not np.array_equal(distributions[0], initial_status): raise ValueError('The "initial_status" parameter, if specified when the "distributions" parameter represents a sequence of redistributions, must match the first element.') distribution_len = len(distributions) distributions = np.array(distributions) figure, ax = pp.subplots(dpi=dpi) if plot_type == 'heatmap': color_map = mplc.LinearSegmentedColormap.from_list('ColorMap', [color_white, colors[0]], 20) ax_is = ax.imshow(np.transpose(distributions), aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) ax.set_xlabel('Steps', fontsize=13.0) ax.set_xticks(np.arange(0.0, distribution_len + 1.0, 1.0 if distribution_len <= 11 else 10.0), minor=False) ax.set_xticks(np.arange(-0.5, distribution_len, 1.0), minor=True) ax.set_xticklabels(np.arange(0, distribution_len, 1 if distribution_len <= 11 else 10)) ax.set_xlim(-0.5, distribution_len - 0.5) ax.set_yticks(np.arange(0.0, mc.size, 1.0), minor=False) ax.set_yticks(np.arange(-0.5, mc.size, 1.0), minor=True) ax.set_yticklabels(mc.states) ax.grid(which='minor', color='k') cb = figure.colorbar(ax_is, drawedges=True, orientation='horizontal', ticks=[0.0, 0.25, 0.5, 0.75, 1.0]) cb.ax.set_xticklabels([0.0, 0.25, 0.5, 0.75, 1.0]) ax.set_title('Redistplot (Heatmap)', fontsize=15.0, fontweight='bold') else: ax.set_prop_cycle('color', colors) if distribution_len == 2: for i in range(mc.size): ax.plot(np.arange(0.0, distribution_len, 1.0), distributions[:, i], label=mc.states[i], marker='o') else: for i in range(mc.size): ax.plot(np.arange(0.0, distribution_len, 1.0), distributions[:, i], label=mc.states[i]) if np.array_equal(distributions[0, :], np.ones(mc.size, dtype=float) / mc.size): ax.plot(0.0, distributions[0, 0], color=color_black, label="Start", marker='o', markeredgecolor=color_black, markerfacecolor=color_black) legend_size = mc.size + 1 else: legend_size = mc.size ax.set_xlabel('Steps', fontsize=13.0) ax.set_xticks(np.arange(0.0, distribution_len, 1.0 if distribution_len <= 11 else 10.0)) ax.set_xticklabels(np.arange(0, distribution_len, 1 if distribution_len <= 11 else 10)) ax.set_xlim(-1.0 * (distribution_len * 0.05), distribution_len * 1.05) ax.set_ylabel('Frequencies', fontsize=13.0) ax.set_yticks(np.linspace(0.0, 1.0, 11)) ax.set_ylim(-0.05, 1.05) ax.grid() ax.legend(bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=legend_size) ax.set_title('Redistplot (Projection)', fontsize=15.0, fontweight='bold') pp.subplots_adjust(bottom=0.2) if pp.isinteractive(): pp.show(block=False) return None return figure, ax
[docs]def plot_walk(mc: tmc, walk: tstateswalk_flex, initial_state: ostate = None, plot_type: str = 'histogram', dpi: int = 100) -> oplot: """ The function plots a random walk on the given Markov chain. :param mc: the target Markov chain. :param walk: a sequence of states or the number of simulations to perform. :param initial_state: the initial state of the walk (if omitted, it is chosen uniformly at random). :param plot_type: the type of plot to display (one of histogram, sequence and transitions; histogram by default). :param dpi: the resolution of the plot expressed in dots per inch (by default, 100). :return: None if Matplotlib is in interactive mode as the plot is immediately displayed, otherwise the handles of the plot. :raises ValueError: if the "walk" parameter represents a sequence of states and the "initial_state" parameter does not match its first element. :raises ValidationError: if any input argument is not compliant. """ try: mc = validate_markov_chain(mc) walk = validate_walk(walk, mc.states) if initial_state is not None: initial_state = validate_state(initial_state, mc.states) plot_type = validate_enumerator(plot_type, ['histogram', 'sequence', 'transitions']) dpi = validate_dpi(dpi) except Exception as e: argument = ''.join(trace()[0][4]).split('=', 1)[0].strip() raise ValidationError(str(e).replace('@arg@', argument)) from None if isinstance(walk, int): walk = mc.walk(walk, initial_state=initial_state, include_initial=True, output_indices=True) elif initial_state is not None and (walk[0] != initial_state): raise ValueError('The "initial_state" parameter, if specified when the "walk" parameter represents a sequence of states, must match the first element.') walk_len = len(walk) figure, ax = pp.subplots(dpi=dpi) if plot_type == 'histogram': walk_histogram = np.zeros((mc.size, walk_len), dtype=float) for i, s in enumerate(walk): walk_histogram[s, i] = 1.0 walk_histogram = np.sum(walk_histogram, axis=1) / np.sum(walk_histogram) ax.bar(np.arange(0.0, mc.size, 1.0), walk_histogram, edgecolor=color_black, facecolor=colors[0]) ax.set_xlabel('States', fontsize=13.0) ax.set_xticks(np.arange(0.0, mc.size, 1.0)) ax.set_xticklabels(mc.states) ax.set_ylabel('Frequencies', fontsize=13.0) ax.set_yticks(np.linspace(0.0, 1.0, 11)) ax.set_ylim(0.0, 1.0) ax.set_title('Walkplot (Histogram)', fontsize=15.0, fontweight='bold') elif plot_type == 'sequence': walk_sequence = np.zeros((mc.size, walk_len), dtype=float) for i, s in enumerate(walk): walk_sequence[s, i] = 1.0 color_map = mplc.LinearSegmentedColormap.from_list('ColorMap', [color_white, colors[0]], 2) ax.imshow(walk_sequence, aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) ax.set_xlabel('Steps', fontsize=13.0) ax.set_xticks(np.arange(0.0, walk_len + 1.0, 1.0 if walk_len <= 11 else 10.0), minor=False) ax.set_xticks(np.arange(-0.5, walk_len, 1.0), minor=True) ax.set_xticklabels(np.arange(0, walk_len, 1 if walk_len <= 11 else 10)) ax.set_xlim(-0.5, walk_len - 0.5) ax.set_ylabel('States', fontsize=13.0) ax.set_yticks(np.arange(0.0, mc.size, 1.0), minor=False) ax.set_yticks(np.arange(-0.5, mc.size, 1.0), minor=True) ax.set_yticklabels(mc.states) ax.grid(which='minor', color='k') ax.set_title('Walkplot (Sequence)', fontsize=15.0, fontweight='bold') else: walk_transitions = np.zeros((mc.size, mc.size), dtype=float) for i in range(1, walk_len): walk_transitions[walk[i - 1], walk[i]] += 1.0 walk_transitions = walk_transitions / np.sum(walk_transitions) color_map = mplc.LinearSegmentedColormap.from_list('ColorMap', [color_white, colors[0]], 20) ax_is = ax.imshow(walk_transitions, aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) ax.set_xticks(np.arange(0.0, mc.size, 1.0), minor=False) ax.set_xticks(np.arange(-0.5, mc.size, 1.0), minor=True) ax.set_xticklabels(mc.states) ax.set_yticks(np.arange(0.0, mc.size, 1.0), minor=False) ax.set_yticks(np.arange(-0.5, mc.size, 1.0), minor=True) ax.set_yticklabels(mc.states) ax.grid(which='minor', color='k') cb = figure.colorbar(ax_is, drawedges=True, orientation='horizontal', ticks=[0.0, 0.25, 0.5, 0.75, 1.0]) cb.ax.set_xticklabels([0.0, 0.25, 0.5, 0.75, 1.0]) ax.set_title('Walkplot (Transitions)', fontsize=15.0, fontweight='bold') if pp.isinteractive(): pp.show(block=False) return None return figure, ax