Source code for dohlee.plot

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sys

from adjustText import adjust_text
from sklearn.decomposition import PCA
from matplotlib.lines import Line2D
from collections import Counter
from functools import wraps


def _get_ax_to_draw(ax, figsize=None):
    """If ax is not specified, return an axis to draw a plot.
    Otherwise, return ax.
    """
    if ax:
        return ax
    else:
        fig = plt.figure(figsize=figsize) if figsize else plt.figure()
        ax = fig.add_subplot(111)
        return ax


def _try_save(file, dpi=150):
    """If file is specified, save the figure to the file with given resolution in dpi.
    Otherwise, show the figure.
    """
    return None if file is None else plt.savefig(file, dpi=dpi)


[docs]def save(file, dpi=120, tight_layout=True): """Save plot to a file. Attributes: file (str): Path to the resulting image file. dpi (int, default=120): Resolution. tight_layout (bool, default=True): Whether to run plt.tight_layout() before saving the plot. """ if tight_layout: plt.tight_layout() plt.savefig(file, dpi=dpi)
def _my_plot(func): @wraps(func) def wrapper(*args, **kwargs): ax = _get_ax_to_draw(kwargs.get('ax', None), kwargs.get('figsize', None)) if 'ax' in kwargs: del kwargs['ax'] if 'figsize' in kwargs: del kwargs['figsize'] if 'title' in kwargs: ax.set_title(kwargs['title']) del kwargs['title'] if 'xlim' in kwargs: ax.set_xlim(kwargs['xlim']) del kwargs['xlim'] if 'ylim' in kwargs: ax.set_ylim(kwargs['ylim']) del kwargs['ylim'] if 'xlabel' in kwargs: ax.set_xlabel(kwargs['xlabel']) del kwargs['xlabel'] if 'ylabel' in kwargs: ax.set_ylabel(kwargs['ylabel']) del kwargs['ylabel'] result = func(*args, ax=ax, **kwargs) _try_save(kwargs.get('file', None)) return wrapper # Set plot preference which looks good to me.
[docs]def set_style(style='white', palette='deep', context='talk', font='Helvetica Neue', font_scale=1.25, rcparams={'figure.figsize': (11.7, 8.27)}): sns.set(style=style, palette=palette, context=context, font=font, font_scale=font_scale, rc=rcparams)
[docs]@_my_plot def frequency(data, order=None, sort_by_values=False, ax=None, **kwargs): """Plot frequency bar chart. Examples: frequency([1, 2, 2, 3, 3, 3], order=[3, 1, 2], sort_by_values=True) Attributes: data (list): A list of elements. order (list): A list of elements which represents the order of the elements to be plotted. sort_by_values (bool): If True, the plot will be sorted in decreasing order of frequency values. ax (pyplot axis): Axis to draw the plot. """ counter = Counter(data) if order is None: if sort_by_values: order = sorted(counter, key=counter.get, reverse=True) else: order = sorted(counter.keys()) else: assert set(order) == set(counter.keys()), 'The order must contain all the elements.' counts = [counter[key] for key in order] ax = _get_ax_to_draw(ax) # Some parameters used for plot configuration. height = max(counts) * 1.167 xticks = list(range(len(order))) # Preset plot. ax.set_xticks(xticks) ax.set_xticklabels(order) ax.set_xlim([-0.66, len(counter) - 0.33]) ax.set_ylim([0, height]) # Plot bar chart. ax.bar(x=xticks, height=counts, width=0.66, **kwargs) # Add text indicating frequency for each bar. for x, count in zip(xticks, counts): ax.text(x=x, y=count, s=str(count), size='small', va='bottom', ha='center')
[docs]@_my_plot def histogram(data, ax=None, **kwargs): plt.hist(data, color='black', ec='white', lw=1.33, **kwargs)
[docs]@_my_plot def volcano(data, x, y, padj, label, cutoff=0.05, sample1=None, sample2=None, ax=None): """Draw a volcano plot. >>> volcano(data=data, x='log2FoldChange', y='pvalue', label='Gene_Symbol', cutoff=0.05, padj='padj', figsize=(10.8, 8.4)) :param dataframe data: A dataframe resulting from DEG-discovery tool. :param str x: Column name denoting log2 fold change. :param str y: Column name denoting p-value. (Note that p-values will be log10-transformed, so they should not be transformed beforehand.) :param str padj: Column name denoting adjusted p-value. :param str label: Column name denoting gene identifier. :param float cutoff: (Optional) Adjusted p-value cutoff value to report significant DEGs. :param str sample1: (Optional) First sample name. :param str sample2: (Optional) Second sample name. :param axis ax: (Optional) Matplotlib axis to draw the plot on. """ # Set x and y extent. x_limit = max(-np.min(data[x].values), np.max(data[x].values)) x_extent = [-x_limit * 1.22, x_limit * 1.22] ax.set_xlim(x_extent) ax.set_ylim([0, max(-np.log10(data[y].values)) * 1.1]) data_not_significant = data[data[padj] >= cutoff] ax.scatter(data_not_significant[x].values, -np.log10(data_not_significant[y].values), color='grey', marker='.') data_significant = data[data.padj < cutoff] ax.scatter(data_significant[x].values, -np.log10(data_significant[y].values), color='red', marker='.') texts = [ax.text(row[x], -np.log10(row[y]), row[label], fontsize=12) for _, row in data_significant.iterrows()] adjust_text(texts) line = Line2D([0], [0], color='red', lw=2.33, label='Adjusted p < %g' % cutoff) plt.legend(handles=[line]) if (not sample1 is None) and (not sample2 is None): ax.set_xlabel(r'$log_2$FC ($log_{2}\frac{%s}{%s}$)' % (sample1.replace(' ', '\ '), sample2.replace(' ', '\ '))) else: ax.set_xlabel(r'$log_2$FC') ax.set_ylabel(r'$log_{10}$(p-value)')
[docs]@_my_plot def pca(data, labels=None, ax=None, **kwargs): '''Draw a simple principle component analysis plot of the data. :param matrix data: Input data. Numpy array recommended. :param list labels: (Optional) Corresponding labels to each datum. If specified, data points in the plot will be colored according to the label. :param axis ax: (Optional) Matplotlib axis to draw the plot on. :param kwargs: Any other keyword arguments will be passed onto matplotlib.pyplot.scatter. ''' # Fit PCA and get pc's pca = PCA(n_components=2) pca.fit(data) pc = pca.transform(data) if labels is None: plt.scatter(x=pc[:, 0], y=pc[:, 1]) else: # If labels are attached, color them in different colors labels = np.array(labels) for label in set(labels): toDraw = (labels == label) # only draw these points this time plt.scatter(x=pc[toDraw, 0], y=pc[toDraw, 1], label=label, **kwargs) plt.legend(loc='best') # show explained variance ratio in the plot axes explainedVarianceRatio = pca.explained_variance_ratio_ plt.xlabel('PC1 ({:.2%})'.format(explainedVarianceRatio[0])) plt.ylabel('PC2 ({:.2%})'.format(explainedVarianceRatio[1]))