The Analysis Pipeline

Alex Malz (NYU) & Phil Marshall (SLAC)

In this notebook we use the "survey mode" machinery to demonstrate how one should choose the optimal parametrization for photo-$z$ PDF storage given the nature of the data, the storage constraints, and the fidelity necessary for a science use case.

In [1]:
#comment out for NERSC
%load_ext autoreload

#comment out for NERSC
%autoreload 2
In [2]:
from __future__ import print_function
    
import hickle
import numpy as np
import random
import cProfile
import pstats
import StringIO
import sys
import os
import timeit
import bisect
import re

import qp
from qp.utils import calculate_kl_divergence as make_kld

# np.random.seed(seed=42)
# random.seed(a=42)
In [3]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['text.usetex'] = True
mpl.rcParams['mathtext.rm'] = 'serif'
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = 'Times New Roman'
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['savefig.dpi'] = 250
mpl.rcParams['savefig.format'] = 'pdf'
mpl.rcParams['savefig.bbox'] = 'tight'

#comment out for NERSC
%matplotlib inline

Analysis

We want to compare parametrizations for large catalogs, so we'll need to be more efficient. The qp.Ensemble object is a wrapper for qp.PDF objects enabling conversions to be performed and metrics to be calculated in parallel. We'll experiment on a subsample of 100 galaxies.

In [4]:
def setup_dataset(dataset_key, skip_rows, skip_cols):
    start = timeit.default_timer()
    with open(dataset_info[dataset_key]['filename'], 'rb') as data_file:
        lines = (line.split(None) for line in data_file)
        for r in range(skip_rows):
            lines.next()
        pdfs = np.array([[float(line[k]) for k in range(skip_cols, len(line))] for line in lines])
    print('read in data file in '+str(timeit.default_timer()-start))
    return(pdfs)
In [5]:
def make_instantiation(dataset_key, n_gals_use, pdfs, bonus=None):
    
    start = timeit.default_timer()
    
    n_gals_tot = len(pdfs)
    full_gal_range = range(n_gals_tot)
    subset = np.random.choice(full_gal_range, n_gals_use, replace=False)#range(n_gals_use)
    print('randos for debugging: '+str(subset))
    pdfs_use = pdfs[subset]
    
    modality = []
    dpdfs = pdfs_use[:,1:] - pdfs_use[:,:-1]
    iqrs = []
    for i in range(n_gals_use):
        modality.append(len(np.where(np.diff(np.signbit(dpdfs[i])))[0]))
        cdf = np.cumsum(qp.utils.normalize_integral((dataset_info[dataset_key]['z_grid'], pdfs_use[i]), vb=False)[1])
        iqr_lo = dataset_info[dataset_key]['z_grid'][bisect.bisect_left(cdf, 0.25)]
        iqr_hi = dataset_info[dataset_key]['z_grid'][bisect.bisect_left(cdf, 0.75)]
        iqrs.append(iqr_hi - iqr_lo)
    modality = np.array(modality)
        
    dataset_info[dataset_key]['N_GMM'] = int(np.median(modality))+1
#     print('n_gmm for '+dataset_info[dataset_key]['name']+' = '+str(dataset_info[dataset_key]['N_GMM']))
      
    # using the same grid for output as the native format, but doesn't need to be so
    dataset_info[dataset_key]['in_z_grid'] = dataset_info[dataset_key]['z_grid']
    dataset_info[dataset_key]['metric_z_grid'] = dataset_info[dataset_key]['z_grid']
    
    print('preprocessed data in '+str(timeit.default_timer()-start))
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['randos'] = randos
        info['z_grid'] = dataset_info[dataset_key]['in_z_grid']
        info['pdfs'] = pdfs_use
        info['modes'] = modality
        info['iqrs'] = iqrs
        hickle.dump(info, filename)
    
    return(pdfs_use)
In [6]:
def plot_examples(n_gals_use, dataset_key, bonus=None):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        randos = info['randos']
        z_grid = info['z_grid']
        pdfs = info['pdfs']
    
    plt.figure()
    for i in range(n_plot):
        data = (z_grid, pdfs[randos[i]])
        data = qp.utils.normalize_integral(qp.utils.normalize_gridded(data))
        pz_max.append(np.max(data))
        plt.plot(data[0], data[1], label=dataset_info[dataset_key]['name']+' \#'+str(randos[i]), color=color_cycle[i])
    plt.xlabel(r'$z$', fontsize=14)
    plt.ylabel(r'$p(z)$', fontsize=14)
    plt.xlim(min(z_grid), max(z_grid))
    plt.ylim(0., max(pz_max))
    plt.title(dataset_info[dataset_key]['name']+' data examples', fontsize=16)
    plt.savefig(loc+'.pdf', dpi=250)
    plt.close()
    
    if 'modes' in info.keys():
        modes = info['modes']
        modes_max.append(np.max(modes))
        plt.figure()
        ax = plt.hist(modes, color='k', alpha=1./n_plot, histtype='stepfilled', bins=range(max(modes_max)+1))
        plt.xlabel('modes')
        plt.ylabel('frequency')
        plt.title(dataset_info[dataset_key]['name']+' data modality distribution (median='+str(dataset_info[dataset_key]['N_GMM'])+')', fontsize=16)
        plt.savefig(loc+'modality.pdf', dpi=250)
        plt.close()
        
    if 'iqrs' in info.keys():
        iqrs = info['iqrs']
        iqr_min.append(min(iqrs))
        iqr_max.append(max(iqrs))
        plot_bins = np.linspace(min(iqr_min), max(iqr_max), 20)
        plt.figure()
        ax = plt.hist(iqrs, bins=plot_bins, color='k', alpha=1./n_plot, histtype='stepfilled')
        plt.xlabel('IQR')
        plt.ylabel('frequency')
        plt.title(dataset_info[dataset_key]['name']+' data IQR distribution', fontsize=16)
        plt.savefig(loc+'iqrs.pdf', dpi=250)
        plt.close()

We'll start by reading in our catalog of gridded PDFs, sampling them, fitting GMMs to the samples, and establishing a new qp.Ensemble object where each meber qp.PDF object has qp.PDF.truth$\neq$None.

In [7]:
def setup_from_grid(dataset_key, in_pdfs, z_grid, N_comps, high_res=1000, bonus=None):
    
    #read in the data, happens to be gridded
    zlim = (min(z_grid), max(z_grid))
    N_pdfs = len(in_pdfs)
    
    start = timeit.default_timer()
#     print('making the initial ensemble of '+str(N_pdfs)+' PDFs')
    E0 = qp.Ensemble(N_pdfs, gridded=(z_grid, in_pdfs), limits=dataset_info[dataset_key]['z_lim'], vb=False)
    print('made the initial ensemble of '+str(N_pdfs)+' PDFs in '+str(timeit.default_timer() - start))    
    
    #fit GMMs to gridded pdfs based on samples (faster than fitting to gridded)
    start = timeit.default_timer()
#     print('sampling for the GMM fit')
    samparr = E0.sample(high_res, vb=False)
    print('took '+str(high_res)+' samples in '+str(timeit.default_timer() - start))
    
    start = timeit.default_timer()
#     print('making a new ensemble from samples')
    Ei = qp.Ensemble(N_pdfs, samples=samparr, limits=dataset_info[dataset_key]['z_lim'], vb=False)
    print('made a new ensemble from samples in '+str(timeit.default_timer() - start))
    
    start = timeit.default_timer()
#     print('fitting the GMM to samples')
    GMMs = Ei.mix_mod_fit(comps=N_comps, vb=False)
    print('fit the GMM to samples in '+str(timeit.default_timer() - start))
    
    #set the GMMS as the truth
    start = timeit.default_timer()
#     print('making the final ensemble')
    Ef = qp.Ensemble(N_pdfs, truth=GMMs, limits=dataset_info[dataset_key]['z_lim'], vb=False)
    print('made the final ensemble in '+str(timeit.default_timer() - start))
    
    path = os.path.join(dataset_key, str(N_pdfs))
    loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['randos'] = randos
        info['z_grid'] = z_grid
        info['pdfs'] = Ef.evaluate(z_grid, using='truth', norm=True, vb=False)[1]
        hickle.dump(info, filename)
        
    start = timeit.default_timer()
#     print('calculating '+str(n_moments_use)+' moments of original PDFs')
    in_moments, vals = [], []
    for n in range(n_moments_use):
        in_moments.append(Ef.moment(n, using='truth', limits=zlim, 
                                    dx=delta_z, vb=False))
        vals.append(n)
    moments = np.array(in_moments)
    print('calculated '+str(n_moments_use)+' moments of original PDFs in '+str(timeit.default_timer() - start))
    
    path = os.path.join(dataset_key, str(N_pdfs))
    loc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+bonus)
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['truth'] = moments
        info['orders'] = vals
        hickle.dump(info, filename)
    
    return(Ef)

Next, we compute the KLD between each approximation and the truth for every member of the ensemble. We make the qp.Ensemble.kld into a qp.PDF object of its own to compare the moments of the KLD distributions for different parametrizations.

In [8]:
def analyze_individual(E, z_grid, N_floats, dataset_key, N_moments=4, i=None, bonus=None):
    zlim = (min(z_grid), max(z_grid))
    z_range = zlim[-1] - zlim[0]
    delta_z = z_range / len(z_grid)
    path = os.path.join(dataset_key, str(n_gals_use))
    
    Eq, Eh, Es = E, E, E
    inits = {}
    for f in formats:
        inits[f] = {}
        for ff in formats:
            inits[f][ff] = None
            
    qstart = timeit.default_timer()
    inits['quantiles']['quantiles'] = Eq.quantize(N=N_floats, vb=False)
    print('finished making in '+str(timeit.default_timer() - qstart))
    hstart = timeit.default_timer()
    inits['histogram']['histogram'] = Eh.histogramize(N=N_floats, binrange=zlim, vb=False)
    print('finished histogramization in '+str(timeit.default_timer() - hstart))
    sstart = timeit.default_timer()
    inits['samples']['samples'] = Es.sample(samps=N_floats, vb=False)
    print('finished sampling in '+str(timeit.default_timer() - sstart))
        
    Eo = {}
    for f in formats:
        fstart = timeit.default_timer()
        Eo[f] = qp.Ensemble(E.n_pdfs, truth=E.truth, 
                            quantiles=inits[f]['quantiles'], 
                            histogram=inits[f]['histogram'],
                            samples=inits[f]['samples'], 
                            limits=dataset_info[dataset_key]['z_lim'])
        fbonus = str(N_floats)+f+str(i)
        loc = os.path.join(path, 'pzs'+str(n_gals_use)+dataset_key+fbonus)
        with open(loc+'.hkl', 'w') as filename:
            info = {}
            info['randos'] = randos
            info['z_grid'] = z_grid
            info['pdfs'] = Eo[f].evaluate(z_grid, using=f, norm=True, vb=False)[1]
            hickle.dump(info, filename)
        print('made '+f+' ensemble in '+str(timeit.default_timer()-fstart))
    
    metric_start = timeit.default_timer()
    inloc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+bonus)
    with open(inloc+'.hkl', 'r') as infilename:
        pz_moments = hickle.load(infilename)
    pz_moment_deltas, klds, metrics, kld_moments = {}, {}, {}, {}
    
    for key in Eo.keys():
        key_start = timeit.default_timer()
        klds[key] = Eo[key].kld(using=key, limits=zlim, dx=delta_z)
        samp_metric = qp.PDF(samples=klds[key])
        gmm_metric = samp_metric.mix_mod_fit(n_components=dataset_info[dataset_key]['N_GMM'], 
                                             using='samples', vb=False)
        metrics[key] = qp.PDF(truth=gmm_metric)
        
        pz_moment_deltas[key], pz_moments[key], kld_moments[key] = [], [], []
        for n in range(N_moments):
            kld_moments[key].append(qp.utils.calculate_moment(metrics[key], n,
                                                          using='truth', 
                                                          limits=zlim, 
                                                          dx=delta_z, 
                                                          vb=False))
            new_moment = Eo[key].moment(n, using=key, limits=zlim, 
                                                  dx=delta_z, vb=False)
            pz_moments[key].append(new_moment)
            delta_moment = (new_moment - pz_moments['truth'][n]) / pz_moments['truth'][n]
            pz_moment_deltas[key].append(delta_moment)
        print('calculated the '+key+' individual moments, kld moments in '+str(timeit.default_timer() - key_start))

    loc = os.path.join(path, 'kld_hist'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['z_grid'] = z_grid
        info['N_floats'] = N_floats
        info['pz_klds'] = klds
        hickle.dump(info, filename)

    outloc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
    with open(outloc+'.hkl', 'w') as outfilename:
        hickle.dump(pz_moments, outfilename)
    
    save_moments(name, size, n_floats_use, kld_moments, 'pz_kld_moments')
    save_moments(name, size, n_floats_use, pz_moments, 'pz_moments')
    save_moments(name, size, n_floats_use, pz_moment_deltas, 'pz_moment_deltas')
    
    return(Eo)#, klds, kld_moments, pz_moments, pz_moment_deltas)
In [9]:
def plot_all_examples(name, size, N_floats, init, bonus={}):
    
    fig, ax = plt.subplots()
    lines = []
    for bonus_key in bonus.keys():
        path = os.path.join(name, str(size))
        loc = os.path.join(path, 'pzs'+str(size)+name+bonus_key)
        with open(loc+'.hkl', 'r') as filename:
            info = hickle.load(filename)
            randos = info['randos']
            z_grid = info['z_grid']
            pdfs = info['pdfs']
        ls = bonus[bonus_key][0]
        a = bonus[bonus_key][1]
        lab = re.sub(r'[\_]', '', bonus_key)
        line, = ax.plot([-1., 0.], [0., 0.], linestyle=ls, alpha=a, color='k', label=lab)
        lines.append(line)
        leg = ax.legend(loc='upper right', handles=lines)
        for i in range(n_plot):
            data = (z_grid, pdfs[randos[i]])
            data = qp.utils.normalize_integral(qp.utils.normalize_gridded(data))
            ax.plot(data[0], data[1], linestyle=ls, alpha=a, color=color_cycle[i])
#     ax.legend(loc='upper right')
    ax.set_xlabel(r'$z$', fontsize=14)
    ax.set_ylabel(r'$p(z)$', fontsize=14)
    ax.set_xlim(min(z_grid), max(z_grid))
    ax.set_title(dataset_info[name]['name']+r' examples with $N_{f}=$'+str(N_floats), fontsize=16)
    saveloc = os.path.join(path, 'pzs'+str(size)+name+str(N_floats)+'_'+str(init))
    fig.savefig(saveloc+'.pdf', dpi=250)
    plt.close()
In [10]:
def plot_individual_kld(n_gals_use, dataset_key, N_floats, i):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    a = 1./len(formats)
    loc = os.path.join(path, 'kld_hist'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        z_grid = info['z_grid']
        N_floats = info['N_floats']
        pz_klds = info['pz_klds']
    
    plt.figure()
    plot_bins = np.linspace(-3., 3., 20)
    for key in pz_klds.keys():
        logdata = qp.utils.safelog(pz_klds[key])
        kld_hist = plt.hist(logdata, color=colors[key], alpha=a, histtype='stepfilled', edgecolor='k',
             label=key, normed=True, bins=plot_bins, linestyle=stepstyles[key], ls=stepstyles[key], lw=2)
        hist_max.append(max(kld_hist[0]))
        dist_min.append(min(logdata))
        dist_max.append(max(logdata))
    plt.legend()
    plt.ylabel('frequency', fontsize=14)
    plt.xlabel(r'$\log[KLD]$', fontsize=14)
#     plt.xlim(min(dist_min), max(dist_max))
#     plt.ylim(0., max(hist_max))
    plt.title(dataset_info[dataset_key]['name']+r' data $p(KLD)$ with $N_{f}='+str(N_floats)+r'$', fontsize=16)
    plt.savefig(loc+'.pdf', dpi=250)
    plt.close()
In [11]:
# def plot_individual_moment(n_gals_use, dataset_key, N_floats, i):
    
#     path = os.path.join(dataset_key, str(n_gals_use))
#     a = 1./len(formats)    
#     loc = os.path.join(path, 'pz_moments'+str(n_gals_use)+dataset_key+str(N_floats)+'_'+str(i))
#     with open(loc+'.hkl', 'r') as filename:
#         moments = hickle.load(filename)
#     delta_moments = {}
        
#     plt.figure(figsize=(5, 5 * (n_moments_use-1)))
#     for n in range(1, n_moments_use):
#         ax = plt.subplot(n_moments_use, 1, n)
#         ends = (min(moments['truth'][n]), max(moments['truth'][n]))
#         for key in formats:
#             ends = (min(ends[0], min(moments[key][n])), max(ends[-1], max(moments[key][n])))
#         plot_bins = np.linspace(ends[0], ends[-1], 20)
#         ax.hist([-100], color='k', alpha=a, histtype='stepfilled', edgecolor='k', label='truth', 
#                     linestyle='-', ls='-')
#         ax.hist(moments['truth'][n], bins=plot_bins, color='k', alpha=a, histtype='stepfilled', normed=True)
#         ax.hist(moments['truth'][n], bins=plot_bins, color='k', histtype='step', normed=True, linestyle='-', alpha=a)
#         for key in formats:
#             ax.hist([-100], color=colors[key], alpha=a, histtype='stepfilled', edgecolor='k', label=key, 
#                     linestyle=stepstyles[key], ls=stepstyles[key], lw=2)
#             ax.hist(moments[key][n], bins=plot_bins, color=colors[key], alpha=a, histtype='stepfilled', normed=True)
#             ax.hist(moments[key][n], bins=plot_bins, color='k', histtype='step', normed=True, linestyle=stepstyles[key], alpha=a, lw=2)
#         ax.legend()
#         ax.set_ylabel('frequency', fontsize=14)
#         ax.set_xlabel(moment_names[n], fontsize=14)
#         ax.set_xlim(min(plot_bins), max(plot_bins))
#     plt.suptitle(dataset_info[dataset_key]['name']+r' data moments with $N_{f}='+str(N_floats)+r'$', fontsize=16)
#     plt.tight_layout()
#     plt.subplots_adjust(top=0.95)
#     plt.savefig(loc+'.pdf', dpi=250)
#     plt.close()
        
#     ngood = {}
#     normarr = np.ones(n_gals_use)
#     for key in formats:
#         ngood[key] = np.zeros(n_moments_use)
#     plt.figure(figsize=(5, 5 * (n_moments_use-1)))
#     for n in range(1, n_moments_use):
#         ax = plt.subplot(n_moments_use, 1, n)
#         ends = (100., -100.)
#         for key in formats:
#             delta_moments[key] = (moments[key] - moments['truth']) / moments['truth']
#             ngood[key][n] = np.sum(normarr[np.abs(delta_moments[key][n]) < 0.01]) / float(n_gals_use)
#             ends = (min(ends[0], min(delta_moments[key][n])), max(ends[-1], max(delta_moments[key][n])))
#         plot_bins = np.linspace(ends[0], ends[-1], 20)
#         for key in formats:
#             ax.hist([-100], color=colors[key], alpha=a, histtype='stepfilled', edgecolor='k', label=key, 
#                     linestyle=stepstyles[key], ls=stepstyles[key], lw=2)
#             ax.hist(delta_moments[key][n], bins=plot_bins, color=colors[key], alpha=a, histtype='stepfilled', normed=True)
#             ax.hist(delta_moments[key][n], bins=plot_bins, color='k', histtype='step', normed=True, linestyle=stepstyles[key], alpha=a, lw=2)
#         ax.legend()
#         ax.set_ylabel('frequency', fontsize=14)
#         ax.set_xlabel(r'fractional error on '+moment_names[n], fontsize=14)
#         ax.set_xlim(min(plot_bins), max(plot_bins))
#     plt.tight_layout()
#     plt.subplots_adjust(top=0.95)
#     plt.suptitle(dataset_info[dataset_key]['name']+r' data moment fractional errors with $N_{f}='+str(N_floats)+r'$', fontsize=16)
#     plt.savefig(loc+'_delta.pdf', dpi=250)
#     plt.close()
    
#     #TO DO: move this calculation and saving out of this plot, then eliminate the plot!
#     save_moments(dataset_key, n_gals_use, N_floats, ngood, 'pz_moment_deltas')

Finally, we calculate metrics on the stacked estimator $\hat{n}(z)$ that is the average of all members of the ensemble.

In [12]:
def analyze_stacked(E0, E, z_grid, n_floats_use, dataset_key, i=None):
    
    zlim = (min(z_grid), max(z_grid))
    z_range = zlim[-1] - zlim[0]
    delta_z = z_range / len(z_grid)
    
#     print('stacking the ensembles')
#     stack_start = timeit.default_timer()
    stacked_pdfs, stacks = {}, {}
    for key in formats:
        start = timeit.default_timer()
        stacked_pdfs[key] = qp.PDF(gridded=E[key].stack(z_grid, using=key, 
                                                        vb=False)[key])
        stacks[key] = stacked_pdfs[key].evaluate(z_grid, using='gridded', norm=True, vb=False)[1]
        print('stacked '+key+ ' in '+str(timeit.default_timer()-start))
    
    stack_start = timeit.default_timer()
    stacked_pdfs['truth'] = qp.PDF(gridded=E0.stack(z_grid, using='truth', 
                                                    vb=False)['truth'])
    
    stacks['truth'] = stacked_pdfs['truth'].evaluate(z_grid, using='gridded', norm=True, vb=False)[1]
    print('stacked truth in '+str(timeit.default_timer() - stack_start))
    
#     print('calculating the metrics')
#     metric_start = timeit.default_timer()
#     for n in range(n_moments_use):
#         moments['truth'].append(qp.utils.calculate_moment(stacked_pdfs['truth'], n, 
#                                                           limits=zlim, 
#                                                           dx=delta_z, 
#                                                           vb=False))
#     print('calculated the true moments in '+str(timeit.default_timer() - metric_start))
    
    klds = {}
    for key in formats:
        kld_start = timeit.default_timer()
        klds[key] = qp.utils.calculate_kl_divergence(stacked_pdfs['truth'],
                                                     stacked_pdfs[key], 
                                                     limits=zlim, dx=delta_z)
        print('calculated the '+key+' stacked kld in '+str(timeit.default_timer() - kld_start))
    save_nz_metrics(name, size, n_floats_use, klds, 'nz_klds')
        
    moments = {}
    for key in formats_plus:
        moment_start = timeit.default_timer()
        moments[key] = []
        for n in range(n_moments_use):
            moments[key].append(qp.utils.calculate_moment(stacked_pdfs[key], n, 
                                                          limits=zlim, 
                                                          dx=delta_z, 
                                                          vb=False))
            print('calculated the '+key+' stacked moments in '+str(timeit.default_timer() - moment_start))
    save_moments(name, size, n_floats_use, moments, 'nz_moments') 
    
    path = os.path.join(dataset_key, str(E0.n_pdfs))
    loc = os.path.join(path, 'nz_comp'+str(n_gals_use)+dataset_key+str(n_floats_use)+'_'+str(i))
    with open(loc+'.hkl', 'w') as filename:
        info = {}
        info['z_grid'] = z_grid
        info['stacks'] = stacks
        info['klds'] = klds
        info['moments'] = moments
        hickle.dump(info, filename)
    
    return(stacked_pdfs)
In [13]:
def plot_estimators(n_gals_use, dataset_key, n_floats_use, i=None):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'nz_comp'+str(n_gals_use)+dataset_key+str(n_floats_use)+'_'+str(i))
    with open(loc+'.hkl', 'r') as filename:
        info = hickle.load(filename)
        z_grid = info['z_grid']
        stacks = info['stacks']
        klds = info['klds']
    
    plt.figure()
    plt.plot(z_grid, stacks['truth'], color='black', lw=3, alpha=0.3, label='original')
    nz_max.append(max(stacks['truth']))
    for key in formats:
        nz_max.append(max(stacks[key]))
        plt.plot(z_grid, stacks[key], label=key+r' KLD='+str(klds[key])[:8], color=colors[key], linestyle=styles[key])
    plt.xlabel(r'$z$', fontsize=14)
    plt.ylabel(r'$\hat{n}(z)$', fontsize=14)
    plt.xlim(min(z_grid), max(z_grid))
#     plt.ylim(0., max(nz_max))
    plt.legend()
    plt.title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ with $N_{f}='+str(n_floats_use)+r'$', fontsize=16)
    plt.savefig(loc+'.pdf', dpi=250)
    plt.close()

We save the data so we can remake the plots later without running everything again.

Scaling

We'd like to do this for many values of $N_{f}$ as well as larger catalog subsamples, repeating the analysis many times to establish error bars on the KLD as a function of format, $N_{f}$, and dataset. The things we want to plot across multiple datasets/number of parametes are:

  1. KLD of stacked estimator, i.e. N_f vs. nz_output[dataset][format][instantiation][KLD_val_for_N_f]
  2. moments of KLD of individual PDFs, i.e. n_moment, N_f vs. pz_output[dataset][format][n_moment][instantiation][moment_val_for_N_f]

So, we ned to make sure these are saved!

We want to plot the moments of the KLD distribution for each format as $N_{f}$ changes.

In [14]:
def save_moments(dataset_key, n_gals_use, N_f, stat, stat_name):

    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_key)
    
    if os.path.exists(loc+'.hkl'):
        with open(loc+'.hkl', 'r') as stat_file:
        #read in content of list/dict
            stats = hickle.load(stat_file)
    else:
        stats = {}
        stats['N_f'] = []
        for f in stat.keys():
            stats[f] = []
            for m in range(n_moments_use):
                stats[f].append([])

    if N_f not in stats['N_f']:
        stats['N_f'].append(N_f)
        for f in stat.keys():
            for m in range(n_moments_use):
                stats[f][m].append([])
        
    where_N_f = stats['N_f'].index(N_f)
        
    for f in stat.keys():
        for m in range(n_moments_use):
            stats[f][m][where_N_f].append(stat[f][m])

    with open(loc+'.hkl', 'w') as stat_file:
        hickle.dump(stats, stat_file)
In [15]:
def plot_pz_metrics(dataset_key, n_gals_use):

    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'pz_kld_moments'+str(n_gals_use)+dataset_key)
    with open(loc+'.hkl', 'r') as pz_file:
        pz_stats = hickle.load(pz_file)
#     if len(instantiations) == 10:
#         for f in formats:
#             for n in range(n_moments_use):
#                 if not np.shape(pz_stats[f][n]) == (4, 10):
#                     for s in range(len(pz_stats[f][n])):
#                         pz_stats[f][n][s] = np.array(np.array(pz_stats[f][n][s])[:10]).flatten()
        
    flat_floats = np.array(pz_stats['N_f']).flatten()
    in_x = np.log(flat_floats)

    def make_patch_spines_invisible(ax):
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for sp in ax.spines.values():
            sp.set_visible(False)

    shapes = moment_shapes
    marksize = 10
    a = 1./len(formats)
    
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax.plot([-1], [0], color=colors[key], label=key, linewidth=1, linestyle=styles[key])
    for n in range(1, n_moments_use):
        ax.scatter([-1], [0], color='k', alpha=a, marker=shapes[n], s=2*marksize, label=moment_names[n])
        n_factor = 0.1 * (n - 2)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print('pz metrics data shape '+str(pz_stats[f][n]))
            data_arr = np.log(np.swapaxes(np.array(pz_stats[f][n]), 0, 1))#go from n_floats*instantiations to instantiations*n_floats
            mean = np.mean(data_arr, axis=0).flatten()
            std = np.std(data_arr, axis=0).flatten()
            y_plus = mean + std
            y_minus = mean - std
            y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
            ax_n.plot(np.exp(in_x+n_factor), mean, marker=shapes[n], markersize=marksize, linestyle=styles[f], alpha=a, color=colors[f])
            ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
            pz_mean_max[n] = max(pz_mean_max[n], np.max(y_plus))
            pz_mean_min[n] = min(pz_mean_min[n], np.min(y_minus))
        ax_n.set_ylabel(r'$\log[\mathrm{'+moment_names[n]+r'}]$', rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim((pz_mean_min[n]-1., pz_mean_max[n]+1.))
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\log[KLD]$ log-moments', fontsize=16)
    ax.legend(loc='lower left')
    fig.tight_layout()
    fig.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()
    
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax_n.plot([-1], [0], color=colors[key], label=key, linestyle=styles[key], linewidth=1)
    for n in range(1, n_moments_use):
        n_factor = 0.1 * (n - 2)
        ax.scatter([-1], [0], color='k', alpha=a, marker=shapes[n], s=2*marksize, label=moment_names[n])
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print('pz metrics data shape '+str(pz_stats[f][n]))
            data_arr = np.log(np.swapaxes(np.array(pz_stats[f][n]), 0, 1))#go from n_floats*instantiations to instantiations*n_floats
            for i in data_arr:
                ax_n.plot(np.exp(in_x+n_factor), i, linestyle=styles[f], marker=shapes[n], markersize=marksize, color=colors[f], alpha=a)
#                 pz_moment_max[n-1].append(max(i))
        ax_n.set_ylabel(r'$\log[\mathrm{'+moment_names[n]+r'}]$', rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim(pz_mean_min[n]-1., pz_mean_max[n]+1.)
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\log[KLD]$ log-moments', fontsize=16)
    ax.legend(loc='lower left')
    fig.tight_layout()
    fig.savefig(loc+'_all.pdf', dpi=250)
    plt.close()
In [16]:
def plot_pz_delta_moments(name, size):
    n_gals_use = size
    
    # should look like nz_moments
    path = os.path.join(name, str(n_gals_use))
    loc = os.path.join(path, 'pz_moment_deltas'+str(n_gals_use)+name)
    with open(loc+'.hkl', 'r') as pz_file:
        pz_stats = hickle.load(pz_file)
    flat_floats = np.array(pz_stats['N_f']).flatten()
    in_x = np.log(flat_floats)
    a = 1./len(formats)
    shapes = moment_shapes
    marksize = 10
    
    def make_patch_spines_invisible(ax):
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for sp in ax.spines.values():
            sp.set_visible(False)   
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], linewidth=1)
    for n in range(1, n_moments_use):
        ax.scatter([-10], [0], color='k', alpha=a, marker=shapes[n], s=2*marksize, label=moment_names[n])
        n_factor = 0.1 * (n - 2)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print(str(np.shape(pz_stats[f][n]))+' should be n_floats * n_instantiations')
            data_arr = np.swapaxes(np.array(pz_stats[f][n]), 0, 1)#go from n_floats*instantiations to instantiations*n_floats
#             print(str(np.shape(data_arr))+' should be n_instantiations * n_floats')
            data_arr = np.median(data_arr, axis=2) * 100.
            mean = np.mean(data_arr, axis=0).flatten()
            std = np.std(data_arr, axis=0).flatten()
            y_plus = mean + std
            y_minus = mean - std
            y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
            ax_n.plot(np.exp(in_x+n_factor), mean, linestyle=styles[key], marker=shapes[n], markersize=marksize, alpha=a, color=colors[f])
            ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
            n_delta_max[n] = max(n_delta_max[n], np.max(y_plus))
            n_delta_min[n] = min(n_delta_min[n], np.min(y_minus))
        ax_n.set_ylabel(r'median percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim((min(n_delta_min)-1., max(n_delta_max)+1.))
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[name]['name']+r' data $\hat{p}(z)$ moment errors', fontsize=16)
    ax.legend(loc='upper right')
    fig.tight_layout()
    fig.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax_n.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], linewidth=1)
    for n in range(1, n_moments_use):
        n_factor = 0.1 * (n - 2)
        ax.scatter([-10], [0], color='k', alpha=a, marker=shapes[n], s=2*marksize, label=moment_names[n])
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
            data_arr = np.swapaxes(np.array(pz_stats[f][n]), 0, 1)
            data_arr = np.median(data_arr, axis=2) * 100.
            for i in data_arr:
                ax_n.plot(np.exp(in_x+n_factor), i, linestyle=styles[f], marker=shapes[n], markersize=marksize, color=colors[f], alpha=a)
        ax_n.set_ylabel(r'median percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim((min(n_delta_min)-1., min(n_delta_max)+1.))
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[name]['name']+r' data $\hat{n}(z)$ moments', fontsize=16)
    ax.legend(loc='upper right')
    fig.tight_layout()
    fig.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

We want to plot the KLD on $\hat{n}(z)$ for all formats as $N_{f}$ changes. We want to repeat this for many subsamples of the catalog to establush error bars on the KLD values.

In [17]:
def save_nz_metrics(dataset_key, n_gals_use, N_f, nz_klds, stat_name):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, stat_name+str(n_gals_use)+dataset_key)
    if os.path.exists(loc+'.hkl'):
        with open(loc+'.hkl', 'r') as nz_file:
        #read in content of list/dict
            nz_stats = hickle.load(nz_file)
    else:
        nz_stats = {}
        nz_stats['N_f'] = []
        for f in formats:
            nz_stats[f] = []
    
    if N_f not in nz_stats['N_f']:
        nz_stats['N_f'].append(N_f)
        for f in formats:
            nz_stats[f].append([])
        
    where_N_f = nz_stats['N_f'].index(N_f) 
    
    for f in formats:
        nz_stats[f][where_N_f].append(nz_klds[f])

    with open(loc+'.hkl', 'w') as nz_file:
        hickle.dump(nz_stats, nz_file)
In [18]:
def plot_nz_klds(dataset_key, n_gals_use):
    
    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'nz_klds'+str(n_gals_use)+dataset_key)
    with open(loc+'.hkl', 'r') as nz_file:
        nz_stats = hickle.load(nz_file)
    if len(instantiations) == 10:
        for f in formats:
            if not np.shape(nz_stats[f]) == (4, 10):
                for s in range(len(floats)):
                    nz_stats[f][s] = np.array(np.array(nz_stats[f][s])[:10]).flatten()

    flat_floats = np.array(nz_stats['N_f']).flatten()
    
    plt.figure(figsize=(5, 5))
    for f in formats:
#         print('nz klds data shape '+str(nz_stats[f][n]))
        data_arr = np.swapaxes(np.array(nz_stats[f]), 0, 1)#turn N_f * instantiations into instantiations * N_f
        n_i = len(data_arr)
        a = 1./len(formats)#1./n_i
        plt.plot([10. * max(flat_floats), 10. * max(flat_floats)], [1., 10.], color=colors[f], alpha=a, label=f, linestyle=styles[f])
        for i in data_arr:
            plt.plot(flat_floats, i, color=colors[f], alpha=a, linestyle=styles[f])
            kld_min.append(min(i))
            kld_max.append(max(i))
    plt.semilogy()
    plt.semilogx()
    plt.xticks(flat_floats, [str(ff) for ff in flat_floats])
    plt.ylim(min(kld_min) / 10., 10. *  max(kld_max))
    plt.xlim(min(flat_floats) / 3., max(flat_floats) * 3.)
    plt.xlabel(r'number of parameters', fontsize=14)
    plt.ylabel(r'KLD', fontsize=14)
    plt.legend(loc='upper right')
    plt.title(r'$\hat{n}(z)$ KLD on '+str(n_gals_use)+' from '+dataset_info[dataset_key]['name']+' mock catalog', fontsize=16)
    plt.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

    plt.figure(figsize=(5, 5))
    a = 1./len(formats)
    for f in formats:
#         print('nz klds data shape '+str(nz_stats[f][n]))
        data_arr = np.swapaxes(np.array(nz_stats[f]), 0, 1)#turn N_f * instantiations into instantiations * N_f
        plt.plot([10. * max(flat_floats), 10. * max(flat_floats)], [1., 10.], color=colors[f], label=f, linestyle=styles[f])
        kld_min.append(np.min(data_arr))
        kld_max.append(np.max(data_arr))
        mean = np.mean(data_arr, axis=0)
        std = np.std(data_arr, axis=0)
        x_cor = np.array([flat_floats[:-1], flat_floats[:-1], flat_floats[1:], flat_floats[1:]])
        y_plus = mean + std
        y_minus = mean - std
        y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
        plt.plot(flat_floats, mean, color=colors[f], linestyle=styles[f])
        plt.fill(x_cor, y_cor, color=colors[f], alpha=a, linewidth=0.)
    plt.semilogy()
    plt.semilogx()
    plt.xticks(flat_floats, [str(ff) for ff in flat_floats])
    plt.ylim(min(kld_min) / 10., 10. *  max(kld_max))
    plt.xlim(min(flat_floats), max(flat_floats))
    plt.xlabel(r'number of parameters', fontsize=14)
    plt.ylabel(r'KLD', fontsize=14)
    plt.legend(loc='upper right')
    plt.title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ KLD', fontsize=16)
    plt.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()
In [19]:
def plot_nz_moments(dataset_key, n_gals_use):

    path = os.path.join(dataset_key, str(n_gals_use))
    loc = os.path.join(path, 'nz_moments'+str(n_gals_use)+dataset_key)
    with open(loc+'.hkl', 'r') as nz_file:
        nz_stats = hickle.load(nz_file)
    flat_floats = np.array(nz_stats['N_f']).flatten()
    in_x = np.log(flat_floats)
    a = 1./len(formats)
    shapes = moment_shapes
    marksize = 10
    
    def make_patch_spines_invisible(ax):
        ax.set_frame_on(True)
        ax.patch.set_visible(False)
        for sp in ax.spines.values():
            sp.set_visible(False)   
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], linewidth=1)
#     ax.plot([-10], [0], color='k', label='original', linewidth=0.5, alpha=1.)
    for n in range(1, n_moments_use):
        ax.scatter([-10], [0], color='k', alpha=a, marker=shapes[n], s=2*marksize, label=moment_names[n])
        n_factor = 0.1 * (n - 2)
        truth = np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print('nz moments data shape '+str(nz_stats[f][n]))
            data_arr = (np.swapaxes(np.array(nz_stats[f][n]), 0, 1) - truth) / truth * 100.#np.log(np.swapaxes(np.array(nz_stats[f]), 0, 1)[:][:][n])#go from n_floats*instantiations to instantiations*n_floats
            mean = np.mean(data_arr, axis=0).flatten()
            std = np.std(data_arr, axis=0).flatten()
            y_plus = mean + std
            y_minus = mean - std
            y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
            ax_n.plot(np.exp(in_x+n_factor), mean, linestyle=styles[key], marker=shapes[n], markersize=marksize, alpha=a, color=colors[f])
            ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color=colors[f])
            nz_mean_max[n] = max(nz_mean_max[n], np.max(y_plus))
            nz_mean_min[n] = min(nz_mean_min[n], np.min(y_minus))
#         data_arr = np.log(np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1))
#         mean = np.mean(data_arr, axis=0).flatten()
#         std = np.std(data_arr, axis=0).flatten()
#         y_plus = mean + std
#         y_minus = mean - std
#         y_cor = np.array([y_minus[:-1], y_plus[:-1], y_plus[1:], y_minus[1:]])
#         ax_n.plot(np.exp(in_x+n_factor), mean, linestyle='-', marker=shapes[n], markersize=marksize, alpha=a, color='k', linewidth=0.5)
#         ax_n.vlines(np.exp(in_x+n_factor), y_minus, y_plus, linewidth=3., alpha=a, color='k')
#         nz_mean_max[n] = max(nz_mean_max[n], np.max(y_plus))
#         nz_mean_min[n] = min(nz_mean_min[n], np.min(y_minus))
#         ax_n.plot(np.exp(in_x+n_factor), np.log(nz_stats['truth'][n]), linestyle='-', marker=shapes[n], markersize=marksize, alpha=a, linewidth=0.5, color='k')
        ax_n.set_ylabel(r'percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim((min(nz_mean_min)-1., max(nz_mean_max)+1.))
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ moments', fontsize=16)
    ax.legend(loc='upper right')
    fig.tight_layout()
    fig.savefig(loc+'_clean.pdf', dpi=250)
    plt.close()
            
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=1.)
    ax_n = ax
    for key in formats:
        ax_n.plot([-10], [0], color=colors[key], label=key, linestyle=styles[key], linewidth=1)
#     ax.plot([-10], [0], color='k', label='original', linewidth=0.5, alpha=1.)
    for n in range(1, n_moments_use):
        n_factor = 0.1 * (n - 2)
        ax.scatter([-10], [0], color='k', alpha=a, marker=shapes[n], s=2*marksize, label=moment_names[n])
        truth = np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1)
        if n>1:
            ax_n = ax.twinx()
            rot_ang = 270
            label_space = 15.
        else:
            rot_ang = 90
            label_space = 0.
        if n>2:
            ax_n.spines["right"].set_position(("axes", 1. + 0.1 * (n-1)))
            make_patch_spines_invisible(ax_n)
            ax_n.spines["right"].set_visible(True)
        for s in range(len(formats)):
            f = formats[s]
            f_factor = 0.05 * (s - 1)
#             print('nz moments data shape '+str(nz_stats[f][n]))
            data_arr = (np.swapaxes(np.array(nz_stats[f][n]), 0, 1) - truth) / truth * 100.
            for i in data_arr:
                ax_n.plot(np.exp(in_x+n_factor), i, linestyle=styles[f], marker=shapes[n], markersize=marksize, color=colors[f], alpha=a)
#                 nz_moment_max[n-1].append(max(i))
        data_arr = np.log(np.swapaxes(np.array(nz_stats['truth'][n]), 0, 1))
#         for i in data_arr:
#             ax_n.plot(np.exp(in_x+n_factor), i, linestyle='-', marker=shapes[n], markersize=marksize, color='k', alpha=a)
# #         ax_n.plot(np.exp(in_x+n_factor), np.log(nz_stats['truth'][n]), linestyle='-', marker=shapes[n], markersize=marksize, alpha=a, linewidth=0.5, color='k')
        ax_n.set_ylabel(r'percent error on '+moment_names[n], rotation=rot_ang, fontsize=14, labelpad=label_space)
        ax_n.set_ylim((min(nz_mean_min)-1., max(nz_mean_max)+1.))
    ax.set_xscale('log')
    ax.set_xticks(flat_floats)
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set_xlim(np.exp(min(in_x)-0.25), np.exp(max(in_x)+0.25))
    ax.set_xlabel('number of parameters', fontsize=14)
    ax.set_title(dataset_info[dataset_key]['name']+r' data $\hat{n}(z)$ moments', fontsize=16)
    ax.legend(loc='upper right')
    fig.tight_layout()
    fig.savefig(loc+'_all.pdf', dpi=250)
    plt.close()

Okay, now all I have to do is have this loop over both datasets, number of galaxies, number of floats, and instantiations!

Note: It takes about 5 minutes per # floats considered for 100 galaxies, and about 40 minutes per # floats for 1000 galaxies. (So, yes, it scales more or less as expected!)

In [20]:
dataset_info = {}
delta = 0.01

dataset_keys = ['mg', 'ss']

for dataset_key in dataset_keys:
    dataset_info[dataset_key] = {}
    if dataset_key == 'mg':
        datafilename = 'bpz_euclid_test_10_3.probs'
        z_low = 0.01
        z_high = 3.51
        nc_needed = 3
        plotname = 'brighter'
        skip_rows = 1
        skip_cols = 1
    elif dataset_key == 'ss':
        datafilename = 'test_magscat_trainingfile_probs.out'
        z_low = 0.005
        z_high = 2.11
        nc_needed = 5
        plotname = 'fainter'
        skip_rows = 1
        skip_cols = 1
    dataset_info[dataset_key]['filename'] = datafilename  
    
    dataset_info[dataset_key]['z_lim'] = (z_low, z_high)
    z_grid = np.arange(z_low, z_high, delta, dtype='float')#np.arange(z_low, z_high + delta, delta, dtype='float')
    z_range = z_high - z_low
    delta_z = z_range / len(z_grid)
    dataset_info[dataset_key]['z_grid'] = z_grid
    dataset_info[dataset_key]['delta_z'] = delta_z

    dataset_info[dataset_key]['N_GMM'] = nc_needed# will be overwritten later
    dataset_info[dataset_key]['name'] = plotname
In [21]:
high_res = 300
color_cycle = np.array([(230, 159, 0), (86, 180, 233), (0, 158, 115), (240, 228, 66), (0, 114, 178), (213, 94, 0), (204, 121, 167)])/256.
n_plot = len(color_cycle)
n_moments_use = 4
moment_names = ['integral', 'mean', 'variance', 'kurtosis']
moment_shapes = ['o', '*', 'P', 'X']

#make this a more clever structure, i.e. a dict
formats = ['quantiles', 'histogram', 'samples']
colors = {'quantiles': 'blueviolet', 'histogram': 'darkorange', 'samples': 'forestgreen'}
styles = {'quantiles': '--', 'histogram': ':', 'samples': '-.'}
stepstyles = {'quantiles': 'dashed', 'histogram': 'dotted', 'samples': 'dashdot'}

formats_plus = ['quantiles', 'histogram', 'samples', 'truth']
colors_plus = {'quantiles': 'blueviolet', 'histogram': 'darkorange', 'samples': 'forestgreen', 'truth':'black'}
styles_plus = {'quantiles': '--', 'histogram': ':', 'samples': '-.', 'truth': '-'}

iqr_min = [3.5]
iqr_max = [delta]
modes_max = [0]
pz_max = [1.]
nz_max = [1.]
hist_max = [1.]
dist_min = [0.]
dist_max = [0.]
pz_mean_max = -10.*np.ones(n_moments_use)
pz_mean_min = 10.*np.ones(n_moments_use)
kld_min = [1.]
kld_max = [1.]
nz_mean_max = -10.*np.ones(n_moments_use)
nz_mean_min = 10.*np.ones(n_moments_use)
n_delta_max = -10.*np.ones(n_moments_use)
n_delta_min = 10.*np.ones(n_moments_use)
In [22]:
#change all for NERSC

floats = [3, 10, 30, 100]
sizes = [10]#[10, 100, 1000]
names = dataset_info.keys()
instantiations = range(2, 3)#0)

all_randos = [[np.random.choice(size, n_plot, replace=False) for size in sizes] for name in names]

The "pipeline" is a bunch of nested for loops because qp.Ensemble makes heavy use of multiprocessing. Doing multiprocessing within multiprocessing may or may not cause problems, but I am certain that it makes debugging a nightmare.

Okay, without further ado, let's do it!

In [23]:
# the "pipeline"
global_start = timeit.default_timer()
for n in range(len(names)):
    name = names[n]
    
    dataset_start = timeit.default_timer()
    print('started '+name)
    
    pdfs = setup_dataset(name, skip_rows, skip_cols)
    
    for s in range(len(sizes)):
        size=sizes[s]
        
        size_start = timeit.default_timer()
        print('started '+name+str(size))
        
        path = os.path.join(name, str(size))
        if not os.path.exists(path):
            os.makedirs(path)
        
        n_gals_use = size
        
        randos = all_randos[n][s]
        
        for i in instantiations:
#             top_bonusdict = {}
            i_start = timeit.default_timer()
            print('started '+name+str(size)+' #'+str(i))
        
            original = '_original'+str(i)
            pdfs_use = make_instantiation(name, size, pdfs, bonus=original)
#             plot = plot_examples(size, name, bonus=original)
#             top_bonusdict[original] = ['-', 0.25]
        
            z_grid = dataset_info[name]['in_z_grid']
            N_comps = dataset_info[name]['N_GMM']
        
            postfit = '_postfit'+str(i)
            catalog = setup_from_grid(name, pdfs_use, z_grid, N_comps, high_res=high_res, bonus=postfit)
#             plot = plot_examples(size, name, bonus=postfit)
#             top_bonusdict[postfit] = ['-', 0.5]
        
            for n_floats_use in floats:
#                 bonusdict = top_bonusdict.copy()
                float_start = timeit.default_timer()
                print('started '+name+str(size)+' #'+str(i)+' with '+str(n_floats_use))
        
                ensembles = analyze_individual(catalog, z_grid, n_floats_use, name, n_moments_use, i=i, bonus=postfit)
                
#                 for f in formats:
#                     fname = str(n_floats_use)+f+str(i)
#                     plot = plot_examples(size, name, bonus=fname)
#                     bonusdict[fname] = [styles[f], 0.5]
#                 plot = plot_all_examples(name, size, n_floats_use, i, bonus=bonusdict)
#                 plot = plot_individual_kld(size, name, n_floats_use, i=i)
            
                stack_evals = analyze_stacked(catalog, ensembles, z_grid, n_floats_use, name, i=i)
#                 plot = plot_estimators(size, name, n_floats_use, i=i)
            
                print('FINISHED '+name+str(size)+' #'+str(i)+' with '+str(n_floats_use)+' in '+str(timeit.default_timer() - float_start))
            print('FINISHED '+name+str(size)+' #'+str(i)+' in '+str(timeit.default_timer() - i_start))
#         plot = plot_pz_metrics(name, size)
#         plot = plot_pz_delta_moments(name, size)      
#         plot = plot_nz_klds(name, size)
#         plot = plot_nz_moments(name, size)
        
        print('FINISHED '+name+str(size)+' in '+str(timeit.default_timer() - size_start))
        
    print('FINISHED '+name+' in '+str(timeit.default_timer() - dataset_start))
print('FINISHED everything in '+str(timeit.default_timer() - global_start))
started ss
read in data file in 11.2739341259
started ss10
started ss10 #2
randos for debugging: [21538 91754 37805 55875  5972 56011 72367 67397 25966 71019]
preprocessed data in 0.0339078903198
made the pool of 4 in 0.143560171127
made the catalog in 0.0556938648224
made the initial ensemble of 10 PDFs in 0.200202941895
took 300 samples in 1.17442893982
made the pool of 4 in 9.79900360107e-05
made the catalog in 0.0416679382324
made a new ensemble from samples in 0.0423328876495
fit the GMM to samples in 0.167675018311
made the pool of 4 in 2.90870666504e-05
made the catalog in 0.769010782242
made the final ensemble in 0.769263029099
calculated 4 moments of original PDFs in 3.48636102676
started ss10 #2 with 3
finished making in 0.969101905823
finished histogramization in 0.932945013046
finished sampling in 2.92821788788
made the pool of 4 in 5.10215759277e-05
made the catalog in 0.376241922379
made quantiles ensemble in 1.13677692413
made the pool of 4 in 5.10215759277e-05
made the catalog in 0.425587892532
made histogram ensemble in 1.29570603371
made the pool of 4 in 6.31809234619e-05
made the catalog in 0.48353600502
made samples ensemble in 1.30219697952
calculated the quantiles individual moments, kld moments in 4.10330486298
calculated the samples individual moments, kld moments in 4.03081297874
calculated the histogram individual moments, kld moments in 4.66959190369
stacked quantiles in 0.880034923553
stacked histogram in 0.823406934738
stacked samples in 0.851178884506
stacked truth in 1.29819512367
calculated the quantiles stacked kld in 0.00104308128357
calculated the histogram stacked kld in 0.000702142715454
calculated the samples stacked kld in 0.000648021697998
calculated the quantiles stacked moments in 0.000274896621704
calculated the quantiles stacked moments in 0.000715970993042
calculated the quantiles stacked moments in 0.00103878974915
calculated the quantiles stacked moments in 0.00144100189209
calculated the histogram stacked moments in 0.000185012817383
calculated the histogram stacked moments in 0.000486135482788
calculated the histogram stacked moments in 0.000741004943848
calculated the histogram stacked moments in 0.00104212760925
calculated the samples stacked moments in 0.000212907791138
calculated the samples stacked moments in 0.000457048416138
calculated the samples stacked moments in 0.00075101852417
calculated the samples stacked moments in 0.00113701820374
calculated the truth stacked moments in 0.000216007232666
calculated the truth stacked moments in 0.000505924224854
calculated the truth stacked moments in 0.000817060470581
calculated the truth stacked moments in 0.00114703178406
FINISHED ss10 #2 with 3 in 25.5616438389
started ss10 #2 with 10
finished making in 0.969255208969
finished histogramization in 1.01541614532
finished sampling in 1.30525708199
made the pool of 4 in 6.103515625e-05
made the catalog in 0.468002796173
made quantiles ensemble in 1.53850412369
made the pool of 4 in 0.000102043151855
made the catalog in 0.702124118805
made histogram ensemble in 1.95024609566
made the pool of 4 in 5.88893890381e-05
made the catalog in 0.522737979889
made samples ensemble in 1.39210891724
calculated the quantiles individual moments, kld moments in 4.09236812592
calculated the samples individual moments, kld moments in 4.13031983376
calculated the histogram individual moments, kld moments in 5.16090512276
stacked quantiles in 0.836797952652
stacked histogram in 0.853886842728
stacked samples in 0.849081993103
stacked truth in 0.855072021484
calculated the quantiles stacked kld in 0.000715017318726
calculated the histogram stacked kld in 0.000553131103516
calculated the samples stacked kld in 0.000545978546143
calculated the quantiles stacked moments in 0.000257968902588
calculated the quantiles stacked moments in 0.000645875930786
calculated the quantiles stacked moments in 0.000994920730591
calculated the quantiles stacked moments in 0.00130200386047
calculated the histogram stacked moments in 0.000175952911377
calculated the histogram stacked moments in 0.000493049621582
calculated the histogram stacked moments in 0.000676155090332
calculated the histogram stacked moments in 0.000903129577637
calculated the samples stacked moments in 0.0001380443573
calculated the samples stacked moments in 0.000327110290527
calculated the samples stacked moments in 0.000588178634644
calculated the samples stacked moments in 0.000787019729614
calculated the truth stacked moments in 0.000263929367065
calculated the truth stacked moments in 0.000633001327515
calculated the truth stacked moments in 0.000967979431152
calculated the truth stacked moments in 0.0012469291687
FINISHED ss10 #2 with 10 in 25.4388132095
started ss10 #2 with 30
finished making in 0.874783039093
finished histogramization in 1.15561914444
finished sampling in 0.931715965271
made the pool of 4 in 5.79357147217e-05
made the catalog in 0.467746973038
made quantiles ensemble in 1.32912802696
made the pool of 4 in 6.41345977783e-05
made the catalog in 0.474547147751
made histogram ensemble in 1.61009693146
made the pool of 4 in 8.10623168945e-05
made the catalog in 0.535789012909
made samples ensemble in 1.46645712852
calculated the quantiles individual moments, kld moments in 8.42942214012
calculated the samples individual moments, kld moments in 6.12721705437
calculated the histogram individual moments, kld moments in 4.38112401962
stacked quantiles in 0.854387044907
stacked histogram in 0.876643896103
stacked samples in 0.838201999664
stacked truth in 0.777998924255
calculated the quantiles stacked kld in 0.000617027282715
calculated the histogram stacked kld in 0.000613927841187
calculated the samples stacked kld in 0.000825881958008
calculated the quantiles stacked moments in 0.000220060348511
calculated the quantiles stacked moments in 0.000627040863037
calculated the quantiles stacked moments in 0.000865936279297
calculated the quantiles stacked moments in 0.00115704536438
calculated the histogram stacked moments in 0.000209808349609
calculated the histogram stacked moments in 0.000463962554932
calculated the histogram stacked moments in 0.000783920288086
calculated the histogram stacked moments in 0.00114488601685
calculated the samples stacked moments in 0.000169992446899
calculated the samples stacked moments in 0.000415086746216
calculated the samples stacked moments in 0.00068211555481
calculated the samples stacked moments in 0.000911951065063
calculated the truth stacked moments in 0.000155925750732
calculated the truth stacked moments in 0.000379085540771
calculated the truth stacked moments in 0.000571966171265
calculated the truth stacked moments in 0.00083589553833
FINISHED ss10 #2 with 30 in 30.3066999912
started ss10 #2 with 100
finished making in 2.44939208031
finished histogramization in 0.818332910538
finished sampling in 0.83979511261
made the pool of 4 in 5.3882598877e-05
made the catalog in 0.529491901398
made quantiles ensemble in 1.34206700325
made the pool of 4 in 5.29289245605e-05
made the catalog in 0.467664003372
made histogram ensemble in 1.25780892372
made the pool of 4 in 5.29289245605e-05
made the catalog in 0.435877799988
made samples ensemble in 1.25956916809
calculated the quantiles individual moments, kld moments in 6.02705717087
calculated the samples individual moments, kld moments in 5.44899702072
calculated the histogram individual moments, kld moments in 4.46730804443
stacked quantiles in 1.02569699287
stacked histogram in 0.787800073624
stacked samples in 1.18201804161
stacked truth in 1.13130617142
calculated the quantiles stacked kld in 0.00105404853821
calculated the histogram stacked kld in 0.000844955444336
calculated the samples stacked kld in 0.000648021697998
calculated the quantiles stacked moments in 0.000410079956055
calculated the quantiles stacked moments in 0.000878095626831
calculated the quantiles stacked moments in 0.00110411643982
calculated the quantiles stacked moments in 0.00160813331604
calculated the histogram stacked moments in 0.000263929367065
calculated the histogram stacked moments in 0.000639915466309
calculated the histogram stacked moments in 0.000905990600586
calculated the histogram stacked moments in 0.00126791000366
calculated the samples stacked moments in 0.000211000442505
calculated the samples stacked moments in 0.000504016876221
calculated the samples stacked moments in 0.00091814994812
calculated the samples stacked moments in 0.00122618675232
calculated the truth stacked moments in 0.000243902206421
calculated the truth stacked moments in 0.000509023666382
calculated the truth stacked moments in 0.000752925872803
calculated the truth stacked moments in 0.00108003616333
FINISHED ss10 #2 with 100 in 28.9493198395
FINISHED ss10 #2 in 117.245790005
FINISHED ss10 in 117.246392965
FINISHED ss in 128.52078104
started mg
read in data file in 22.7191698551
started mg10
started mg10 #2
randos for debugging: [51107 68537 53635 23399  9697 77903 25869 12059 40991 63275]
preprocessed data in 0.0304780006409
made the pool of 4 in 6.60419464111e-05
made the catalog in 0.039901971817
made the initial ensemble of 10 PDFs in 0.0404348373413
took 300 samples in 1.94261884689
made the pool of 4 in 5.79357147217e-05
made the catalog in 0.0352969169617
made a new ensemble from samples in 0.0355360507965
fit the GMM to samples in 0.147353172302
made the pool of 4 in 2.40802764893e-05
made the catalog in 0.674927949905
made the final ensemble in 0.675142049789
calculated 4 moments of original PDFs in 3.03729391098
started mg10 #2 with 3
finished making in 0.780121803284
finished histogramization in 0.774260044098
finished sampling in 0.808349132538
made the pool of 4 in 5.48362731934e-05
made the catalog in 0.423326015472
made quantiles ensemble in 1.21974611282
made the pool of 4 in 4.79221343994e-05
made the catalog in 0.456423997879
made histogram ensemble in 1.30866789818
made the pool of 4 in 4.88758087158e-05
made the catalog in 0.464884996414
made samples ensemble in 1.31574010849
calculated the quantiles individual moments, kld moments in 4.12014389038
calculated the samples individual moments, kld moments in 3.89831805229
calculated the histogram individual moments, kld moments in 4.01975798607
stacked quantiles in 1.21437597275
stacked histogram in 0.838474988937
stacked samples in 1.42362308502
stacked truth in 1.11074781418
calculated the quantiles stacked kld in 0.00150585174561
calculated the histogram stacked kld in 0.00222396850586
calculated the samples stacked kld in 0.00371885299683
calculated the quantiles stacked moments in 0.000260829925537
calculated the quantiles stacked moments in 0.000839948654175
calculated the quantiles stacked moments in 0.00129890441895
calculated the quantiles stacked moments in 0.00178384780884
calculated the histogram stacked moments in 0.000329971313477
calculated the histogram stacked moments in 0.000756025314331
calculated the histogram stacked moments in 0.0011899471283
calculated the histogram stacked moments in 0.00180697441101
calculated the samples stacked moments in 0.000293970108032
calculated the samples stacked moments in 0.000711917877197
calculated the samples stacked moments in 0.00110793113708
calculated the samples stacked moments in 0.00156378746033
calculated the truth stacked moments in 0.000337839126587
calculated the truth stacked moments in 0.000759840011597
calculated the truth stacked moments in 0.00116181373596
calculated the truth stacked moments in 0.00162887573242
FINISHED mg10 #2 with 3 in 23.3168389797
started mg10 #2 with 10
finished making in 1.51137495041
finished histogramization in 1.39197802544
finished sampling in 1.23923683167
made the pool of 4 in 6.41345977783e-05
made the catalog in 0.524284124374
made quantiles ensemble in 1.66794490814
made the pool of 4 in 8.51154327393e-05
made the catalog in 0.745748996735
made histogram ensemble in 2.26602888107
made the pool of 4 in 7.39097595215e-05
made the catalog in 0.647045850754
made samples ensemble in 1.62145590782
calculated the quantiles individual moments, kld moments in 6.71728801727
calculated the samples individual moments, kld moments in 5.40768003464
calculated the histogram individual moments, kld moments in 6.78759598732
stacked quantiles in 1.14174818993
stacked histogram in 0.903980970383
stacked samples in 0.96507692337
stacked truth in 0.863555908203
calculated the quantiles stacked kld in 0.000929117202759
calculated the histogram stacked kld in 0.000620126724243
calculated the samples stacked kld in 0.000907897949219
calculated the quantiles stacked moments in 0.000237941741943
calculated the quantiles stacked moments in 0.00062894821167
calculated the quantiles stacked moments in 0.000795841217041
calculated the quantiles stacked moments in 0.00125288963318
calculated the histogram stacked moments in 0.000307083129883
calculated the histogram stacked moments in 0.000663042068481
calculated the histogram stacked moments in 0.000990867614746
calculated the histogram stacked moments in 0.001384973526
calculated the samples stacked moments in 0.000274181365967
calculated the samples stacked moments in 0.000648021697998
calculated the samples stacked moments in 0.000982999801636
calculated the samples stacked moments in 0.00139117240906
calculated the truth stacked moments in 0.000205993652344
calculated the truth stacked moments in 0.00053596496582
calculated the truth stacked moments in 0.000850915908813
calculated the truth stacked moments in 0.00121688842773
FINISHED mg10 #2 with 10 in 32.9919991493
started mg10 #2 with 30
finished making in 1.60738801956
finished histogramization in 0.867848157883
finished sampling in 1.39367294312
made the pool of 4 in 0.000163078308105
made the catalog in 0.82582116127
made quantiles ensemble in 2.11856389046
made the pool of 4 in 5.72204589844e-05
made the catalog in 0.520595788956
made histogram ensemble in 1.46121788025
made the pool of 4 in 0.000101804733276
made the catalog in 0.525902032852
made samples ensemble in 1.46025896072
calculated the quantiles individual moments, kld moments in 4.10835695267
calculated the samples individual moments, kld moments in 4.14074587822
calculated the histogram individual moments, kld moments in 3.88139796257
stacked quantiles in 0.777491092682
stacked histogram in 0.864093065262
stacked samples in 0.810467004776
stacked truth in 0.804461956024
calculated the quantiles stacked kld in 0.000840187072754
calculated the histogram stacked kld in 0.000622987747192
calculated the samples stacked kld in 0.000609874725342
calculated the quantiles stacked moments in 0.000240087509155
calculated the quantiles stacked moments in 0.000703096389771
calculated the quantiles stacked moments in 0.00105404853821
calculated the quantiles stacked moments in 0.00142908096313
calculated the histogram stacked moments in 0.000262022018433
calculated the histogram stacked moments in 0.000568151473999
calculated the histogram stacked moments in 0.000870227813721
calculated the histogram stacked moments in 0.00122117996216
calculated the samples stacked moments in 0.000231027603149
calculated the samples stacked moments in 0.000545978546143
calculated the samples stacked moments in 0.000859975814819
calculated the samples stacked moments in 0.00120496749878
calculated the truth stacked moments in 0.000247001647949
calculated the truth stacked moments in 0.000550031661987
calculated the truth stacked moments in 0.000839948654175
calculated the truth stacked moments in 0.00118112564087
FINISHED mg10 #2 with 30 in 24.9179830551
started mg10 #2 with 100
finished making in 4.54299402237
finished histogramization in 0.877863883972
finished sampling in 0.789620876312
made the pool of 4 in 5.48362731934e-05
made the catalog in 0.45265007019
made quantiles ensemble in 1.30355715752
made the pool of 4 in 4.91142272949e-05
made the catalog in 0.463873147964
made histogram ensemble in 1.32496881485
made the pool of 4 in 5.31673431396e-05
made the catalog in 0.430168867111
made samples ensemble in 1.3029999733
calculated the quantiles individual moments, kld moments in 4.18964004517
calculated the samples individual moments, kld moments in 3.98618006706
calculated the histogram individual moments, kld moments in 4.04275989532
stacked quantiles in 0.791258096695
stacked histogram in 0.834949970245
stacked samples in 0.800794839859
stacked truth in 0.832922935486
calculated the quantiles stacked kld in 0.00085711479187
calculated the histogram stacked kld in 0.000594854354858
calculated the samples stacked kld in 0.000599145889282
calculated the quantiles stacked moments in 0.000240802764893
calculated the quantiles stacked moments in 0.000664949417114
calculated the quantiles stacked moments in 0.000977993011475
calculated the quantiles stacked moments in 0.00132584571838
calculated the histogram stacked moments in 0.00029993057251
calculated the histogram stacked moments in 0.000585079193115
calculated the histogram stacked moments in 0.000874042510986
calculated the histogram stacked moments in 0.00120496749878
calculated the samples stacked moments in 0.000239849090576
calculated the samples stacked moments in 0.000517845153809
calculated the samples stacked moments in 0.000797033309937
calculated the samples stacked moments in 0.00112199783325
calculated the truth stacked moments in 0.000216007232666
calculated the truth stacked moments in 0.000493049621582
calculated the truth stacked moments in 0.000772953033447
calculated the truth stacked moments in 0.0010929107666
FINISHED mg10 #2 with 100 in 26.3846879005
FINISHED mg10 #2 in 114.395443916
FINISHED mg10 in 114.39635396
FINISHED mg in 137.116107941
FINISHED everything in 265.638630867

Remake the plots to share axes.

In [26]:
floats = [3, 10, 30, 100]
sizes = [10]#[10, 100, 1000]
names = dataset_info.keys()
instantiations = range(2, 3)#0)

all_randos = [[np.random.choice(size, n_plot, replace=False) for size in sizes] for name in names]
In [27]:
# comment out for NERSC
# run twice to match axis limits

for name in names:
    for size in sizes:
        for i in instantiations:
            top_bonusdict = {}
            bo = '_original'+str(i)
#             plot = plot_examples(size, name, bonus=bo)
            top_bonusdict[bo] = ['-', 0.25]
            bp = '_postfit'+str(i)
#             plot = plot_examples(size, name, bonus=bp)
            top_bonusdict[bp] = ['-', 0.5]
            for n in range(len(floats)):
                bonusdict = top_bonusdict.copy()
                n_floats_use = floats[n]
                for f in formats:
                    fname = str(n_floats_use)+f+str(i)
#                     plot = plot_examples(size, name, bonus=fname)
                    bonusdict[fname] = [styles[f], 0.5]
                plot = plot_all_examples(name, size, n_floats_use, i, bonus=bonusdict)
                plot = plot_individual_kld(size, name, n_floats_use, i)
                plot = plot_estimators(size, name, n_floats_use, i)
        plot = plot_pz_metrics(name, size)
        plot = plot_pz_delta_moments(name, size)
        plot = plot_nz_klds(name, size)
        plot = plot_nz_moments(name, size)
/home/aimalz/.local/lib/python2.7/site-packages/matplotlib/axes/_axes.py:6198: RuntimeWarning: invalid value encountered in true_divide
  m = (m.astype(float) / db) / m.sum()
In [ ]:
 
In [ ]: