Coverage for src/lccalib/zp.py: 0%
147 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 12:30 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 12:30 +0000
1"""
2Tools to fit a zp from a color transformation
3"""
5import numpy as np
6import pandas as pd
7import matplotlib.pyplot as plt
9from saltworks.plottools import binplot
10from saltworks.linearmodels import RobustLinearSolver, linear_func, indic
11from saltworks.indextools import make_index
13from .match import match
15# pylint: disable=invalid-name,too-many-locals,too-many-arguments
18def get_matched_cats(survey, star_lc_cat, logger, by=None, **kwargs):
19 """Align secondary star catalog with star lc catalog
21 :param survey: secondary star catalog provider
22 :param recarray star_lc_cat: lc averaged catalog
23 :param logger: logger
24 :param dict by: loop over values given by dict unique key
25 :return recarray star_selec: aligned secondary star catalog
26 :return recarray lc_selec: aligned lc star catalog
27 :return list index: selected index of secondary catalog
28 """
29 if by is None:
30 star_selec, lc_selec, index = _get_matched_cats(survey, star_lc_cat, **kwargs)
31 logger.info(f"number of stars in lc catalog: {len(star_lc_cat)}")
32 logger.info(f"number of stars considered in ref catalog: {len(index)}")
33 logger.info(f"number of match with ref catalog: {len(star_selec)}")
34 return star_selec, lc_selec
36 key = list(by.keys())[0]
37 stack_star = []
38 stack_lc = []
40 for k in by[key]:
41 _star_lc_cat = star_lc_cat[star_lc_cat[key] == k]
42 kwargs[key] = k
43 star_selec, lc_selec, index = _get_matched_cats(survey, _star_lc_cat, **kwargs)
44 logger.info(k)
45 logger.info(f"number of stars in lc catalog: {len(_star_lc_cat)}")
46 logger.info(f"number of stars considered in ref catalog: {len(index)}")
47 logger.info(f"number of match with ref catalog: {len(star_selec)}")
48 stack_star.append(star_selec)
49 stack_lc.append(lc_selec)
51 star_selec = np.hstack(stack_star)
52 lc_selec = np.hstack(stack_lc)
53 return star_selec, lc_selec
56def _get_matched_cats(survey, star_lc_cat, arcsecrad=1, **kwargs):
58 star_cat = survey.get_secondary_star_catalog(**kwargs)
59 if isinstance(star_cat, pd.DataFrame):
60 star_cat = star_cat.to_records()
62 selec = np.isfinite(star_lc_cat["ra"]) & np.isfinite(star_lc_cat["dec"])
63 if isinstance(star_lc_cat, pd.DataFrame):
64 star_lc_cat = star_lc_cat.to_records()
65 star_lc_cat = star_lc_cat[selec]
66 ra_bounds = star_lc_cat["ra"].min(), star_lc_cat["ra"].max()
67 dec_bounds = star_lc_cat["dec"].min(), star_lc_cat["dec"].max()
69 selec = (star_cat["ra"] > ra_bounds[0] - 0.2) & (
70 star_cat["ra"] < ra_bounds[1] + 0.2
71 )
72 selec &= (star_cat["dec"] > dec_bounds[0] - 0.2) & (
73 star_cat["dec"] < dec_bounds[1] + 0.2
74 )
75 star_cat = star_cat[selec]
77 index = match(star_lc_cat, star_cat, arcsecrad=arcsecrad)
78 star_selec = star_cat[index != -1]
79 lc_selec = star_lc_cat[index[index != -1]]
81 return star_selec, lc_selec, index
83# pylint: disable=dangerous-default-value
84def plot_diff_mag(
85 survey,
86 star_selec,
87 lc_selec,
88 xlabel="mag",
89 lims=[-0.2, 0.2],
90 fig=None,
91 axs=None,
92 **kwargs,
93):
94 """Plot mag difference between lc and secondary star catalogs.
96 :param survey: secondary star catalog provider
97 :param recarray star_selec: aligned secondary star catalog
98 :param recarray lc_selec: aligned light curve star catalog
99 :param str xlabel: can be mag or color
100 :return: fig, ax
101 """
102 nband = len(survey.bands)
103 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + 1, 2)
105 if axs is None:
106 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True)
107 axs = axs.flatten()
108 for ax, band in zip(axs, survey.bands):
109 labels = survey.get_secondary_labels(band, **kwargs)
110 mag = -2.5 * np.log10(lc_selec[f"flux_{band}"])
111 if labels["mag"] not in star_selec.dtype.names:
112 continue
113 mag0 = star_selec[labels["mag"]]
114 diff = mag - mag0
115 x = (
116 mag0
117 if xlabel == "mag"
118 else star_selec[labels[xlabel][0]] - star_selec[labels[xlabel][1]]
119 )
120 goods = np.isfinite(x)
121 binplot(x[goods], diff[goods] - np.nanmean(diff), ax=ax)
122 ax.grid("on")
123 ax.set_ylim(*lims)
124 ax.set_ylabel(f"{band}: aper-psf")
125 axs[-1].set_xlabel(xlabel)
126 if nband > 3:
127 axs[-2].set_xlabel(xlabel)
128 return fig, axs
131def compute_zp(
132 survey,
133 band,
134 lc_selec,
135 star_selec,
136 color_range,
137 zpkey="ccd",
138 error_floor=0,
139 **kwargs,
140):
141 """Fit a zp per zpkey (like ccd, name) and a joined linear color term
143 :param survey: secondary star catalog provider
144 :param band: band name
145 :param recarray star_selec: aligned secondary star catalog
146 :param recarray lc_selec: aligned light curve star catalog
147 :param list color_range: color range on which the fit is done
148 :param str zpkey: column name of lc_selec on which zp apply
149 :return dict dfit: dict with all fitted quantities
150 """
152 labels = survey.get_secondary_labels(band, **kwargs)
153 mag_psf = -2.5 * np.log10(lc_selec[f"flux_{band}"])
154 emag_psf = 1.08 * (lc_selec[f"eflux_{band}"] / lc_selec[f"flux_{band}"])
156 mag_ap = star_selec[labels["mag"]]
157 emag_ap = star_selec[labels["emag"]]
159 color = star_selec[labels["color"][0]] - star_selec[labels["color"][1]]
160 wcolor = np.sqrt(1 / (emag_psf**2 + emag_ap**2 + error_floor**2))
162 goods = (np.isfinite(mag_psf)) & (np.isfinite(mag_ap))
163 goods &= (np.isfinite(color)) & (np.isfinite(wcolor))
164 goods &= (color > color_range[0]) & (color < color_range[1])
165 for k, v in labels["goods"].items():
166 goods &= star_selec[k] > v
167 goods &= (mag_ap < labels["mag_cut"][1]) & (mag_ap > labels["mag_cut"][0])
169 y = np.array((mag_psf - mag_ap))
170 w = np.array(wcolor) # w = np.ones((goods.sum()))
172 model = linear_func(color[goods], name="alpha") + indic(
173 np.array(lc_selec[zpkey])[goods], name="beta"
174 )
175 solver = RobustLinearSolver(model, y[goods], weights=w[goods])
176 x = solver.robust_solution(nsig=3)
177 model.params.free = x
178 res = solver.get_res(y[goods], x)
179 err = np.sqrt(solver.get_cov().diagonal())
181 return dict(
182 {
183 "y": y,
184 "x": x,
185 "color": color,
186 "res": res,
187 "err": err,
188 "mag": mag_ap,
189 "goods": goods,
190 "bads": solver.bads,
191 "w": w,
192 "model": model(),
193 "wres": solver.get_wres(x=x),
194 }
195 )
197# pylint: disable=dangerous-default-value
198def plot_zpfit_res(
199 zpfits,
200 xlabel="mag",
201 lims=[-0.03, 0.03],
202 fig=None,
203 axs=None,
204):
205 """Plot zp fit residuals.
208 :param str xlabel: can be mag or color
209 :return: fig, ax
210 """
211 bands = list(zpfits.keys())
213 nband = len(bands)
214 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + nband % 2, 2)
216 if axs is None:
217 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True)
218 axs = axs.flatten()
219 for ax, band in zip(axs, bands):
220 _dfit = zpfits[band]
221 binplot(
222 np.array(_dfit[xlabel][_dfit["goods"]]),
223 np.array(_dfit["res"]),
224 ax=ax,
225 data=True,
226 label=band,
227 )
228 ax.grid("on")
229 ax.set_ylim(*lims)
230 ax.legend()
231 axs[-1].set_xlabel(xlabel)
232 if nband > 3:
233 axs[-2].set_xlabel(xlabel)
234 plt.tight_layout()
235 return fig, axs
238def zpfit_diagnostic(dfit, nbins=15):
239 """Plot zpfit diagnostic including rms of the residual compared to
240 measurement error and chi2.
241 """
242 # rms vs predicted error, chi2
243 bads = dfit["bads"]
244 y = dfit["res"]
245 wres = dfit["wres"]
247 fig, ax = plt.subplots(2, 2, figsize=(15, 5), sharex="col")
248 ax = list(ax.flatten())
249 for x, (ax0, ax1), xlabel in zip(
250 [dfit["mag"][dfit["goods"]], dfit["color"][dfit["goods"]]],
251 [[ax[0], ax[2]], [ax[1], ax[3]]],
252 ["mag", "color"],
253 ):
254 _, xbinned, xerr, index = make_bins(x, nbins)
255 ngood = np.array([(~bads[e]).sum() for e in index])
257 mean_y2 = np.array([(y[e][~bads[e]] ** 2).sum() for e in index]) / ngood
258 mean_2y = (np.array([y[e][~bads[e]].sum() for e in index]) / ngood) ** 2
259 chi2 = np.array([(wres[e][~bads[e]] ** 2).sum() for e in index]) / (ngood - 1)
260 rms = np.sqrt(mean_y2 - mean_2y)
261 # nmeas = np.array([len(y[e]) for e in index])
263 yerr = np.sqrt(1 / dfit["w"][dfit["goods"]] ** 2)
264 yerr = [yerr[e].mean() for e in index]
266 ax0.errorbar(xbinned, rms, xerr=xerr, ls="None", marker="+", label="res rms")
267 ax0.errorbar(
268 xbinned, yerr, xerr=xerr, ls="None", marker="+", label="predicted errors"
269 )
270 ax1.errorbar(
271 xbinned, chi2, xerr=xerr, ls="None", marker="+", label="chi2 / dof"
272 )
273 ax1.set_ylim(0, 0.1)
274 ax1.set_ylim(0.5, 10)
275 ax1.set_yscale("log")
276 ax0.legend()
277 ax1.legend()
278 ax1.axhline(y=1, color="k", ls="--")
279 ax1.set_xlabel(xlabel)
280 return fig, ax
283def make_bins(x, nbins):
284 """Define nbins bin in x.
285 :param array x: x
286 :param int nbins: number of bins
287 :return array bins: bins limit
288 :return array xbinned: binned version of x
289 :return array xerr: bins size
290 :return array index: index of x corresponding to each bin
291 """
292 bins = np.linspace(x.min(), x.max() + abs(x.max() * 1e-7), nbins + 1)
293 yd = np.digitize(x, bins)
294 index = make_index(yd)
295 xbinned = 0.5 * (bins[:-1] + bins[1:])
296 usedbins = np.array(np.sort(list(set(yd)))) - 1
297 xbinned = xbinned[usedbins]
298 bins = bins[usedbins + 1]
299 xerr = np.array([bins, bins]) - np.array([xbinned, xbinned])
300 return bins, xbinned, xerr, index