Source code for mavis.cluster.main

import inspect
import itertools
import os
from shortuuid import uuid
import time

from .cluster import merge_breakpoint_pairs
from .constants import DEFAULTS
from ..constants import COLUMNS
from ..util import filter_on_overlap, filter_uninformative, generate_complete_stamp, log, log_arguments, mkdirp, output_tabbed_file, read_inputs, write_bed_file


[docs]def split_clusters(clusters, outputdir, batch_id, min_clusters_per_file=0, max_files=1, write_bed_summary=True): """ For a set of clusters creates a bed file representation of all clusters. Also splits the clusters evenly into multiple files based on the user parameters (min_clusters_per_file, max_files) Returns: list: of output file names (not including the bed file) """ if write_bed_summary: bedfile = os.path.join(outputdir, 'clusters.bed') write_bed_file(bedfile, itertools.chain.from_iterable([b.get_bed_repesentation() for b in clusters])) number_of_jobs = len(clusters) // min_clusters_per_file if number_of_jobs > max_files: number_of_jobs = max_files elif number_of_jobs == 0: number_of_jobs = 1 jobs = [[] for j in range(0, number_of_jobs)] clusters = sorted(clusters, key=lambda x: (x.break1.chr, x.break1.start, x.break2.chr, x.break2.start)) # split up consecutive clusters for i, cluster in enumerate(clusters): jobs[i % len(jobs)].append(cluster) assert sum([len(j) for j in jobs]) == len(clusters) output_files = [] for i, job in enumerate(jobs): # generate an output file filename = os.path.join(outputdir, '{}-{}.tab'.format(batch_id, i + 1)) output_files.append(filename) output_tabbed_file(job, filename) return output_files
[docs]def main( inputs, output, strand_specific, library, protocol, disease_status, masking, annotations, limit_to_chr=DEFAULTS.limit_to_chr, cluster_initial_size_limit=DEFAULTS.cluster_initial_size_limit, cluster_radius=DEFAULTS.cluster_radius, uninformative_filter=DEFAULTS.uninformative_filter, max_proximity=DEFAULTS.max_proximity, min_clusters_per_file=DEFAULTS.min_clusters_per_file, max_files=DEFAULTS.max_files, log_args=False, batch_id=None, split_only=False, start_time=int(time.time()), **kwargs ): """ Args: inputs (:class:`List` of :class:`str`): list of input files to read output (str): path to the output directory strand_specific (bool): is the bam using a strand specific protocol library (str): the library to look for in each of the input files protocol (PROTOCOL): the sequence protocol (genome or transcriptome) masking (object): see :func:`~mavis.annotate.file_io.load_masking_regions` cluster_clique_size (int): the maximum size of cliques to search for using the exact algorithm cluster_radius (int): distance (in breakpoint pairs) used in deciding to join bpps in a cluster uninformative_filter (bool): if True then clusters should be filtered out if they are not within a specified (max_proximity) distance to any annotation max_proximity (int): the maximum distance away an annotation can be before the uninformative_filter is applied annotations (object): see :func:`~mavis.annotate.file_io.load_reference_genes` min_clusters_per_file (int): the minimum number of clusters to output to a file max_files (int): the maximum number of files to split clusters into """ if log_args: frame = inspect.currentframe() args, _, _, values = inspect.getargvalues(frame) args = {arg: values[arg] for arg in args if arg != 'log_args'} log_arguments(args) # output files batch_id = 'batch-' + str(uuid()) if batch_id is None else batch_id filtered_output = os.path.join(output, 'filtered_pairs.tab') cluster_assign_output = os.path.join(output, 'cluster_assignment.tab') # load the input files breakpoint_pairs = read_inputs( inputs, cast={COLUMNS.tools: lambda x: set(x.split(';')) if x else set() if not split_only else x}, add_default={ COLUMNS.library: library, COLUMNS.protocol: protocol, COLUMNS.tools: '', COLUMNS.disease_status: disease_status, COLUMNS.stranded: False, COLUMNS.tracking_id: '' }, expand_strand=False, expand_orient=True, expand_svtype=True ) # filter any breakpoint pairs where the library and protocol don't match other_libs = set() other_chr = set() unfiltered_breakpoint_pairs = [] filtered_pairs = [] log('filtering by library and chr name') for bpp in breakpoint_pairs: if bpp.library is None: bpp.library = library if bpp.library != library: other_libs.add(bpp.library) bpp.data[COLUMNS.filter_comment] = 'Not the target library name' filtered_pairs.append(bpp) elif bpp.break1.chr in limit_to_chr and bpp.break2.chr in limit_to_chr: unfiltered_breakpoint_pairs.append(bpp) else: other_chr.update({bpp.break1.chr, bpp.break2.chr}) bpp.data[COLUMNS.filter_comment] = 'Non standard chromosome name' filtered_pairs.append(bpp) other_chr -= set(limit_to_chr) breakpoint_pairs = unfiltered_breakpoint_pairs if other_libs: log('warning: ignoring breakpoints found for other libraries:', sorted([l for l in other_libs])) if other_chr: log('warning: filtered events on chromosomes not found in "limit_to_chr"', other_chr) # filter by masking file breakpoint_pairs, masked_pairs = filter_on_overlap(breakpoint_pairs, masking) for bpp in masked_pairs: filtered_pairs.append(bpp) # filter by informative if uninformative_filter: log('filtering from', len(breakpoint_pairs), 'breakpoint pairs using informative filter') pass_clusters, uninformative_clusters = filter_uninformative(annotations, breakpoint_pairs, max_proximity=max_proximity) log( 'filtered from', len(breakpoint_pairs), 'down to', len(pass_clusters), '(removed {})'.format(len(uninformative_clusters)) ) breakpoint_pairs = pass_clusters for bpp in uninformative_clusters: bpp.data[COLUMNS.filter_comment] = 'Uninformative' filtered_pairs.append(bpp) else: log('did not apply uninformative filter') output_tabbed_file(filtered_pairs, filtered_output) mkdirp(output) if not split_only: log('computing clusters') clusters = merge_breakpoint_pairs( breakpoint_pairs, cluster_radius=cluster_radius, cluster_initial_size_limit=cluster_initial_size_limit) hist = {} length_hist = {} for cluster in clusters: input_pairs = clusters[cluster] hist[len(input_pairs)] = hist.get(len(input_pairs), 0) + 1 cluster1 = round(len(cluster[0]), -2) cluster2 = round(len(cluster[1]), -2) length_hist[cluster1] = length_hist.get(cluster1, 0) + 1 length_hist[cluster2] = length_hist.get(cluster2, 0) + 1 cluster.data[COLUMNS.cluster_id] = str(uuid()) cluster.data[COLUMNS.cluster_size] = len(input_pairs) temp = set() data_items = set() combined_tracking_id = set() # group the tracking ids for pair in input_pairs: temp.update(pair.data[COLUMNS.tools]) data_items.update(pair.data.keys()) if COLUMNS.tracking_id in pair.data and pair.tracking_id: combined_tracking_id.update(pair.tracking_id.split(';')) cluster.data[COLUMNS.tools] = ';'.join(sorted(list(temp))) cluster.data[COLUMNS.tracking_id] = ';'.join(sorted(list(combined_tracking_id))) data_items -= {COLUMNS.tools, COLUMNS.tracking_id} # retain all data where data is consistent between the input pairs for item in data_items: common_data = [p.data.get(item, None) for p in input_pairs] common_data = set(common_data) if len(common_data) == 1: cluster.data[item] = list(common_data)[0] log('computed', len(clusters), 'clusters', time_stamp=False) log('cluster input pairs distribution', sorted(hist.items()), time_stamp=False) log('cluster intervals lengths', sorted(length_hist.items()), time_stamp=False) # map input pairs to cluster ids # now create the mapping from the original input files to the cluster(s) rows = {} for cluster, input_pairs in clusters.items(): for pair in input_pairs: if pair not in rows: rows[pair] = pair.flatten() rows[pair][COLUMNS.tools].update(pair.data[COLUMNS.tools]) rows[pair].setdefault('clusters', set()).add(cluster.data[COLUMNS.cluster_id]) for row in rows.values(): row['clusters'] = ';'.join([str(c) for c in sorted(list(row['clusters']))]) row[COLUMNS.tools] = ';'.join(sorted(list(row[COLUMNS.tools]))) output_tabbed_file(rows.values(), cluster_assign_output) breakpoint_pairs = list(clusters.keys()) output_files = split_clusters( breakpoint_pairs, output, batch_id, min_clusters_per_file=min_clusters_per_file, max_files=max_files, write_bed_summary=True ) generate_complete_stamp(output, log, start_time=start_time) return output_files