Source code for hicberg.benchmark

import uuid
from os import mkdir #path
from shutil import rmtree, copy
from pathlib import Path
from functools import partial

from multiprocessing import Process, Pool


from glob import glob
# import tempfile as tmpf
# import multiprocessing as mp
import subprocess as sp
from datetime import datetime


from itertools import product

from shutil import rmtree
import click
import logging

import numpy as np
import cooler

import hicberg.io as hio
import hicberg.utils as hut
# import hicberg.align as hal
import hicberg.statistics as hst
import hicberg.eval as hev
import hicberg.plot as hpl

from hicberg import logger

# BAM_FOR = "group1.1_subsampled.bam" 
# BAM_REV = 'group1.2_subsampled.bam'
BAM_FOR = "group1.1.bam" 
BAM_REV = 'group1.2.bam'
BASE_MATRIX = "original_map.cool"
FRAGMENTS = "fragments_fixed_sizes.txt"
DIST_FRAGS = "dist.frag.npy"
CHROMOSOME_SIZES = "chromosome_sizes.npy"
XS = "xs.npy"
UNRESCUED_MATRIX = "unrescued_map.cool"
RESCUED_MATRIX = "rescued_map.cool"
RESTRICTION_MAP = "restriction_map.npy"
FORWARD_IN_FILE = "group1.1.in.bam"
REVERSE_IN_FILE = "group1.2.in.bam"
FORWARD_OUT_FILE = "group1.1.out.bam"
REVERSE_OUT_FILE = "group1.2.out.bam"

# TODO : Complete docstring
[docs] def benchmark(output_dir : str = None, chromosome : str = "", position : int = 0, trans_chromosome : str = None, trans_position : int = None, strides : list[int] = [], mode : str = "full", auto : int = None, kernel_size : int = 11, deviation : float = 0.5, bins : int = None, circular :str = "", genome : str = None, pattern : str = None, threshold : float = 0.0, jitter : int = 0, trend : bool = True, top : int = 100, force : bool = False, iterations : int = 3, cpus : int = 8): """ Performs a benchmark of the HiC-Berg method by simulating the deletion of a genomic region and evaluating the method's ability to recover the original signal. The benchmark simulates the deletion of a genomic region by removing reads that map to that region from a Hi-C dataset. It then applies the HiC-Berg method to the depleted dataset to try to recover the original signal. The benchmark evaluates the method's performance by comparing the rescued contact map to the original contact map. Parameters ---------- output_dir : str, optional Path to the output directory. chromosome : str, optional Chromosome to perform the benchmark on. position : int, optional Position of the deletion on the chromosome. trans_chromosome : str, optional Chromosome to consider for trans interactions. trans_position : int, optional Position of the deletion on the trans chromosome. strides : list[int], optional List of strides to use for the deletion. mode : str, optional Mode of the HiC-Berg method to use. Can be "full" or "density". auto : int, optional Automatically determine the size of the deletion based on the given number of bins. kernel_size : int, optional Size of the kernel to use for the density estimation. deviation : float, optional Standard deviation of the kernel to use for the density estimation. bins : int, optional Number of bins to use for the deletion. circular : str, optional Whether the chromosome is circular. genome : str, optional Genome assembly to use. pattern : str, optional Pattern to use for Chromosight pre-call. threshold : float, optional Threshold to use for Chromosight pre-call. jitter : int, optional Jitter to use for Chromosight pre-call. trend : bool, optional Whether to use trend for Chromosight pre-call. top : int, optional Top patterns to use for Chromosight pre-call. force : bool, optional Whether to force the benchmark to run even if the output directory already exists. iterations : int, optional Number of iterations to run the HiC-Berg method. cpus : int, optional Number of CPUs to use for parallel processing. Raises ------ ValueError If the output directory does not exist. ValueError If the restriction map file does not exist. """ # logger.addHandler('hicberg_benchmark.log') args = locals() # Keep track of the arguments used for arg in args: logger.info("%s: %s", arg, args[arg]) learning_status = False #False bckp picking_status = False # False bckp # Setting files paths output_path = Path(output_dir) output_path_chunks = Path(output_dir, "chunks") restriction_map_path = output_path / RESTRICTION_MAP bam_for_path = output_path / BAM_FOR bam_rev_path = output_path / BAM_REV fragments_path = output_path / FRAGMENTS chromosome_size_path = output_path / CHROMOSOME_SIZES dist_frags_path = output_path / DIST_FRAGS xs_path = output_path / XS # Define unique id to keep track of the experiments id_tag = str(uuid.uuid4())[:8] # Define output child directory output_uniq_path = Path(output_dir, id_tag) output_data_path = Path(output_dir, id_tag, "data") output_plot_path = Path(output_dir, id_tag, "plots") if not output_uniq_path.exists(): mkdir(output_uniq_path) if not output_data_path.exists(): mkdir(output_data_path) if not output_plot_path.exists(): mkdir(output_plot_path) if not output_path.exists(): raise ValueError(f"Output directory {output_dir} does not exist.") if not restriction_map_path.exists(): raise ValueError(f"Restriction map file {restriction_map_path} does not exist. PLease provide an existing restriction map file.") # Define file to store results header = f"id\tdate\tchrom\tpos\tstride\ttrans_chrom\ttrans_pos\tauto\tbins\tmode\tnb_reads\tpattern\tprecision\trecall\tf1_score\tscore\n" results = output_path / "benchmark.csv" # Copy files to data_path copy(output_path /BASE_MATRIX, output_data_path / BASE_MATRIX) copy(bam_for_path, output_data_path / BAM_FOR) copy(bam_rev_path, output_data_path / BAM_REV) copy(restriction_map_path, output_data_path / RESTRICTION_MAP) copy(fragments_path, output_data_path / FRAGMENTS) copy(chromosome_size_path, output_data_path / CHROMOSOME_SIZES) copy(dist_frags_path, output_data_path / DIST_FRAGS) copy(xs_path, output_data_path / XS) base_matrix = hio.load_cooler(output_data_path / BASE_MATRIX) bin_size = base_matrix.binsize flag_file = output_path / "flag.txt" forward_in_path = output_data_path / FORWARD_IN_FILE reverse_in_path = output_data_path / REVERSE_IN_FILE forward_out_path = output_data_path / FORWARD_OUT_FILE reverse_out_path = output_data_path / REVERSE_OUT_FILE ## Chromosomsight pre-call # reformat inputs if pattern is None or pattern == "-1": chromosome = [str(c) for c in chromosome.split(",")] if trans_chromosome is not None and trans_chromosome != "-1": trans_chromosome = [str(t) for t in trans_chromosome.split(",")] else : trans_chromosome = None if trans_position is not None and trans_position != "-1": # -1 usefull when not using trans cases with benchmark calling through Snakemake trans_position = [int(p) for p in trans_position.split(",")] else : trans_position = None nb_bins = bins strides = [int(s) for s in strides.split(",")] elif pattern is not None and pattern != "-1": # Pattern based benchmarking pre_recall_cmd = hev.chromosight_cmd_generator(file = output_data_path / BASE_MATRIX, pattern = pattern, untrend = trend, output_dir = output_data_path) logger.info("Starting Chromosight pre-call") sp.run(pre_recall_cmd, shell = True) chromosome = chromosome if trans_chromosome is not None and trans_chromosome != "-1": trans_chromosome = [str(t) for t in trans_chromosome.split(",")] else : trans_chromosome = None if trans_position is not None and trans_position != "-1": # -1 usefull when not using trans cases with benchmark calling through Snakemake trans_position = [int(p) for p in trans_position.split(",")] else : trans_position = None df = hev.get_top_pattern(file = output_data_path / "original.tsv", top = top, threshold = threshold, chromosome = chromosome).sort_values(by='start1', ascending=True) print(df) position = df.iloc[0].start1 # select top score pattern if strides is None or strides == "-1": strides = [int(df.iloc[i].start1 - df['start1'].min()) for i in range(1, df.shape[0])] else : strides = [int(s) for s in strides.split(",")] nb_bins = bins for sub_mode in mode.split(","): precision = None recall = None f1_score = None # Pick reads if not picking_status: # # chunk bam files forward_chunks, reverse_chunks = hut.get_chunks(output_dir = output_path.as_posix()) # Multithread part coordinates = hev.generate_dict_coordinates(matrix_file = output_path / BASE_MATRIX, position = position, chromosome = chromosome, strides = strides, trans_chromosome = trans_chromosome, trans_position = trans_position, auto = auto, nb_bins = nb_bins, output_dir = output_data_path) # TODO : pass coordinates to select_reads_multithreads to avoid seeding problems through multithreading # Pick reads reads with Pool(processes = cpus) as pool: # cpus res = pool.map(partial(hev.select_reads_multithreads, interval_dictionary = coordinates, output_dir = output_data_path), zip(forward_chunks, reverse_chunks)) pool.close() pool.join() # Get dictionrary of intervals from Pool intervals_dictionary = coordinates hio.merge_predictions(output_dir = output_data_path, clean = True, stage = "benchmark", cpus = cpus) # TODO : put aside in function indexes = hev.get_bin_indexes(matrix = base_matrix, dictionary = intervals_dictionary, ) picking_status = True # Get corresponding indexes to the duplicated reads coordinates. # Re-build pairs and cooler matrix hio.build_pairs(bam_for = forward_out_path, bam_rev = reverse_out_path, output_dir = output_data_path) hio.build_matrix(balance = True, output_dir = output_data_path) unrescued_map_path = output_path / UNRESCUED_MATRIX # Reattribute reads from inner group if not learning_status : ## Compute statistics p1 = Process(target = hst.get_patterns, kwargs = dict(forward_bam_file = forward_out_path, reverse_bam_file = reverse_out_path, circular = circular, output_dir = output_data_path)) p2 = Process(target = hst.generate_trans_ps, kwargs = dict(output_dir = output_data_path)) p3 = Process(target = hst.generate_coverages, kwargs = dict(forward_bam_file = forward_out_path, reverse_bam_file = reverse_out_path, genome = genome, bins = bin_size, output_dir = output_data_path)) p4 = Process(target = hst.generate_d1d2, kwargs = dict(forward_bam_file = forward_out_path, reverse_bam_file = reverse_out_path, output_dir = output_data_path)) logger.info("Full mode selected. Learning step will be performed.") # Launch processes for process in [p1, p2, p3, p4]: process.start() # for process in [p1, p2, p3, p4, p5]: for process in [p1, p2, p3, p4]: process.join() logger.info("Learning step completed") if "density" in mode.split(","): hst.compute_density(cooler_file = UNRESCUED_MATRIX, kernel_size = kernel_size, deviation = deviation, threads = cpus, output_dir = output_data_path) learning_status = True for _ in range(iterations): # Reattribute reads logger.info("Re-attributing reads") # Get chunk_for_*.in.bam/chunk_rev_*.in.bam forward_chunks = sorted(glob(str(output_data_path / "chunk_for_*.in.bam"))) reverse_chunks = sorted(glob(str(output_data_path / "chunk_rev_*.in.bam"))) # Check if chunks are empty for forward_chunk, reverse_chunk in zip(forward_chunks, reverse_chunks): if hut.is_empty_alignment(forward_chunk) or hut.is_empty_alignment(reverse_chunk): forward_chunks.remove(forward_chunk) reverse_chunks.remove(reverse_chunk) # Reattribute reads with Pool(processes = cpus) as pool: # cpus res = pool.map(partial(hst.reattribute_reads, mode = sub_mode, restriction_map = output_data_path / RESTRICTION_MAP, output_dir = output_data_path), zip(forward_chunks, reverse_chunks)) pool.close() pool.join() # hst.reattribute_reads(reads_couple = (forward_in_path, reverse_in_path), mode = sub_mode, output_dir = output_data_path) hio.merge_predictions(output_dir = output_data_path, clean = True, cpus = cpus) hio.build_pairs(bam_for = "group1.1.out.bam", bam_rev = "group1.2.out.bam", bam_for_rescued = "group2.1.rescued.bam", bam_rev_rescued = "group2.2.rescued.bam", mode = True, output_dir = output_data_path) hio.build_matrix(mode = True, balance = False, output_dir = output_data_path) rescued_matrix = hio.load_cooler(output_data_path / RESCUED_MATRIX) rescued_matrix_path = output_data_path / RESCUED_MATRIX pearson = hst.pearson_score(original_matrix = base_matrix, rescued_matrix = rescued_matrix , markers = indexes) if pattern is None or pattern == "-1": chromosome_set = [*chromosome, *trans_chromosome] if trans_chromosome is not None else chromosome else : chromosome_set = [chromosome, *trans_chromosome] if trans_chromosome is not None else chromosome rescued_matrix_path = output_data_path / RESCUED_MATRIX post_recall_cmd = hev.chromosight_cmd_generator(file = rescued_matrix_path, pattern = pattern, untrend = trend, mode = True, output_dir = output_data_path) logger.info("Starting Chromosight post-call") sp.run(post_recall_cmd, shell = True) # TODO : move to top df_original = (output_path / "original.tsv").as_posix() df_rescued = (output_path / "rescued.tsv").as_posix() true_positives = hev.get_TP_table(df_pattern = output_data_path / "original.tsv", df_pattern_recall = output_data_path / "rescued.tsv", chromosome = chromosome, bin_size = bin_size, jitter = jitter, threshold = threshold) false_positives = hev.get_FP_table(df_pattern = output_data_path / "original.tsv", df_pattern_recall = output_data_path / "rescued.tsv", chromosome = chromosome, bin_size = bin_size, jitter = jitter, threshold = threshold) false_negatives = hev.get_FN_table(df_pattern = output_data_path / "original.tsv", df_pattern_recall = output_data_path / "rescued.tsv", chromosome = chromosome, bin_size = bin_size, jitter = jitter, threshold = threshold) # Get scores precision = hev.get_precision(df_pattern = df_original, df_pattern_recall = df_rescued, chromosome = chromosome, bin_size = bin_size, jitter = jitter, threshold = threshold) recall = hev.get_recall(df_pattern = df_original, df_pattern_recall = df_rescued, chromosome = chromosome, bin_size = bin_size, jitter = jitter, threshold = threshold) f1_score = hev.get_f1_score(df_pattern = df_original, df_pattern_recall = df_rescued, chromosome = chromosome, bin_size = bin_size, jitter = jitter, threshold = threshold) # Get plots hpl.plot_pattern_reconstruction(table = true_positives, original_cool = output_data_path / BASE_MATRIX, rescued_cool = rescued_matrix_path, chromosome = chromosome, threshold = threshold, case = "true_positives", output_dir = output_data_path) hpl.plot_pattern_reconstruction(table = false_positives, original_cool = output_data_path / BASE_MATRIX, rescued_cool = rescued_matrix_path, chromosome = chromosome, threshold = threshold, case = "false_positives", output_dir = output_data_path) hpl.plot_pattern_reconstruction(table = false_negatives, original_cool = output_data_path / BASE_MATRIX, rescued_cool = rescued_matrix_path, chromosome = chromosome, threshold = threshold, case = "false_negatives", output_dir = output_data_path) hpl.plot_benchmark(original_matrix = BASE_MATRIX, depleted_matrix = UNRESCUED_MATRIX, rescued_matrix = RESCUED_MATRIX, chromosomes = chromosome_set, output_dir = output_data_path) number_reads = 10 logger.info(f"Pearson score : {pearson:9.4f} in mode {sub_mode}") if not results.exists(): with open(results, "w") as f_out: f_out.write(header) date = datetime.now().strftime("%d/%m/%Y %H:%M:%S") f_out.write(f"{id_tag}\t{date}\t{chromosome}\t{position}\t{strides}\t{trans_chromosome}\t{trans_position}\t{auto}\t{bins}\t{sub_mode}\t{number_reads}\t{pattern}\t{precision}\t{recall}\t{f1_score}\t{pearson:9.4f}\n") f_out.close() else : with open(results, "a") as f_out: date = datetime.now().strftime("%d/%m/%Y %H:%M:%S") f_out.write(f"{id_tag}\t{date}\t{chromosome}\t{position}\t{strides}\t{trans_chromosome}\t{trans_position}\t{auto}\t{bins}\t{sub_mode}\t{number_reads}\t{pattern}\t{precision}\t{recall}\t{f1_score}\t{pearson:9.4f}\n") f_out.close() logger.info(f"Ending benchmark") # # Clean up # forward_in_path.unlink() # reverse_in_path.unlink() # forward_out_path.unlink() # reverse_out_path.unlink() # (output_data_path / BASE_MATRIX).unlink() # (output_data_path / BAM_FOR).unlink() # (output_data_path / BAM_REV).unlink() # (output_data_path / RESTRICTION_MAP).unlink() # (output_data_path / FRAGMENTS).unlink() # (output_data_path / CHROMOSOME_SIZES).unlink() # (output_data_path / DIST_FRAGS).unlink() # (output_data_path / XS).unlink() files = [p for p in output_data_path.glob("*")] for file in files: if Path(file).suffix in [ ".npy", ".tsv", ".bam", ".pairs"]: Path(file).unlink() open(flag_file, 'a').close() return