Source code for mavis.illustrate.scatter

import os

from ..bam.read import sequenced_strand, pileup
from ..util import log
from ..interval import Interval
from ..validate.constants import DEFAULTS as VALIDATION_DEFAULTS


[docs]def bam_to_scatter(bam_file, chrom, start, end, bin_size, strand=None, axis_name=None, ymax=None, min_mapping_quality=0): """ pull data from a bam file to set up a scatter plot of the pileup Args: bam_file (str): path to the bam file chrom (str): chromosome name start (int): genomic start position for the plot end (int): genomic end position for the plot bin_size (int): number of genomic positions to group together and average to reduce data strand (STRAND): expected strand axis_name (str): axis name ymax (int): maximum value to plot the y axis min_mapping_quality (int): minimum mapping quality for reads to be considered in the plot Returns: ScatterPlot: the scatter plot representing the bam pileup """ import pysam if not axis_name: axis_name = os.path.basename(bam_file) # one plot per bam log('reading:', bam_file) plot = None samfile = pysam.AlignmentFile(bam_file, 'rb') def read_filter(read): if read.mapping_quality < min_mapping_quality: return True if strand is None: return False try: return sequenced_strand(read, VALIDATION_DEFAULTS.strand_determining_read) != strand except ValueError: return True try: points = [] avg_points = [] try: for refpos, count in pileup(samfile.fetch(chrom, start, end), filter_func=read_filter): if refpos <= end and refpos >= start: points.append((refpos, count)) except ValueError: # chrom not in bam pass else: grouping_indices = [x for x in range(0, len(points), bin_size)] grouping_indices.append(None) for st_index, end_index in zip(grouping_indices[0::], grouping_indices[1::]): pos = [x for x, y in points[st_index:end_index]] pos = Interval(min(pos), max(pos)) cov = [y for x, y in points[st_index:end_index]] cov = Interval(sum(cov) / len(cov)) avg_points.append((pos, cov)) log('scatter plot {} has {} points'.format(axis_name, len(avg_points))) plot = ScatterPlot( avg_points, axis_name, ymin=0, ymax=max([y.start for x, y in avg_points] + [100]) if ymax is None else ymax ) finally: samfile.close() return plot
[docs]class ScatterPlot: """ holds settings that will go into matplotlib after conversion using the mapping system """ def __init__( self, points, y_axis_label, ymax=None, ymin=None, xmin=None, xmax=None, hmarkers=None, height=100, point_radius=2, title='', yticks=None, colors=None ): self.hmarkers = hmarkers if hmarkers is not None else [] self.yticks = yticks if yticks is not None else [] self.colors = colors if colors else {} self.ymin = ymin self.ymax = ymax self.points = points if self.ymin is None and (yticks or points): self.ymin = min([y.start for x, y in points] + yticks) if self.ymax is None and (yticks or points): self.ymax = max([y.end for x, y in points] + yticks) self.xmin = xmin self.xmax = xmax if self.xmin is None and points: self.xmin = min([x.start for x, y in points]) if self.xmax is None and points: self.xmax = max([x.end for x, y in points]) self.y_axis_label = y_axis_label self.height = 100 self.point_radius = 2 self.title = title
[docs]def draw_scatter(ds, canvas, plot, xmapping): """ given a xmapping, draw the scatter plot svg group Args: ds (DiagramSettings): the settings/constants to use for building the svg canvas (svgwrite.canvas): the svgwrite object used to create new svg elements plot (ScatterPlot): the plot to be drawn xmapping (:class:`dict` of :class:`Interval` by :class:`Interval`): dict used for conversion of coordinates in the xaxis to pixel positions """ # generate the y coordinate mapping plot_group = canvas.g(class_='scatter_plot') yratio = plot.height / (abs(plot.ymax - plot.ymin)) ypx = [] xpx = [] for xpo, ypo in plot.points: try: temp = Interval.convert_ratioed_pos(xmapping, xpo.start) xp = Interval.convert_ratioed_pos(xmapping, xpo.end) xp = xp | temp xpx.append((xp, xpo)) temp = plot.height - abs(ypo.start - plot.ymin) * yratio yp = Interval(plot.height - abs(ypo.end - plot.ymin) * yratio, temp) ypx.append((yp, ypo)) except IndexError: pass for x, y in zip(xpx, ypx): xp, xpo = x yp, ypo = y if xp.length() > ds.scatter_marker_radius: plot_group.add(canvas.line( (xp.start, yp.center), (xp.end, yp.center), stroke='#000000', stroke_width=ds.scatter_error_bar_stroke_width )) if yp.length() > ds.scatter_marker_radius: plot_group.add(canvas.line( (xp.center, yp.start), (xp.center, yp.end), stroke='#000000', stroke_width=ds.scatter_error_bar_stroke_width )) plot_group.add(canvas.circle( center=(xp.center, yp.center), fill=plot.colors.get((xpo, ypo), '#000000'), r=ds.scatter_marker_radius )) xmax = Interval.convert_ratioed_pos(xmapping, plot.xmax).end for py in plot.hmarkers: py = plot.height - abs(py - plot.ymin) * yratio plot_group.add( canvas.line( start=(0, py), end=(xmax, py), stroke='blue' ) ) # draw left y axis plot_group.add(canvas.line( start=(0, 0), end=(0, plot.height), stroke='#000000' )) ytick_labels = [0] # draw start and end markers on the y axis for y in plot.yticks: ytick_labels.append(len(str(y))) py = plot.height - abs(y - plot.ymin) * yratio plot_group.add( canvas.line( start=(0 - ds.scatter_yaxis_tick_size, py), end=(0, py), stroke='#000000' )) plot_group.add( canvas.text( str(y), insert=( 0 - ds.scatter_yaxis_tick_size - ds.padding, py + ds.scatter_ytick_font_size * ds.font_central_shift_ratio), fill=ds.label_color, style=ds.font_style.format(font_size=ds.scatter_ytick_font_size, text_anchor='end') )) shift = max(ytick_labels) x = 0 - ds.padding * 2 - ds.scatter_axis_font_size - ds.scatter_yaxis_tick_size - \ ds.scatter_ytick_font_size * ds.font_width_height_ratio * shift y = plot.height / 2 yaxis = canvas.text( plot.y_axis_label, insert=(x, y), fill=ds.label_color, style=ds.font_style.format(font_size=ds.scatter_axis_font_size, text_anchor='start'), class_='y_axis_label' ) plot_group.add(yaxis) cx = len(plot.y_axis_label) * ds.font_width_height_ratio * ds.scatter_axis_font_size / 2 yaxis.rotate(270, (x + cx, y)) yaxis.translate(0, 0) y = plot.height setattr(plot_group, 'height', y) return plot_group