"""
Module with a function for plotting spectra.
"""
import os
import math
import itertools
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from species.core import box, constants
from species.read import read_filter
from species.util import plot_util
mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'
plt.rc('axes', edgecolor='black', linewidth=2)
plt.rcParams['axes.axisbelow'] = False
[docs]def plot_spectrum(boxes,
filters=None,
residuals=None,
colors=None,
xlim=None,
ylim=None,
scale=('linear', 'linear'),
title=None,
offset=None,
legend='upper left',
figsize=(7., 5.),
object_type='planet',
quantity='flux',
output='spectrum.pdf'):
"""
Parameters
----------
boxes : list(species.core.box, )
Boxes with data.
filters : list(str, ), None
Filter IDs for which the transmission profile is plotted. Not plotted if set to None.
residuals : species.core.box.ResidualsBox, None
Box with residuals of a fit. Not plotted if set to None.
colors : list(str, ), None
Colors to be used for the different boxes. Note that a box with residuals requires a tuple
with two colors (i.e., for the photometry and spectrum). Automatic colors are used if set
to None.
xlim : tuple(float, float)
Limits of the x-axis.
ylim : tuple(float, float)
Limits of the y-axis.
scale : tuple(str, str)
Scale of the axes ('linear' or 'log').
title : str
Title.
offset : tuple(float, float)
Offset for the label of the x- and y-axis.
legend : str, None
Location of the legend.
figsize : tuple(float, float)
Figure size.
object_type : str
Object type ('planet' or 'star'). With 'planet', the radius and mass are expressed in
Jupiter units. With 'star', the radius and mass are expressed in solar units.
quantity: str
The quantity of the y-axis ('flux' or 'magnitude').
output : str
Output filename.
Returns
-------
NoneType
None
"""
marker = itertools.cycle(('o', 's', '*', 'p', '<', '>', 'P', 'v', '^'))
if residuals and filters:
plt.figure(1, figsize=figsize)
gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1])
gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
ax1 = plt.subplot(gridsp[1, 0])
ax2 = plt.subplot(gridsp[0, 0])
ax3 = plt.subplot(gridsp[2, 0])
elif residuals:
plt.figure(1, figsize=figsize)
gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[4, 1])
gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
ax1 = plt.subplot(gridsp[0, 0])
ax3 = plt.subplot(gridsp[1, 0])
elif filters:
plt.figure(1, figsize=figsize)
gridsp = mpl.gridspec.GridSpec(2, 1, height_ratios=[1, 4])
gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
ax1 = plt.subplot(gridsp[1, 0])
ax2 = plt.subplot(gridsp[0, 0])
else:
plt.figure(1, figsize=figsize)
gridsp = mpl.gridspec.GridSpec(1, 1)
gridsp.update(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
ax1 = plt.subplot(gridsp[0, 0])
if residuals:
labelbottom = False
else:
labelbottom = True
ax1.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=labelbottom)
ax1.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=labelbottom)
if filters:
ax2.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=False)
ax2.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True, labelbottom=False)
if residuals:
ax3.tick_params(axis='both', which='major', colors='black', labelcolor='black',
direction='in', width=1, length=5, labelsize=12, top=True,
bottom=True, left=True, right=True)
ax3.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
direction='in', width=1, length=3, labelsize=12, top=True,
bottom=True, left=True, right=True)
if residuals and filters:
ax1.set_xlabel('', fontsize=13)
ax2.set_xlabel('', fontsize=13)
ax3.set_xlabel('Wavelength [micron]', fontsize=13)
elif residuals:
ax1.set_xlabel('', fontsize=13)
ax3.set_xlabel('Wavelength [micron]', fontsize=13)
elif filters:
ax1.set_xlabel('Wavelength [micron]', fontsize=13)
ax2.set_xlabel('', fontsize=13)
else:
ax1.set_xlabel('Wavelength [micron]', fontsize=13)
if filters:
ax2.set_ylabel('Transmission', fontsize=13)
if residuals:
ax3.set_ylabel(r'Residual [$\sigma$]', fontsize=13)
if xlim:
ax1.set_xlim(xlim[0], xlim[1])
else:
ax1.set_xlim(0.6, 6.)
if quantity == 'magnitude':
scaling = 1.
ax1.set_ylabel('Flux contrast [mag]', fontsize=13)
if ylim:
ax1.set_ylim(ylim[0], ylim[1])
elif quantity == 'flux':
if ylim:
ax1.set_ylim(ylim[0], ylim[1])
ylim = ax1.get_ylim()
exponent = math.floor(math.log10(ylim[1]))
scaling = 10.**exponent
ylabel = r'Flux [10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$]'
ax1.set_ylabel(ylabel, fontsize=13)
ax1.set_ylim(ylim[0]/scaling, ylim[1]/scaling)
if ylim[0] < 0.:
ax1.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5)
else:
ax1.set_ylabel(r'Flux [W m$^{-2}$ $\mu$m$^{-1}$]', fontsize=13)
scaling = 1.
if filters:
ax2.set_ylim(0., 1.)
xlim = ax1.get_xlim()
if filters:
ax2.set_xlim(xlim[0], xlim[1])
if residuals:
ax3.set_xlim(xlim[0], xlim[1])
if offset and residuals and filters:
ax3.get_xaxis().set_label_coords(0.5, offset[0])
ax1.get_yaxis().set_label_coords(offset[1], 0.5)
ax2.get_yaxis().set_label_coords(offset[1], 0.5)
ax3.get_yaxis().set_label_coords(offset[1], 0.5)
elif offset and filters:
ax1.get_xaxis().set_label_coords(0.5, offset[0])
ax1.get_yaxis().set_label_coords(offset[1], 0.5)
ax2.get_yaxis().set_label_coords(offset[1], 0.5)
elif offset and residuals:
ax3.get_xaxis().set_label_coords(0.5, offset[0])
ax1.get_yaxis().set_label_coords(offset[1], 0.5)
ax3.get_yaxis().set_label_coords(offset[1], 0.5)
elif offset:
ax1.get_xaxis().set_label_coords(0.5, offset[0])
ax1.get_yaxis().set_label_coords(offset[1], 0.5)
else:
ax1.get_xaxis().set_label_coords(0.5, -0.12)
ax1.get_yaxis().set_label_coords(-0.1, 0.5)
ax1.set_xscale(scale[0])
ax1.set_yscale(scale[1])
if filters:
ax2.set_xscale(scale[0])
if residuals:
ax3.set_xscale(scale[0])
color_obj_phot = None
color_obj_spec = None
for j, boxitem in enumerate(boxes):
if isinstance(boxitem, (box.SpectrumBox, box.ModelBox)):
wavelength = boxitem.wavelength
flux = boxitem.flux
if isinstance(wavelength[0], (np.float32, np.float64)):
data = np.array(flux, dtype=np.float64)
masked = np.ma.array(data, mask=np.isnan(data))
if isinstance(boxitem, box.ModelBox):
param = boxitem.parameters
par_key, par_unit = plot_util.quantity_unit(param=list(param.keys()),
object_type=object_type)
par_val = list(param.values())
label = ''
newline = False
for i, item in enumerate(par_key):
if item == r'$T_\mathregular{eff}$':
value = f'{par_val[i]:.1f}'
elif item in (r'$\log\,g$', '[Fe/H]', 'C/O', r'f$_\mathregular{sed}$'):
value = f'{par_val[i]:.2f}'
elif item == r'$R$':
if object_type == 'planet':
value = f'{par_val[i]:.2f}'
elif object_type == 'star':
value = f'{par_val[i]*constants.R_JUP/constants.R_SUN:.2f}'
elif item == r'$M$':
if object_type == 'planet':
value = f'{par_val[i]:.2f}'
elif object_type == 'star':
value = f'{par_val[i]*constants.M_JUP/constants.M_SUN:.2f}'
elif item == r'$L$':
value = f'{par_val[i]:.1e}'
else:
continue
# if len(label) > 110 and newline == False:
# label += '\n'
# newline = True
if par_unit[i] is None:
label += item+' = '+str(value)
else:
label += item+' = '+str(value)+' '+par_unit[i]
if i < len(par_key)-1:
label += ', '
else:
label = None
if colors:
ax1.plot(wavelength, masked/scaling, color=colors[j], lw=0.5,
label=label, zorder=2)
else:
ax1.plot(wavelength, masked/scaling, lw=0.5, label=label, zorder=2)
elif isinstance(wavelength[0], (np.ndarray)):
for i, item in enumerate(wavelength):
data = np.array(flux[i], dtype=np.float64)
masked = np.ma.array(data, mask=np.isnan(data))
if isinstance(boxitem.name[i], bytes):
label = boxitem.name[i].decode('utf-8')
else:
label = boxitem.name[i]
ax1.plot(item, masked/scaling, lw=0.5, label=label)
elif isinstance(boxitem, list):
for i, item in enumerate(boxitem):
wavelength = item.wavelength
flux = item.flux
data = np.array(flux, dtype=np.float64)
masked = np.ma.array(data, mask=np.isnan(data))
if colors:
ax1.plot(wavelength, masked/scaling, lw=0.2, color=colors[j],
alpha=0.5, zorder=1)
else:
ax1.plot(wavelength, masked/scaling, lw=0.2, alpha=0.5, zorder=1)
elif isinstance(boxitem, box.PhotometryBox):
marker = next(marker)
if boxitem.quantity != 'flux':
raise ValueError(f'The quantity of the PhotometryBox is \'{boxitem.quantity}\' '
f'and not \'flux\'.')
for i, _ in enumerate(boxitem.wavelength):
if colors:
ax1.plot(boxitem.wavelength[i], boxitem.flux[i]/scaling, marker=marker, ms=6,
color=colors[j], zorder=3)
else:
ax1.plot(boxitem.wavelength[i], boxitem.flux[i]/scaling, marker=marker, ms=6,
zorder=3, color='black')
elif isinstance(boxitem, box.ObjectBox):
if boxitem.flux is not None:
for item in boxitem.flux:
transmission = read_filter.ReadFilter(item)
wavelength = transmission.mean_wavelength()
fwhm = transmission.filter_fwhm()
color_obj_phot = colors[j][0]
ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2.,
yerr=boxitem.flux[item][1]/scaling, marker='s', ms=5, zorder=3,
color=color_obj_phot, markerfacecolor=color_obj_phot)
if boxitem.spectrum is not None:
masked = np.ma.array(boxitem.spectrum, mask=np.isnan(boxitem.spectrum))
color_obj_spec = colors[j][1]
if colors is None:
ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling,
ms=2, marker='s', zorder=2.5, ls='none')
else:
ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling,
marker='o', ms=2, zorder=2.5, color=color_obj_spec,
markerfacecolor=color_obj_spec, ls='none')
elif isinstance(boxitem, box.SynphotBox):
for item in boxitem.flux:
transmission = read_filter.ReadFilter(item)
wavelength = transmission.mean_wavelength()
fwhm = transmission.filter_fwhm()
ax1.errorbar(wavelength, boxitem.flux[item]/scaling, xerr=fwhm/2., yerr=None,
alpha=0.7, marker='s', ms=5, zorder=4, color=colors[j],
markerfacecolor='white')
if filters:
for i, item in enumerate(filters):
transmission = read_filter.ReadFilter(item)
data = transmission.get_filter()
ax2.plot(data[0, ], data[1, ], '-', lw=0.7, color='black', zorder=1)
if residuals:
res_max = 0.
if residuals.photometry is not None:
ax3.plot(residuals.photometry[0, ], residuals.photometry[1, ], marker='s',
ms=5, linestyle='none', color=color_obj_phot, zorder=2)
res_max = np.nanmax(np.abs(residuals.photometry[1, ]))
if residuals.spectrum is not None:
ax3.plot(residuals.spectrum[0, ], residuals.spectrum[1, ], marker='o',
ms=2, linestyle='none', color=color_obj_spec, zorder=1)
max_tmp = np.nanmax(np.abs(residuals.spectrum[1, ]))
if max_tmp > res_max:
res_max = max_tmp
res_lim = math.ceil(1.1*res_max)
ax3.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5)
ax3.set_ylim(-res_lim, res_lim)
if filters:
ax2.set_ylim(0., 1.1)
print(f'Plotting spectrum: {output}...', end='', flush=True)
if title:
if filters:
ax2.set_title(title, y=1.02, fontsize=15)
else:
ax1.set_title(title, y=1.02, fontsize=15)
handles, _ = ax1.get_legend_handles_labels()
if handles and legend:
ax1.legend(loc=legend, prop={'size': 9}, frameon=False)
plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
plt.clf()
plt.close()
print(' [DONE]')