Coverage for src/lccalib/zp.py: 0%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 12:30 +0000

1""" 

2Tools to fit a zp from a color transformation 

3""" 

4 

5import numpy as np 

6import pandas as pd 

7import matplotlib.pyplot as plt 

8 

9from saltworks.plottools import binplot 

10from saltworks.linearmodels import RobustLinearSolver, linear_func, indic 

11from saltworks.indextools import make_index 

12 

13from .match import match 

14 

15# pylint: disable=invalid-name,too-many-locals,too-many-arguments 

16 

17 

18def get_matched_cats(survey, star_lc_cat, logger, by=None, **kwargs): 

19 """Align secondary star catalog with star lc catalog 

20 

21 :param survey: secondary star catalog provider 

22 :param recarray star_lc_cat: lc averaged catalog 

23 :param logger: logger 

24 :param dict by: loop over values given by dict unique key 

25 :return recarray star_selec: aligned secondary star catalog 

26 :return recarray lc_selec: aligned lc star catalog 

27 :return list index: selected index of secondary catalog 

28 """ 

29 if by is None: 

30 star_selec, lc_selec, index = _get_matched_cats(survey, star_lc_cat, **kwargs) 

31 logger.info(f"number of stars in lc catalog: {len(star_lc_cat)}") 

32 logger.info(f"number of stars considered in ref catalog: {len(index)}") 

33 logger.info(f"number of match with ref catalog: {len(star_selec)}") 

34 return star_selec, lc_selec 

35 

36 key = list(by.keys())[0] 

37 stack_star = [] 

38 stack_lc = [] 

39 

40 for k in by[key]: 

41 _star_lc_cat = star_lc_cat[star_lc_cat[key] == k] 

42 kwargs[key] = k 

43 star_selec, lc_selec, index = _get_matched_cats(survey, _star_lc_cat, **kwargs) 

44 logger.info(k) 

45 logger.info(f"number of stars in lc catalog: {len(_star_lc_cat)}") 

46 logger.info(f"number of stars considered in ref catalog: {len(index)}") 

47 logger.info(f"number of match with ref catalog: {len(star_selec)}") 

48 stack_star.append(star_selec) 

49 stack_lc.append(lc_selec) 

50 

51 star_selec = np.hstack(stack_star) 

52 lc_selec = np.hstack(stack_lc) 

53 return star_selec, lc_selec 

54 

55 

56def _get_matched_cats(survey, star_lc_cat, arcsecrad=1, **kwargs): 

57 

58 star_cat = survey.get_secondary_star_catalog(**kwargs) 

59 if isinstance(star_cat, pd.DataFrame): 

60 star_cat = star_cat.to_records() 

61 

62 selec = np.isfinite(star_lc_cat["ra"]) & np.isfinite(star_lc_cat["dec"]) 

63 if isinstance(star_lc_cat, pd.DataFrame): 

64 star_lc_cat = star_lc_cat.to_records() 

65 star_lc_cat = star_lc_cat[selec] 

66 ra_bounds = star_lc_cat["ra"].min(), star_lc_cat["ra"].max() 

67 dec_bounds = star_lc_cat["dec"].min(), star_lc_cat["dec"].max() 

68 

69 selec = (star_cat["ra"] > ra_bounds[0] - 0.2) & ( 

70 star_cat["ra"] < ra_bounds[1] + 0.2 

71 ) 

72 selec &= (star_cat["dec"] > dec_bounds[0] - 0.2) & ( 

73 star_cat["dec"] < dec_bounds[1] + 0.2 

74 ) 

75 star_cat = star_cat[selec] 

76 

77 index = match(star_lc_cat, star_cat, arcsecrad=arcsecrad) 

78 star_selec = star_cat[index != -1] 

79 lc_selec = star_lc_cat[index[index != -1]] 

80 

81 return star_selec, lc_selec, index 

82 

83# pylint: disable=dangerous-default-value 

84def plot_diff_mag( 

85 survey, 

86 star_selec, 

87 lc_selec, 

88 xlabel="mag", 

89 lims=[-0.2, 0.2], 

90 fig=None, 

91 axs=None, 

92 **kwargs, 

93): 

94 """Plot mag difference between lc and secondary star catalogs. 

95 

96 :param survey: secondary star catalog provider 

97 :param recarray star_selec: aligned secondary star catalog 

98 :param recarray lc_selec: aligned light curve star catalog 

99 :param str xlabel: can be mag or color 

100 :return: fig, ax 

101 """ 

102 nband = len(survey.bands) 

103 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + 1, 2) 

104 

105 if axs is None: 

106 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True) 

107 axs = axs.flatten() 

108 for ax, band in zip(axs, survey.bands): 

109 labels = survey.get_secondary_labels(band, **kwargs) 

110 mag = -2.5 * np.log10(lc_selec[f"flux_{band}"]) 

111 if labels["mag"] not in star_selec.dtype.names: 

112 continue 

113 mag0 = star_selec[labels["mag"]] 

114 diff = mag - mag0 

115 x = ( 

116 mag0 

117 if xlabel == "mag" 

118 else star_selec[labels[xlabel][0]] - star_selec[labels[xlabel][1]] 

119 ) 

120 goods = np.isfinite(x) 

121 binplot(x[goods], diff[goods] - np.nanmean(diff), ax=ax) 

122 ax.grid("on") 

123 ax.set_ylim(*lims) 

124 ax.set_ylabel(f"{band}: aper-psf") 

125 axs[-1].set_xlabel(xlabel) 

126 if nband > 3: 

127 axs[-2].set_xlabel(xlabel) 

128 return fig, axs 

129 

130 

131def compute_zp( 

132 survey, 

133 band, 

134 lc_selec, 

135 star_selec, 

136 color_range, 

137 zpkey="ccd", 

138 error_floor=0, 

139 **kwargs, 

140): 

141 """Fit a zp per zpkey (like ccd, name) and a joined linear color term 

142 

143 :param survey: secondary star catalog provider 

144 :param band: band name 

145 :param recarray star_selec: aligned secondary star catalog 

146 :param recarray lc_selec: aligned light curve star catalog 

147 :param list color_range: color range on which the fit is done 

148 :param str zpkey: column name of lc_selec on which zp apply 

149 :return dict dfit: dict with all fitted quantities 

150 """ 

151 

152 labels = survey.get_secondary_labels(band, **kwargs) 

153 mag_psf = -2.5 * np.log10(lc_selec[f"flux_{band}"]) 

154 emag_psf = 1.08 * (lc_selec[f"eflux_{band}"] / lc_selec[f"flux_{band}"]) 

155 

156 mag_ap = star_selec[labels["mag"]] 

157 emag_ap = star_selec[labels["emag"]] 

158 

159 color = star_selec[labels["color"][0]] - star_selec[labels["color"][1]] 

160 wcolor = np.sqrt(1 / (emag_psf**2 + emag_ap**2 + error_floor**2)) 

161 

162 goods = (np.isfinite(mag_psf)) & (np.isfinite(mag_ap)) 

163 goods &= (np.isfinite(color)) & (np.isfinite(wcolor)) 

164 goods &= (color > color_range[0]) & (color < color_range[1]) 

165 for k, v in labels["goods"].items(): 

166 goods &= star_selec[k] > v 

167 goods &= (mag_ap < labels["mag_cut"][1]) & (mag_ap > labels["mag_cut"][0]) 

168 

169 y = np.array((mag_psf - mag_ap)) 

170 w = np.array(wcolor) # w = np.ones((goods.sum())) 

171 

172 model = linear_func(color[goods], name="alpha") + indic( 

173 np.array(lc_selec[zpkey])[goods], name="beta" 

174 ) 

175 solver = RobustLinearSolver(model, y[goods], weights=w[goods]) 

176 x = solver.robust_solution(nsig=3) 

177 model.params.free = x 

178 res = solver.get_res(y[goods], x) 

179 err = np.sqrt(solver.get_cov().diagonal()) 

180 

181 return dict( 

182 { 

183 "y": y, 

184 "x": x, 

185 "color": color, 

186 "res": res, 

187 "err": err, 

188 "mag": mag_ap, 

189 "goods": goods, 

190 "bads": solver.bads, 

191 "w": w, 

192 "model": model(), 

193 "wres": solver.get_wres(x=x), 

194 } 

195 ) 

196 

197# pylint: disable=dangerous-default-value 

198def plot_zpfit_res( 

199 zpfits, 

200 xlabel="mag", 

201 lims=[-0.03, 0.03], 

202 fig=None, 

203 axs=None, 

204): 

205 """Plot zp fit residuals. 

206 

207 

208 :param str xlabel: can be mag or color 

209 :return: fig, ax 

210 """ 

211 bands = list(zpfits.keys()) 

212 

213 nband = len(bands) 

214 nx, ny = (nband, 1) if nband < 4 else (nband // 2 + nband % 2, 2) 

215 

216 if axs is None: 

217 fig, axs = plt.subplots(nx, ny, sharex=True, sharey=True) 

218 axs = axs.flatten() 

219 for ax, band in zip(axs, bands): 

220 _dfit = zpfits[band] 

221 binplot( 

222 np.array(_dfit[xlabel][_dfit["goods"]]), 

223 np.array(_dfit["res"]), 

224 ax=ax, 

225 data=True, 

226 label=band, 

227 ) 

228 ax.grid("on") 

229 ax.set_ylim(*lims) 

230 ax.legend() 

231 axs[-1].set_xlabel(xlabel) 

232 if nband > 3: 

233 axs[-2].set_xlabel(xlabel) 

234 plt.tight_layout() 

235 return fig, axs 

236 

237 

238def zpfit_diagnostic(dfit, nbins=15): 

239 """Plot zpfit diagnostic including rms of the residual compared to 

240 measurement error and chi2. 

241 """ 

242 # rms vs predicted error, chi2 

243 bads = dfit["bads"] 

244 y = dfit["res"] 

245 wres = dfit["wres"] 

246 

247 fig, ax = plt.subplots(2, 2, figsize=(15, 5), sharex="col") 

248 ax = list(ax.flatten()) 

249 for x, (ax0, ax1), xlabel in zip( 

250 [dfit["mag"][dfit["goods"]], dfit["color"][dfit["goods"]]], 

251 [[ax[0], ax[2]], [ax[1], ax[3]]], 

252 ["mag", "color"], 

253 ): 

254 _, xbinned, xerr, index = make_bins(x, nbins) 

255 ngood = np.array([(~bads[e]).sum() for e in index]) 

256 

257 mean_y2 = np.array([(y[e][~bads[e]] ** 2).sum() for e in index]) / ngood 

258 mean_2y = (np.array([y[e][~bads[e]].sum() for e in index]) / ngood) ** 2 

259 chi2 = np.array([(wres[e][~bads[e]] ** 2).sum() for e in index]) / (ngood - 1) 

260 rms = np.sqrt(mean_y2 - mean_2y) 

261 # nmeas = np.array([len(y[e]) for e in index]) 

262 

263 yerr = np.sqrt(1 / dfit["w"][dfit["goods"]] ** 2) 

264 yerr = [yerr[e].mean() for e in index] 

265 

266 ax0.errorbar(xbinned, rms, xerr=xerr, ls="None", marker="+", label="res rms") 

267 ax0.errorbar( 

268 xbinned, yerr, xerr=xerr, ls="None", marker="+", label="predicted errors" 

269 ) 

270 ax1.errorbar( 

271 xbinned, chi2, xerr=xerr, ls="None", marker="+", label="chi2 / dof" 

272 ) 

273 ax1.set_ylim(0, 0.1) 

274 ax1.set_ylim(0.5, 10) 

275 ax1.set_yscale("log") 

276 ax0.legend() 

277 ax1.legend() 

278 ax1.axhline(y=1, color="k", ls="--") 

279 ax1.set_xlabel(xlabel) 

280 return fig, ax 

281 

282 

283def make_bins(x, nbins): 

284 """Define nbins bin in x. 

285 :param array x: x 

286 :param int nbins: number of bins 

287 :return array bins: bins limit 

288 :return array xbinned: binned version of x 

289 :return array xerr: bins size 

290 :return array index: index of x corresponding to each bin 

291 """ 

292 bins = np.linspace(x.min(), x.max() + abs(x.max() * 1e-7), nbins + 1) 

293 yd = np.digitize(x, bins) 

294 index = make_index(yd) 

295 xbinned = 0.5 * (bins[:-1] + bins[1:]) 

296 usedbins = np.array(np.sort(list(set(yd)))) - 1 

297 xbinned = xbinned[usedbins] 

298 bins = bins[usedbins + 1] 

299 xerr = np.array([bins, bins]) - np.array([xbinned, xbinned]) 

300 return bins, xbinned, xerr, index