#!/usr/bin/python
# -*- coding: latin-1 -*-
"""
A Python wrapper for the SVO Filter Profile Service
"""
from glob import glob
import inspect
import os
import pickle
from pkg_resources import resource_filename
import warnings
import itertools
import astropy.table as at
import astropy.io.votable as vo
import astropy.units as q
import astropy.constants as ac
from astropy.utils.exceptions import AstropyWarning
from bokeh.plotting import figure, show
import bokeh.palettes as bpal
import numpy as np
warnings.simplefilter('ignore', category=AstropyWarning)
EXTINCTION = {'PS1.g': 3.384, 'PS1.r': 2.483, 'PS1.i': 1.838, 'PS1.z': 1.414, 'PS1.y': 1.126,
'SDSS.u': 4.0, 'SDSS.g': 3.384, 'SDSS.r': 2.483, 'SDSS.i': 1.838, 'SDSS.z': 1.414,
'2MASS.J': 0.650, '2MASS.H': 0.327, '2MASS.Ks': 0.161}
[docs]class Filter:
"""
Creates a Filter object to store a photometric filter profile
and metadata
Attributes
----------
path: str
The absolute filepath for the bandpass data, an ASCII file with
a wavelength column in Angstroms and a response column of values
ranging from 0 to 1
refs: list, str
The references for the bandpass data
rsr: np.ndarray
The wavelength and relative spectral response (RSR) arrays
Band: str
The band name
CalibrationReference: str
The paper detailing the calibration
FWHM: float
The FWHM for the filter
Facility: str
The telescope facility
FilterProfileService: str
The SVO source
MagSys: str
The magnitude system
PhotCalID: str
The calibration standard
PhotSystem: str
The photometric system
ProfileReference: str
The SVO reference
WavelengthCen: float
The center wavelength
WavelengthEff: float
The effective wavelength
WavelengthMax: float
The maximum wavelength
WavelengthMean: float
The mean wavelength
WavelengthMin: float
The minimum wavelength
WavelengthPeak: float
The peak wavelength
WavelengthPhot: float
The photon distribution based effective wavelength
WavelengthPivot: float
The wavelength pivot
WavelengthUCD: str
The SVO wavelength unit
WavelengthUnit: str
The wavelength unit
WidthEff: float
The effective width
ZeroPoint: float
The value of the zero point flux
ZeroPointType: str
The system of the zero point
ZeroPointUnit: str
The units of the zero point
filterID: str
The SVO filter ID
"""
def __init__(self, band, filter_directory=None,
wave_units=q.um, flux_units=q.erg/q.s/q.cm**2/q.AA,
**kwargs):
"""
Loads the bandpass data into the Filter object
Parameters
----------
band: str
The bandpass filename (e.g. 2MASS.J)
filter_directory: str
The directory containing the filter files
wave_units: str, astropy.units.core.PrefixUnit (optional)
The wavelength units
flux_units: str, astropy.units.core.PrefixUnit (optional)
The zeropoint flux units
"""
if filter_directory is None:
filter_directory = resource_filename('svo_filters', 'data/filters/')
# Check if TopHat
if band.lower().replace('-', '').replace(' ', '') == 'tophat':
# check kwargs for limits
wave_min = kwargs.get('wave_min')
wave_max = kwargs.get('wave_max')
filepath = ''
if wave_min is None or wave_max is None:
raise ValueError("Please provide **{'wave_min', 'wave_max'} to create top hat filter.")
else:
# Load the filter
n_pix = kwargs.get('n_pixels', 100)
self.load_TopHat(wave_min, wave_max, n_pix)
else:
# Get list of filters
files = glob(filter_directory+'*')
no_ext = {f.replace('.txt', ''): f for f in files}
bands = [os.path.basename(b) for b in no_ext]
fp = os.path.join(filter_directory, band)
filepath = no_ext.get(fp, fp)
# If the filter is missing, ask what to do
if band not in bands:
err = """No filters match {}\n\nCurrent filters: {}\n\nA full list of available filters from the\nSVO Filter Profile Service can be found at\nhttp: //svo2.cab.inta-csic.es/theory/fps3/\n\nPlace the desired filter XML file in your\nfilter directory and try again.""".format(filepath, ', '.join(bands))
raise IOError(err)
# Get the first line to determine format
with open(filepath) as f:
top = f.readline()
# Read in XML file
if top.startswith('<?xml'):
self.load_xml(filepath)
# Read in txt file
elif filepath.endswith('.txt'):
self.load_txt(filepath)
else:
raise TypeError("File must be XML or ascii format.")
# Set the wavelength and throughput
self._wave_units = q.AA
self._wave = np.array([self.raw[0]]) * self.wave_units
self._throughput = np.array([self.raw[1]])
# Set n_bins and pixels_per_bin
self.n_bins = 1
self.pixels_per_bin = self.raw.shape[-1]
# Rename some values and apply units
self.wave_min = self.WavelengthMin * self.wave_units
self.wave_max = self.WavelengthMax * self.wave_units
self.wave_eff = self.WavelengthEff * self.wave_units
self.wave_center = self.WavelengthCen * self.wave_units
self.wave_mean = self.WavelengthMean * self.wave_units
self.wave_peak = self.WavelengthPeak * self.wave_units
self.wave_phot = self.WavelengthPhot * self.wave_units
self.wave_pivot = self.WavelengthPivot * self.wave_units
self.width_eff = self.WidthEff * self.wave_units
self.fwhm = self.FWHM * self.wave_units
self.zp = self.ZeroPoint * q.Unit(self.ZeroPointUnit)
# Delete redundant attributes
del self.WavelengthMin, self.WavelengthMax, self.WavelengthEff
del self.WavelengthCen, self.WavelengthMean, self.WavelengthPeak
del self.WavelengthPhot, self.WavelengthPivot, self.WidthEff, self.FWHM
del self.ZeroPointUnit, self.ZeroPoint
try:
del self.WavelengthUnit
except AttributeError:
pass
# Set the wavelength units
if wave_units is not None:
self.wave_units = wave_units
# Set zeropoint flux units
if flux_units is not None:
self._flux_units = self.zp.unit
self.flux_units = flux_units
# Get references
self.refs = []
try:
if isinstance(self.CalibrationReference, str):
self.refs = [self.CalibrationReference.split('=')[-1]]
except:
self.CalibrationReference = None
# Set a base name
self.name = self.filterID.split('/')[-1]
# Try to get the extinction vector R from Green et al. (2018)
self.ext_vector = EXTINCTION.get(self.name, 0)
# Bin
if kwargs:
bwargs = {k: v for k, v in kwargs.items() if k in
inspect.signature(self.bin).parameters.keys()}
self.bin(**bwargs)
[docs] def apply(self, spectrum, plot=False):
"""
Apply the filter to the given spectrum
Parameters
----------
spectrum: array-like
The wavelength [um] and flux of the spectrum
to apply the filter to
plot: bool
Plot the original and filtered spectrum
Returns
-------
np.ndarray
The filtered spectrum
"""
# Make into iterable arrays
wav, flx, *err = [np.asarray(i) for i in spectrum]
# Make flux 2D
if len(flx.shape) == 1:
flx = np.expand_dims(flx, axis=0)
# Make throughput 3D
rsr = np.copy(self.rsr)
# Make empty filtered array
filtered = np.zeros((rsr.shape[0], flx.shape[0], rsr.shape[2]))
# Rebin the input spectra to the filter wavelength array
# and apply the RSR curve to the spectrum
for i, bn in enumerate(rsr):
for j, f in enumerate(flx):
filtered[i][j] = np.interp(bn[0], wav, f)*bn[1]
if plot:
# Make the figure
COLORS = color_gen('Category10')
xlab = 'Wavelength [{}]'.format(self.wave_units)
ylab = 'Flux Density [{}]'.format(self.flux_units)
title = self.filterID
fig = figure(title=title, x_axis_label=xlab, y_axis_label=ylab)
# Plot the unfiltered spectrum
fig.line(wav, flx[0], legend='Input spectrum', color='black')
# Plot each spectrum bin
for wav, bn in zip(self.wave, filtered):
fig.line(wav, bn[0], color=next(COLORS))
show(fig)
del rsr, wav, flx
return filtered.squeeze()
[docs] def bin(self, n_bins=1, pixels_per_bin=None, wave_min=None, wave_max=None):
"""
Break the filter up into bins and apply a throughput to each bin,
useful for G141, G102, and other grisms
Parameters
----------
n_bins: int
The number of bins to dice the throughput curve into
pixels_per_bin: int (optional)
The number of channels per bin, which will be used
to calculate n_bins
wave_min: astropy.units.quantity (optional)
The minimum wavelength to use
wave_max: astropy.units.quantity (optional)
The maximum wavelength to use
"""
# Get wavelength limits
if wave_min is not None:
self.wave_min = wave_min
if wave_max is not None:
self.wave_max = wave_max
# Trim the wavelength by the given min and max
raw_wave = self.raw[0]
whr = np.logical_and(raw_wave * q.AA >= self.wave_min,
raw_wave * q.AA <= self.wave_max)
self.wave = (raw_wave[whr] * q.AA).to(self.wave_units)
self.throughput = self.raw[1][whr]
print('Bandpass trimmed to',
'{} - {}'.format(self.wave_min, self.wave_max))
# Calculate the number of bins and channels
pts = len(self.wave)
if isinstance(pixels_per_bin, int):
self.pixels_per_bin = pixels_per_bin
self.n_bins = int(pts/self.pixels_per_bin)
elif isinstance(n_bins, int):
self.n_bins = n_bins
self.pixels_per_bin = int(pts/self.n_bins)
else:
raise ValueError("Please specify 'n_bins' OR 'pixels_per_bin' as integers.")
print('{} bins of {} pixels each.'.format(self.n_bins,
self.pixels_per_bin))
# Trim throughput edges so that there are an integer number of bins
new_len = self.n_bins * self.pixels_per_bin
start = (pts - new_len) // 2
self.wave = self.wave[start:new_len+start].reshape(self.n_bins, self.pixels_per_bin)
self.throughput = self.throughput[start:new_len+start].reshape(self.n_bins, self.pixels_per_bin)
@property
def centers(self):
"""A getter for the wavelength bin centers and average fluxes"""
# Get the bin centers
w_cen = np.nanmean(self.wave.value, axis=1)
f_cen = np.nanmean(self.throughput, axis=1)
return np.asarray([w_cen, f_cen])
@property
def flux_units(self):
"""A getter for the flux units"""
return self._flux_units
@flux_units.setter
def flux_units(self, units):
"""
A setter for the flux units
Parameters
----------
units: str, astropy.units.core.PrefixUnit
The desired units of the zeropoint flux density
"""
# Check that the units are valid
dtypes = (q.core.PrefixUnit, q.quantity.Quantity, q.core.CompositeUnit)
if not isinstance(units, dtypes):
raise ValueError(units, "units not understood.")
# Check that the units changed
if units != self.flux_units:
# Convert to new units
sfd = q.spectral_density(self.wave_eff)
self.zp = self.zp.to(units, equivalencies=sfd)
# Store new units
self._flux_units = units
[docs] def info(self, fetch=False):
"""
Print a table of info about the current filter
"""
# Get the info from the class
tp = (int, bytes, bool, str, float, tuple, list, np.ndarray)
info = [[k, str(v)] for k, v in vars(self).items() if isinstance(v, tp)
and k not in ['rsr', 'raw', 'centers'] and not k.startswith('_')]
# Make the table
table = at.Table(np.asarray(info).reshape(len(info), 2),
names=['Attributes', 'Values'])
# Sort and print
table.sort('Attributes')
if fetch:
return table
else:
table.pprint(max_width=-1, max_lines=-1, align=['>', '<'])
[docs] def load_TopHat(self, wave_min, wave_max, pixels_per_bin=100):
"""
Loads a top hat filter given wavelength min and max values
Parameters
----------
wave_min: astropy.units.quantity (optional)
The minimum wavelength to use
wave_max: astropy.units.quantity (optional)
The maximum wavelength to use
n_pixels: int
The number of pixels for the filter
"""
# Get min, max, effective wavelengths and width
self.pixels_per_bin = pixels_per_bin
self.n_bins = 1
self._wave_units = q.AA
wave_min = wave_min.to(self.wave_units)
wave_max = wave_max.to(self.wave_units)
# Create the RSR curve
self._wave = np.linspace(wave_min, wave_max, pixels_per_bin)
self._throughput = np.ones_like(self.wave)
self.raw = np.array([self.wave.value, self.throughput])
# Calculate the effective wavelength
wave_eff = ((wave_min + wave_max) / 2.).value
width = (wave_max - wave_min).value
# Add the attributes
self.path = ''
self.refs = ''
self.Band = 'Top Hat'
self.CalibrationReference = ''
self.FWHM = width
self.Facility = '-'
self.FilterProfileService = '-'
self.MagSys = '-'
self.PhotCalID = ''
self.PhotSystem = ''
self.ProfileReference = ''
self.WavelengthMin = wave_min.value
self.WavelengthMax = wave_max.value
self.WavelengthCen = wave_eff
self.WavelengthEff = wave_eff
self.WavelengthMean = wave_eff
self.WavelengthPeak = wave_eff
self.WavelengthPhot = wave_eff
self.WavelengthPivot = wave_eff
self.WavelengthUCD = ''
self.WidthEff = width
self.ZeroPoint = 0
self.ZeroPointType = ''
self.ZeroPointUnit = 'Jy'
self.filterID = 'Top Hat'
[docs] def load_txt(self, filepath):
"""Load the filter from a txt file
Parameters
----------
file: str
The filepath
"""
self.raw = np.genfromtxt(filepath, unpack=True)
# Convert to Angstroms if microns
if self.raw[0][-1] < 100:
self.raw[0] = self.raw[0] * 10000
self.WavelengthUnit = str(q.AA)
self.ZeroPointUnit = str(q.erg/q.s/q.cm**2/q.AA)
x, f = self.raw
# Get a spectrum of Vega
vega_file = resource_filename('svo_filters', 'data/spectra/vega.txt')
vega = np.genfromtxt(vega_file, unpack=True)[: 2]
vega[0] = vega[0] * 10000
vega = rebin_spec(vega, x)*q.erg/q.s/q.cm**2/q.AA
flam = np.trapz((vega[1]*f).to(q.erg/q.s/q.cm**2/q.AA), x=x)
thru = np.trapz(f, x=x)
self.ZeroPoint = (flam/thru).to(q.erg/q.s/q.cm**2/q.AA).value
# Calculate the filter's properties
self.filterID = os.path.splitext(os.path.basename(filepath))[0]
self.WavelengthPeak = np.max(self.raw[0])
f0 = f[: np.where(np.diff(f) > 0)[0][-1]]
x0 = x[: np.where(np.diff(f) > 0)[0][-1]]
self.WavelengthMin = np.interp(max(f)/100., f0, x0)
f1 = f[::-1][: np.where(np.diff(f[::-1]) > 0)[0][-1]]
x1 = x[::-1][: np.where(np.diff(f[::-1]) > 0)[0][-1]]
self.WavelengthMax = np.interp(max(f)/100., f1, x1)
self.WavelengthEff = np.trapz(f*x*vega, x=x)/np.trapz(f*vega, x=x)
self.WavelengthMean = np.trapz(f*x, x=x)/np.trapz(f, x=x)
self.WidthEff = np.trapz(f*x, x=x)
piv = np.trapz(f*x, x=x)
self.WavelengthPivot = np.sqrt(piv/np.trapz(f/x, x=x))
pht = f*vega*x**2
self.WavelengthPhot = np.trapz(pht, x=x)/np.trapz(f*vega*x, x=x)
# Fix these two:
self.WavelengthCen = self.WavelengthMean
self.FWHM = self.WidthEff
# Add missing attributes
self.path = ''
self.pixels_per_bin = self.raw.shape[-1]
self.n_bins = 1
[docs] def load_xml(self, filepath):
"""Load the filter from a txt file
Parameters
----------
filepath: str
The filepath for the filter
"""
# Parse the XML file
vot = vo.parse_single_table(filepath)
self.raw = np.array([list(i) for i in vot.array]).T
# Parse the filter metadata
for p in [str(p).split() for p in vot.params]:
# Extract the key/value pairs
key = p[1].split('"')[1]
val = p[-1].split('"')[1]
# Do some formatting
flt1 = p[2].split('"')[1] == 'float'
flt2 = p[3].split('"')[1] == 'float'
if flt1 or flt2:
val = float(val)
else:
val = val.replace('b'', '')\
.replace('&apos', '')\
.replace('&', '&')\
.strip(';')
# Set the attribute
if key != 'Description':
setattr(self, key, val)
# Create some attributes
self.path = filepath
self.pixels_per_bin = self.raw.shape[-1]
self.n_bins = 1
[docs] def overlap(self, spectrum):
"""Tests for overlap of this filter with a spectrum
Example of full overlap:
|---------- spectrum ----------|
|------ self ------|
Examples of partial overlap: :
|---------- self ----------|
|------ spectrum ------|
|---- spectrum ----|
|----- self -----|
|---- self ----|
|---- spectrum ----|
Examples of no overlap: :
|---- spectrum ----| |---- other ----|
|---- other ----| |---- spectrum ----|
Parameters
----------
spectrum: sequence
The [W, F] spectrum with astropy units
Returns
-------
ans : {'full', 'partial', 'none'}
Overlap status.
"""
swave = self.wave[np.where(self.throughput != 0)]
s1, s2 = swave.min(), swave.max()
owave = spectrum[0]
o1, o2 = owave.min(), owave.max()
if (s1 >= o1 and s2 <= o2):
ans = 'full'
elif (s2 < o1) or (o2 < s1):
ans = 'none'
else:
ans = 'partial'
return ans
[docs] def plot(self, fig=None):
"""
Plot the filter
"""
COLORS = color_gen('Category10')
# Make the figure
if fig is None:
xlab = 'Wavelength [{}]'.format(self.wave_units)
ylab = 'Throughput'
title = self.filterID
fig = figure(title=title, x_axis_label=xlab, y_axis_label=ylab)
# Plot the raw curve
fig.line((self.raw[0]*q.AA).to(self.wave_units), self.raw[1],
alpha=0.1, line_width=8, color='black')
# Plot each with bin centers
for x, y in self.rsr:
fig.line(x, y, color=next(COLORS), line_width=2)
fig.circle(*self.centers, size=8, color='black')
show(fig)
@property
def rsr(self):
"""A getter for the relative spectral response (rsr) curve"""
arr = np.array([self.wave.value, self.throughput]).swapaxes(0, 1)
return arr
@property
def throughput(self):
"""A getter for the throughput"""
return self._throughput
@throughput.setter
def throughput(self, points):
"""A setter for the throughput
Parameters
----------
throughput: sequence
The array of throughput points
"""
# Test shape
if not points.shape == self.wave.shape:
raise ValueError("Throughput and wavelength must be same shape.")
self._throughput = points
@property
def wave(self):
"""A getter for the wavelength"""
return self._wave
@wave.setter
def wave(self, wavelength):
"""A setter for the wavelength
Parameters
----------
wavelength: astropy.units.quantity.Quantity
The array with units
"""
# Test units
if not isinstance(wavelength, q.quantity.Quantity):
raise ValueError("Wavelength must be in length units.")
self._wave = wavelength
self.wave_units = wavelength.unit
@property
def wave_units(self):
"""A getter for the wavelength units"""
return self._wave_units
@wave_units.setter
def wave_units(self, units):
"""
A setter for the wavelength units
Parameters
----------
units: str, astropy.units.core.PrefixUnit
The wavelength units
"""
# Make sure it's length units
if not units.is_equivalent(q.m):
raise ValueError(units, ": New wavelength units must be a length.")
# Update the units
self._wave_units = units
# Update all the wavelength values
self._wave = self.wave.to(self.wave_units).round(5)
self.wave_min = self.wave_min.to(self.wave_units).round(5)
self.wave_max = self.wave_max.to(self.wave_units).round(5)
self.wave_eff = self.wave_eff.to(self.wave_units).round(5)
self.wave_center = self.wave_center.to(self.wave_units).round(5)
self.wave_mean = self.wave_mean.to(self.wave_units).round(5)
self.wave_peak = self.wave_peak.to(self.wave_units).round(5)
self.wave_phot = self.wave_phot.to(self.wave_units).round(5)
self.wave_pivot = self.wave_pivot.to(self.wave_units).round(5)
self.width_eff = self.width_eff.to(self.wave_units).round(5)
self.fwhm = self.fwhm.to(self.wave_units).round(5)
[docs]def color_gen(colormap='viridis', key=None, n=15):
"""Color generator for Bokeh plots
Parameters
----------
colormap: str, sequence
The name of the color map
Returns
-------
generator
A generator for the color palette
"""
if colormap in dir(bpal):
palette = getattr(bpal, colormap)
if isinstance(palette, dict):
if key is None:
key = list(palette.keys())[0]
palette = palette[key]
elif callable(palette):
palette = palette(n)
else:
raise TypeError("pallette must be a bokeh palette name or a sequence of color hex values.")
elif isinstance(colormap, (list, tuple)):
palette = colormap
else:
raise TypeError("pallette must be a bokeh palette name or a sequence of color hex values.")
yield from itertools.cycle(palette)
[docs]def filters(filter_directory=None, update=False, fmt='table', **kwargs):
"""
Get a list of the available filters
Parameters
----------
filter_directory: str
The directory containing the filter relative spectral response curves
update: bool
Check the filter directory for new filters and generate pickle of table
fmt: str
The format for the returned table
Returns
-------
list
The list of band names
"""
if filter_directory is None:
filter_directory = resource_filename('svo_filters', 'data/filters/')
# Get the pickle path and make sure file exists
p_path = os.path.join(filter_directory, 'filter_list.p')
updated = False
if not os.path.isfile(p_path):
os.system('touch {}'.format(p_path))
if update:
print('Loading filters into table...')
# Get all the filters (except the pickle)
files = glob(filter_directory+'*')
files = [f for f in files if not f.endswith('.p')]
bands = [os.path.basename(b) for b in files]
tables = []
for band in bands:
# Load the filter
filt = Filter(band, **kwargs)
filt.Band = band
# Put metadata into table with correct dtypes
info = filt.info(True)
vals = [float(i) if i.replace('.', '').replace('-', '')
.replace('+', '').isnumeric() else i
for i in info['Values']]
dtypes = np.array([type(i) for i in vals])
table = at.Table(np.array([vals]), names=info['Attributes'],
dtype=dtypes)
tables.append(table)
del filt, info, table
# Write to the pickle
with open(p_path, 'wb') as file:
pickle.dump(at.vstack(tables), file)
# Load the saved pickle
data = {}
if os.path.isfile(p_path):
with open(p_path, 'rb') as file:
data = pickle.load(file)
# Return the data
if data:
if fmt == 'dict':
data = {r[0]: {k: r[k].value if hasattr(r[k], 'unit') else r[k]
for k in data.keys()[1:]} for r in data}
else:
# Add Band as index
data.add_index('Band')
return data
# Or try to generate it once
else:
if not updated:
updated = True
filters(update=True)
else:
print('No filters found in', filter_directory)
[docs]def rebin_spec(spec, wavnew, oversamp=100, plot=False):
"""
Rebin a spectrum to a new wavelength array while preserving
the total flux
Parameters
----------
spec: array-like
The wavelength and flux to be binned
wavenew: array-like
The new wavelength array
Returns
-------
np.ndarray
The rebinned flux
"""
wave, flux = spec
nlam = len(wave)
x0 = np.arange(nlam, dtype=float)
x0int = np.arange((nlam-1.) * oversamp + 1., dtype=float)/oversamp
w0int = np.interp(x0int, x0, wave)
spec0int = np.interp(w0int, wave, flux)/oversamp
# Set up the bin edges for down-binning
maxdiffw1 = np.diff(wavnew).max()
w1bins = np.concatenate(([wavnew[0]-maxdiffw1],
.5*(wavnew[1::]+wavnew[0: -1]),
[wavnew[-1]+maxdiffw1]))
# Bin down the interpolated spectrum:
w1bins = np.sort(w1bins)
nbins = len(w1bins)-1
specnew = np.zeros(nbins)
inds2 = [[w0int.searchsorted(w1bins[ii], side='left'),
w0int.searchsorted(w1bins[ii+1], side='left')]
for ii in range(nbins)]
for ii in range(nbins):
specnew[ii] = np.sum(spec0int[inds2[ii][0]: inds2[ii][1]])
return specnew