Source code for desk.plotting_seds

import math
import pdb

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import ascii
from astropy.table import Table
from desk import config
from matplotlib import rc

'''
Steve Goldman
Space Telescope Science Institute
May 17, 2018
sgoldman@stsci.edu

This script is for plotting the outputs of the sed_fitting script.
'''


[docs]def create_fig(): """ :return: Runs plotting script """ # rc('text', usetex=True) # plt.rcParams['font.family'] = 'serif' # plt.rcParams['mathtext.fontset'] = 'dejavuserif' # plt.rcParams['text.usetex'] = True # plt.rcParams['text.latex.unicode'] = True full_path = str(__file__.replace('plotting_seds.py', '')) input_file = Table.read('fitting_plotting_outputs.csv') grid_dusty = Table.read(full_path + 'models/' + str(input_file['grid_name'][0]) + '_models.fits') def get_data(filename): """ :param filename: filename of input data. Should be csv with Column 0: wavelength in um and Col 1: flux in Jy :return: two arrays of wavelength (x) and flux (y) in unit specified in config.py """ table = ascii.read(filename, delimiter=',') table.sort(table.colnames[0]) x = np.array(table.columns[0]) y = np.array(table.columns[1]) if config.output['output_unit'] == 'Wm^-2': y = y * u.Jy y = y.to(u.W / (u.m * u.m), equivalencies=u.spectral_density(x * u.um)) elif config.output['output_unit'] == 'Jy': pass else: raise ValueError("Unit in config.py not 'Wm^-2' or 'Jy'") return x, np.array(y) # setting axes if len(input_file) == 1: fig, ax1 = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(8, 5)) elif len(input_file) == 2: fig, axs = plt.subplots(2, 1, sharex=True, sharey=True, figsize=(8, 10)) elif len(input_file) == 3: fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(8, 10)) else: fig, axs = plt.subplots(math.ceil(len(input_file) / 3), 3, sharex=True, sharey=True, figsize=(8, 10)) axs = axs.ravel() for counter, target in enumerate(input_file): # gets data for plotting target_name = (target['target_name']).replace('.csv', '').replace('_', ' ') x_data, y_data = get_data(target['data_file']) x_model, y_model = grid_dusty[target['index']] x_model = x_model[np.where(y_model != 0)] y_model = y_model[np.where(y_model != 0)] * input_file[counter]['norm'] if config.output['output_unit'] == 'Wm^-2': axislabel = "log $\lambda$ F$_{\lambda}$ (W m$^{-2}$)" elif config.output['output_unit'] == 'Jy': y_model = y_model * u.W / (u.m * u.m) y_model = y_model / ((x_model * u.um).to(u.Hz, equivalencies=u.spectral())) y_model = y_model.to(u.Jy).value axislabel = "log F$_{\lambda}$ (Jy)" else: raise ValueError("Unit in config.py not 'Wm^-2' or 'Jy'") # logscale x_model = np.log10(x_model) y_model = np.log10(y_model) x_data = np.log10(x_data) y_data = np.log10(y_data) # out = Table((x_model,y_model), names=('wave', 'lamflam')) # out.write('crystalline_fit.csv', format='csv') # pdb.set_trace() # plotting if len(input_file) == 1: ax1.set_xlim(-0.99, 2.49) ax1.set_ylim(np.median(y_model) - 2, np.median(y_model) + 2) ax1.scatter(x_data, y_data, c='blue', label='data') ax1.plot(x_model, y_model, c='k', linewidth=0.5, linestyle='--', zorder=2, label='model') ax1.annotate(target_name.replace('-', r'\textendash'), (0.07, 0.85), xycoords='axes fraction', fontsize=14) ax1.get_xaxis().set_tick_params(which='both', direction='in', labelsize=15) ax1.get_yaxis().set_tick_params(which='both', direction='in', labelsize=15) ax1.set_xlabel('log $\lambda$ ($\mu m$)', labelpad=10) ax1.set_ylabel("log $\lambda$ F$_{\lambda}$ " + "(W m$^{-2}$)", labelpad=10) else: axs[counter].set_xlim(-0.99, 2.49) axs[counter].set_ylim(np.median(y_model) - 2, np.median(y_model) + 2) axs[counter].plot(x_model, y_model, c='k', linewidth=0.4, linestyle='--', zorder=2) axs[counter].scatter(x_data, y_data, c='blue') axs[counter].annotate(target_name.replace('-', r'\textendash'), (0.7, 0.8), xycoords='axes fraction', fontsize=14) axs[counter].get_xaxis().set_tick_params(which='both', direction='in', labelsize=15) axs[counter].get_yaxis().set_tick_params(which='both', direction='in', labelsize=15) axs[counter].set_xlabel('log $\lambda$ ($\mu m$)', labelpad=10) axs[counter].set_ylabel(axislabel, labelpad=10) # pdb.set_trace() plt.subplots_adjust(wspace=0, hspace=0) fig.savefig('output_sed.png', dpi=200, bbox_inches='tight')
[docs]def single_fig(): """ :return: Runs plotting script """ full_path = str(__file__.replace('plotting_seds.py', '')) input_file = Table.read('fitting_plotting_outputs.csv') grid_dusty = Table.read(full_path + 'models/' + str(input_file['grid_name'][0]) + '_models.fits') def get_data(filename): """ :param filename: filename of input data. Should be csv with Column 0: wavelength in um and Col 1: flux in Jy :return: two arrays of wavelength (x) and flux (y) in unit specified in config.py """ table = ascii.read(filename, delimiter=',') table.sort(table.colnames[0]) x = np.array(table.columns[0]) y = np.array(table.columns[1]) if config.output['output_unit'] == 'Wm^-2': y = y * u.Jy y = y.to(u.W / (u.m * u.m), equivalencies=u.spectral_density(x * u.um)) elif config.output['output_unit'] == 'Jy': pass else: raise ValueError("Unit in config.py not 'Wm^-2' or 'Jy'") return x, np.array(y) # setting axes for counter, target in enumerate(input_file): # gets data for plotting target_name = (target['target_name']).replace('.csv', '') x_data, y_data = get_data(target['data_file']) x_model, y_model = grid_dusty[target['index']] x_model = x_model[np.where(y_model != 0)] y_model = y_model[np.where(y_model != 0)] * input_file[counter]['norm'] if config.output['output_unit'] == 'Wm^-2': axislabel = "log $\lambda$ F$_{\lambda}$ (W m$^{-2}$)" elif config.output['output_unit'] == 'Jy': y_model = y_model * u.W / (u.m * u.m) y_model = y_model / ((x_model * u.um).to(u.Hz, equivalencies=u.spectral())) y_model = y_model.to(u.Jy).value axislabel = "log F$_{\lambda}$ (Jy)" else: raise ValueError("Unit in config.py not 'Wm^-2' or 'Jy'") # logscale x_model = np.log10(x_model) y_model = np.log10(y_model) x_data = np.log10(x_data) y_data = np.log10(y_data) # Figure plotting fig, ax1 = plt.subplots(1, 1, sharex=True, sharey=True, figsize=(8, 5)) ax1.set_xlim(-0.99, 2.49) ax1.set_ylim(np.median(y_model) - 3, np.median(y_model) + 3) ax1.scatter(x_data, y_data, c='blue', label='data') ax1.plot(x_model, y_model, c='k', linewidth=0.5, linestyle='--', zorder=2, label='model') ax1.annotate(target_name.replace('_', ' '), (0.07, 0.85), xycoords='axes fraction', fontsize=14) ax1.get_xaxis().set_tick_params(which='both', direction='in', labelsize=15) ax1.get_yaxis().set_tick_params(which='both', direction='in', labelsize=15) ax1.set_xlabel('log $\lambda$ ($\mu m$)', labelpad=10) ax1.set_ylabel("log $\lambda$ F$_{\lambda}$ " + "(W m$^{-2}$)", labelpad=10) plt.subplots_adjust(wspace=0, hspace=0) fig.savefig('output_sed'+str(target_name)+'.png', dpi=200, bbox_inches='tight') plt.close()
# if __name__ == '__main__': # create_fig()