Coverage for src/lccalib/averaging.py: 0%

243 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-23 12:30 +0000

1""" 

2Tools to make a single light curve catalog with averaged flux 

3""" 

4import os 

5import itertools 

6import ssl 

7import pandas as pd 

8import numpy as np 

9from scipy.stats import norm 

10import matplotlib.pyplot as plt 

11 

12from astroquery.gaia import GaiaClass 

13from astropy.table import vstack 

14 

15from saltworks.linearmodels import LinearModel, RobustLinearSolver 

16from saltworks.plottools import binplot 

17from saltworks.dataproxy import DataProxy 

18 

19from .match import match 

20 

21#pylint: disable=invalid-name,too-many-locals, too-many-arguments,dangerous-default-value 

22 

23def get_gaia_match(stars, offset=0.02, maxq=5000): 

24 """Return aligned Gaia and input star subset catalog. 

25 

26 TODO: apply pm. 

27 """ 

28 # pylint: disable=protected-access 

29 _create_unverified_https_context = ssl._create_unverified_context 

30 ssl._create_default_https_context = _create_unverified_https_context 

31 

32 gaia = GaiaClass( 

33 gaia_tap_server="https://gea.esac.esa.int/", 

34 gaia_data_server="https://gea.esac.esa.int/", 

35 ) 

36 ra_min, dec_min = stars["ra"].min() - offset, stars["dec"].min() - offset 

37 ra_max, dec_max = stars["ra"].max() + offset, stars["dec"].max() + offset 

38 

39 # pylint: disable=line-too-long 

40 #query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where has_xp_continuous = 'True' and ra <= {ra_max} and ra >= {ra_min} and dec <= {dec_max} and dec >= {dec_min}" 

41 query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where ra <= {ra_max} and ra >= {ra_min} and dec <= {dec_max} and dec >= {dec_min}" 

42 # job = gaia.launch_job_async(query, dump_to_file=False) 

43 job = gaia.launch_job(query, dump_to_file=False) 

44 gaia_ids = job.get_results() 

45 if len(gaia_ids) > maxq-100: # split in 4 if close to 5000, but no more 

46 gaia_ids = [] 

47 ra2 = (ra_max+ra_min)/2 

48 dec2 = (dec_max+dec_min)/2 

49 for _ra_min, _ra_max, _dec_min, _dec_max in [(ra_min, ra2, dec_min, dec2), 

50 (ra_min, ra2, dec2, dec_max), 

51 (ra2, ra_max, dec_min, dec2), 

52 (ra2, ra_max, dec2, dec_max)]: 

53 query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where ra <= {_ra_max} and ra >= {_ra_min} and dec <= {_dec_max} and dec >= {_dec_min}" 

54 job = gaia.launch_job(query, dump_to_file=False) 

55 gaia_ids.append(job.get_results()) 

56 gaia_ids = vstack(gaia_ids) 

57 index = match(gaia_ids, stars, arcsecrad=20) 

58 selected_stars = stars[index != -1] 

59 selected_ids = gaia_ids[index[index != -1]] 

60 return selected_ids.to_pandas(), selected_stars, index 

61 

62 

63def cut_from_epoch_number(T, min_d=3): 

64 """cut stars with less than 3 epochs""" 

65 if min_d < 1: 

66 return T 

67 nstar = int(T["index"].max() + 1) 

68 dates = set(T["mjd"].astype(int)) 

69 ndate = len(dates) 

70 D = np.zeros((nstar, ndate)) 

71 vdate = T["mjd"].astype(int) 

72 for i, d in enumerate(dates): 

73 dt = T[vdate == d] 

74 D[dt["index"].astype(int), i] += 1 

75 ikeep = np.sum(D, axis=1) > min_d 

76 #print(("cut %d/%d" % (len(ikeep) - ikeep.sum(), len(ikeep)))) 

77 T = T[ikeep[T["index"].astype(int)]] 

78 return T 

79 

80def star_lc_averager( 

81 T, star_key="star", flux_key="flux", eflux_key="error", error_floor=0, show=False 

82): 

83 """Compute average flux and associated errors.""" 

84 # pylint: disable=E1101 

85 # pylint: disable=E1130 

86 dp = DataProxy(T, flux=flux_key, eflux=eflux_key) 

87 dp.add_field("star", T[star_key].astype(int)) 

88 dp.make_index("star", intmap=True) 

89 

90 weights = 1.0 / np.sqrt(dp.eflux**2 + error_floor) 

91 

92 model = LinearModel(list(range(len(dp.nt))), dp.star_index, np.ones_like(dp.flux)) 

93 solver = RobustLinearSolver(model, np.array(dp.flux), weights=np.array(weights)) 

94 avg_flux = solver.robust_solution(nsig=3) 

95 solver.model.params.free = avg_flux 

96 res = solver.get_res(dp.flux) 

97 wres = solver.get_wres(avg_flux) 

98 ngood = np.bincount(dp.star_index, ~solver.bads) 

99 nz = 0 

100 while ngood[-1]==0: 

101 ngood = ngood[:-1] 

102 nz += 1 

103 index = dp.star_index[~solver.bads] 

104 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood 

105 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2 

106 chi2 = np.bincount(index, weights=wres[~solver.bads] ** 2) / (ngood - 1) 

107 err = np.sqrt(solver.get_cov().diagonal()) 

108 with np.errstate(divide="ignore", invalid="ignore"): 

109 avg_cat = pd.DataFrame( 

110 data={ 

111 "star": dp.star_set if not nz else dp.star_set[:-nz], 

112 "flux": avg_flux if not nz else avg_flux[:-nz], 

113 "eflux": err if not nz else err[:-nz], 

114 "rms": np.sqrt(mean_y2 - mean_2y), 

115 "nmeas": np.bincount(index), 

116 "chi2": chi2, 

117 } 

118 ) 

119 

120 if show: 

121 plot_star_averager(avg_cat, res=res, dp=dp, goods=~solver.bads) 

122 return avg_cat, index, ~solver.bads 

123 

124 

125def add_columns(avg_cat, T, names, index, goods): 

126 """Complete average catalog with columns""" 

127 N = np.bincount(index) 

128 d = {} 

129 with np.errstate(divide="ignore", invalid="ignore"): 

130 for n in names: 

131 avg = np.bincount(index, T[n][goods]) / N 

132 d[n] = avg 

133 avg_cat = avg_cat.assign(**d) 

134 return avg_cat 

135 

136 

137def plot_night_averager(avg_cat, single=True, **kwargs): 

138 """Show a comparison of residual dispersion and expected errors for 

139 the night average fit, stacked over several nights. 

140 """ 

141 if single: 

142 return plot_single_night_averager(avg_cat, **kwargs) 

143 fig, ax = plt.subplots(2, 2) # , sharex=True, sharey=True) 

144 ax = list(ax.flatten()) 

145 mag = -2.5 * np.log10(avg_cat["flux"]) 

146 for x, ax0, ax1, xlabel in zip( 

147 [avg_cat["mjd"].to_numpy(), mag.to_numpy()], 

148 [ax[0], ax[1]], 

149 [ax[2], ax[3]], 

150 ["mjd", "mag"], 

151 ): 

152 binplot( 

153 x, 

154 (avg_cat["eflux"] / avg_cat["flux"]).to_numpy(), 

155 label="predicted errors", 

156 ax=ax0, 

157 data=False, 

158 ) 

159 

160 binplot( 

161 x, 

162 (avg_cat["rms"] / avg_cat["flux"]).to_numpy(), 

163 color="r", 

164 label="res rms", 

165 ax=ax0, 

166 ) 

167 ax0.grid() 

168 ax0.set_ylabel(r"$\sigma_f / f$") 

169 ax0.legend() 

170 binplot(x, avg_cat["chi2"].to_numpy(), ax=ax1) 

171 ax1.set_xlabel(xlabel) 

172 ax1.set_ylabel("chi2") 

173 ax1.grid() 

174 fig.tight_layout() 

175 return fig, ax 

176 

177 

178def plot_single_night_averager(avg_cat, res=None, dp=None): 

179 """Averager control plots.""" 

180 N = 2 if res is None else 3 

181 fig, ax = plt.subplots(N, 1, sharex=True) 

182 ax[0].plot(avg_cat["mjd"], avg_cat["rms"] / avg_cat["flux"], "r.", label="res rms") 

183 ax[0].plot( 

184 avg_cat["mjd"], 

185 avg_cat["eflux"] * np.sqrt(avg_cat["nmeas"]) / avg_cat["flux"], 

186 "k.", 

187 label=r"$\sigma_f \sqrt{N} / f$", 

188 ) 

189 ax[0].grid() 

190 ax[0].set_ylabel(r"$\sigma_f / f$") 

191 ax[0].legend() 

192 ax[1].plot(avg_cat["mjd"], avg_cat["chi2"], "k.") 

193 ax[1].set_xlabel("mjd") 

194 ax[1].set_ylabel("chi2") 

195 ax[1].grid() 

196 if res is not None: 

197 binplot(dp.mjd, res, robust=True, ax=ax[2]) 

198 ax[2].plot(dp.mjd, res / dp.flux, "k.") 

199 ax[2].set_ylabel("res / f") 

200 ax[2].grid() 

201 return fig, ax 

202 

203 

204def plot_star_averager(avg_cat, res=None, dp=None, goods=None): 

205 """Averager control plots.""" 

206 N = 2 if res is None else 3 

207 fig, ax = plt.subplots(N, 1, sharex=True, layout="constrained") 

208 

209 m = -2.5 * np.log10(np.array(avg_cat["flux"])) 

210 binplot(m, avg_cat["rms"] / avg_cat["flux"], color="r", label="res rms", ax=ax[0]) 

211 binplot( 

212 m, 

213 avg_cat["eflux"] * np.sqrt(avg_cat["nmeas"]) / avg_cat["flux"], 

214 color="k", 

215 label=r"$\sigma_f \sqrt{N}$", 

216 ax=ax[0], 

217 ) 

218 ax[0].grid() 

219 ax[0].set_ylabel(r"$\sigma_f / f$") 

220 ax[0].legend() 

221 

222 ax_histy = ax[1].inset_axes([1.05, 0, 0.25, 1], sharey=ax[1]) 

223 ok = np.isfinite(avg_cat["chi2"]) 

224 ax_histy.hist( 

225 avg_cat["chi2"][ok], 

226 bins=10, 

227 density=True, 

228 histtype="step", 

229 color="black", 

230 orientation="horizontal", 

231 ) 

232 binplot(m[ok], np.array(avg_cat["chi2"])[ok], ax=ax[1]) 

233 ax[1].set_xlabel("mag") 

234 ax[1].set_ylabel("chi2") 

235 ax[1].grid() 

236 

237 if res is not None: 

238 ax_histy = ax[2].inset_axes([1.05, 0, 0.25, 1], sharey=ax[2]) 

239 #x = -2.5 * np.log10(dp.flux)[goods] 

240 y = res[goods] 

241 res_min, res_max = -10000.0, 10000.0 

242 xx = np.linspace(res_min, res_max, 1000) 

243 me, sc = norm.fit(y) 

244 ax_histy.tick_params(axis="y", labelleft=False) 

245 ax_histy.hist( 

246 y, 

247 bins=50, 

248 density=True, 

249 histtype="step", 

250 color="black", 

251 orientation="horizontal", 

252 ) 

253 ax_histy.plot( 

254 norm.pdf(xx, loc=me, scale=sc), xx, color="black", label=f"{int(sc)}" 

255 ) 

256 ax_histy.legend(fontsize=8) 

257 

258 binplot(-2.5 * np.log10(dp.flux), res, robust=True, ax=ax[2]) 

259 ax[2].set_ylabel("res") 

260 ax[2].set_ylim(res_min, res_max) 

261 ax[2].grid() 

262 fig.tight_layout(h_pad=0.1) 

263 return fig, ax 

264 

265 

266def night_averager( 

267 T, mjd_key="mjd", flux_key="flux", eflux_key="error", error_floor=0, show=False 

268): 

269 """Compute mean flux per night. 

270 

271 :param recarray T: input catalog 

272 :return array avg_cat: mjd, flux, flux_err 

273 :return array indices: mjd indices 

274 """ 

275 # pylint: disable=E1101 

276 # pylint: disable=E1130 

277 

278 dp = DataProxy(T, flux=flux_key, eflux=eflux_key) 

279 dp.add_field("mjd", T[mjd_key].astype(int)) 

280 dp.make_index("mjd", intmap=True) 

281 

282 weights = 1.0 / np.sqrt(dp.eflux**2 + error_floor) 

283 weights[~np.isfinite(weights)] = 0 

284 

285 model = LinearModel(list(range(len(dp.nt))), dp.mjd_index, np.ones_like(dp.flux)) 

286 solver = RobustLinearSolver(model, np.array(dp.flux), weights=np.array(weights)) 

287 avg_flux = solver.robust_solution(nsig=3) 

288 

289 solver.model.params.free = avg_flux 

290 res = solver.get_res(dp.flux) 

291 wres = solver.get_wres(avg_flux) 

292 index = dp.mjd_index[~solver.bads] 

293 ngood = np.bincount(dp.mjd_index, ~solver.bads) 

294 with np.errstate(divide="ignore", invalid="ignore"): 

295 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood 

296 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2 

297 avg_cat = pd.DataFrame( 

298 data={ 

299 "mjd": [ 

300 float(np.mean(T[mjd_key][dp.mjd_index == i])) 

301 for i in range(len(dp.mjd_set)) 

302 ], 

303 "flux": avg_flux, 

304 "eflux": np.sqrt(solver.get_cov().diagonal()), 

305 "rms": np.sqrt(mean_y2 - mean_2y), 

306 "nmeas": np.bincount(index), 

307 "chi2": np.bincount(index, weights=wres[~solver.bads] ** 2) 

308 / (ngood - 1), 

309 } 

310 ) 

311 

312 if show: 

313 plot_single_night_averager(avg_cat, res=res, dp=dp)#, goods=~solver.bads) 

314 

315 return avg_cat, index, ~solver.bads 

316 

317 

318def chain_averaging(lccat, logger, gaia_match=False, extra_cols=[]): 

319 """Chain night and star averaging. 

320 

321 :param recarray lccat: star light curve catalog 

322 :return recarray cat: night averaged catalog 

323 :return recaraay star_lc_cat: star averaged catalog 

324 """ 

325 if not isinstance(lccat, np.recarray): 

326 lccat = lccat.to_records() 

327 star_set = list(set(lccat["star"])) 

328 cat = [] 

329 for s in star_set: 

330 T = lccat[lccat["star"] == s] 

331 T = T[T["ra"]!=0] #todo remove this as soon as mklc is fixed 

332 try: 

333 cat_, idx, goods = night_averager( 

334 T, 

335 mjd_key="mjd", 

336 flux_key="flux", 

337 eflux_key="error", 

338 show=False, 

339 ) 

340 cat_ = cat_.assign(star=int(s) * np.ones((len(cat_))).astype(int)) 

341 cat_ = add_columns(cat_, T, ["ra", "dec"]+extra_cols, idx, goods) 

342 cat.append(cat_) 

343 except: #pylint: disable=bare-except 

344 logger.warning( 

345 f"Night averager failed for star {s}" 

346 f", number of epochs is {len(set(T['mjd'].astype(int)))}" 

347 ) 

348 cat_ = pd.DataFrame( 

349 data={ 

350 "mjd": T["mjd"], 

351 "flux": T["flux"], 

352 "eflux": T["error"], 

353 "star": int(s) * np.ones((len(T))).astype(int), 

354 "rms":np.ones((len(T)))*np.nan, 

355 "nmeas":np.ones((len(T))), 

356 "chi2":np.ones((len(T)))*np.nan, 

357 }) 

358 d = {} 

359 for k in ["ra", "dec"]+extra_cols: 

360 d[k] = T[k] 

361 cat_ = cat_.assign(**d) 

362 cat.append(cat_) 

363 

364 cat = pd.concat(cat) 

365 cat = cat[np.isfinite(cat["eflux"])] 

366 cat = cat[np.isfinite(cat["ra"])] 

367 

368 # cat = cut_from_epoch_number(cat.to_records(), min_d=min_d) 

369 star_lc_cat, idx, goods = star_lc_averager( 

370 cat.to_records(), 

371 star_key="star", 

372 flux_key="flux", 

373 eflux_key="eflux", 

374 show=False, 

375 ) 

376 

377 star_lc_cat = add_columns(star_lc_cat, cat, ["ra", "dec"]+extra_cols, idx, goods) 

378 

379 if gaia_match: 

380 selected_ids, star_lc_cat, _ = get_gaia_match( 

381 star_lc_cat, offset=0.02, maxq=5000 

382 ) 

383 star_lc_cat = star_lc_cat.astype({"star": "int64"}) 

384 star_lc_cat["star"] = selected_ids["SOURCE_ID"].astype("int64").values 

385 return cat, star_lc_cat 

386 

387 

388def lc_stack( 

389 d_iterator, fn_provider, bands, cols=["flux", "eflux", "rms", "nmeas", "chi2"] 

390): 

391 """Stack all catalogs in a single one 

392 

393 :param dict d_iterator: keys and values indexing catalogs 

394 :param func fn_provider: function which return a catalog filename for a given set of key/value 

395 :param str list bands: list of band names 

396 :param str list cols: list of columns to stack, named col_band in stacked catalog 

397 :return recaraay stacked: stacked catalog 

398 """ 

399 

400 # pylint: disable=no-member 

401 

402 stacked = [] 

403 

404 for k_i in itertools.product(*d_iterator.values()): 

405 kwargs = dict(zip(d_iterator.keys(), k_i)) 

406 

407 lc_catalog = [] 

408 for band in bands: 

409 kwargs["band"] = band 

410 cat_ = fn_provider(**kwargs) 

411 if isinstance(cat_, str): 

412 if os.path.exists(cat_): 

413 cat_ = pd.read_parquet(cat_) 

414 else: 

415 continue 

416 cat_ = cat_.assign(band=np.full(len(cat_), band)) 

417 lc_catalog.append(cat_) 

418 if len(lc_catalog)==0: 

419 continue 

420 lc_catalog = pd.concat(lc_catalog) 

421 

422 # flux per band as columns 

423 dp = DataProxy(lc_catalog) 

424 dp.add_field("star", lc_catalog["star"].astype(int)) 

425 dp.make_index("star") # , intmap=True) 

426 reshaped = pd.DataFrame.from_dict({"star": dp.star_set}) 

427 

428 dkey = {} 

429 for _k, _v in kwargs.items(): 

430 if _k != "band": 

431 dkey[_k] = [_v] * len(reshaped) 

432 reshaped = reshaped.assign(**dkey) 

433 

434 reshaped = add_columns( 

435 reshaped, 

436 lc_catalog, 

437 ["ra", "dec"], 

438 dp.star_index, 

439 np.ones((len(dp.star_index))).astype("bool"), 

440 ) 

441 N = len(reshaped) 

442 

443 for band in bands: 

444 selec = lc_catalog["band"] == band 

445 nancols = dict( 

446 zip( 

447 [c + f"_{band}" for c in cols], 

448 [np.ones((N)) * np.nan for i in cols], 

449 ) 

450 ) 

451 reshaped = reshaped.assign(**nancols) 

452 for l in cols: 

453 k = l + f"_{band}" 

454 i = reshaped.columns.get_loc(k) 

455 reshaped.iloc[dp.star_index[selec], i] = lc_catalog[l][selec] 

456 stacked.append(reshaped) 

457 stacked = pd.concat(stacked) 

458 return stacked