Coverage for /usr/local/lib/python3.9/site-packages/lccalib/averaging.py: 15%
243 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 12:30 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-23 12:30 +0000
1"""
2Tools to make a single light curve catalog with averaged flux
3"""
4import os
5import itertools
6import ssl
7import pandas as pd
8import numpy as np
9from scipy.stats import norm
10import matplotlib.pyplot as plt
12from astroquery.gaia import GaiaClass
13from astropy.table import vstack
15from saltworks.linearmodels import LinearModel, RobustLinearSolver
16from saltworks.plottools import binplot
17from saltworks.dataproxy import DataProxy
19from .match import match
21#pylint: disable=invalid-name,too-many-locals, too-many-arguments,dangerous-default-value
23def get_gaia_match(stars, offset=0.02, maxq=5000):
24 """Return aligned Gaia and input star subset catalog.
26 TODO: apply pm.
27 """
28 # pylint: disable=protected-access
29 _create_unverified_https_context = ssl._create_unverified_context
30 ssl._create_default_https_context = _create_unverified_https_context
32 gaia = GaiaClass(
33 gaia_tap_server="https://gea.esac.esa.int/",
34 gaia_data_server="https://gea.esac.esa.int/",
35 )
36 ra_min, dec_min = stars["ra"].min() - offset, stars["dec"].min() - offset
37 ra_max, dec_max = stars["ra"].max() + offset, stars["dec"].max() + offset
39 # pylint: disable=line-too-long
40 #query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where has_xp_continuous = 'True' and ra <= {ra_max} and ra >= {ra_min} and dec <= {dec_max} and dec >= {dec_min}"
41 query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where ra <= {ra_max} and ra >= {ra_min} and dec <= {dec_max} and dec >= {dec_min}"
42 # job = gaia.launch_job_async(query, dump_to_file=False)
43 job = gaia.launch_job(query, dump_to_file=False)
44 gaia_ids = job.get_results()
45 if len(gaia_ids) > maxq-100: # split in 4 if close to 5000, but no more
46 gaia_ids = []
47 ra2 = (ra_max+ra_min)/2
48 dec2 = (dec_max+dec_min)/2
49 for _ra_min, _ra_max, _dec_min, _dec_max in [(ra_min, ra2, dec_min, dec2),
50 (ra_min, ra2, dec2, dec_max),
51 (ra2, ra_max, dec_min, dec2),
52 (ra2, ra_max, dec2, dec_max)]:
53 query = f"select TOP {maxq} source_id, ra, dec from gaiadr3.gaia_source where ra <= {_ra_max} and ra >= {_ra_min} and dec <= {_dec_max} and dec >= {_dec_min}"
54 job = gaia.launch_job(query, dump_to_file=False)
55 gaia_ids.append(job.get_results())
56 gaia_ids = vstack(gaia_ids)
57 index = match(gaia_ids, stars, arcsecrad=20)
58 selected_stars = stars[index != -1]
59 selected_ids = gaia_ids[index[index != -1]]
60 return selected_ids.to_pandas(), selected_stars, index
63def cut_from_epoch_number(T, min_d=3):
64 """cut stars with less than 3 epochs"""
65 if min_d < 1:
66 return T
67 nstar = int(T["index"].max() + 1)
68 dates = set(T["mjd"].astype(int))
69 ndate = len(dates)
70 D = np.zeros((nstar, ndate))
71 vdate = T["mjd"].astype(int)
72 for i, d in enumerate(dates):
73 dt = T[vdate == d]
74 D[dt["index"].astype(int), i] += 1
75 ikeep = np.sum(D, axis=1) > min_d
76 #print(("cut %d/%d" % (len(ikeep) - ikeep.sum(), len(ikeep))))
77 T = T[ikeep[T["index"].astype(int)]]
78 return T
80def star_lc_averager(
81 T, star_key="star", flux_key="flux", eflux_key="error", error_floor=0, show=False
82):
83 """Compute average flux and associated errors."""
84 # pylint: disable=E1101
85 # pylint: disable=E1130
86 dp = DataProxy(T, flux=flux_key, eflux=eflux_key)
87 dp.add_field("star", T[star_key].astype(int))
88 dp.make_index("star", intmap=True)
90 weights = 1.0 / np.sqrt(dp.eflux**2 + error_floor)
92 model = LinearModel(list(range(len(dp.nt))), dp.star_index, np.ones_like(dp.flux))
93 solver = RobustLinearSolver(model, np.array(dp.flux), weights=np.array(weights))
94 avg_flux = solver.robust_solution(nsig=3)
95 solver.model.params.free = avg_flux
96 res = solver.get_res(dp.flux)
97 wres = solver.get_wres(avg_flux)
98 ngood = np.bincount(dp.star_index, ~solver.bads)
99 nz = 0
100 while ngood[-1]==0:
101 ngood = ngood[:-1]
102 nz += 1
103 index = dp.star_index[~solver.bads]
104 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood
105 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2
106 chi2 = np.bincount(index, weights=wres[~solver.bads] ** 2) / (ngood - 1)
107 err = np.sqrt(solver.get_cov().diagonal())
108 with np.errstate(divide="ignore", invalid="ignore"):
109 avg_cat = pd.DataFrame(
110 data={
111 "star": dp.star_set if not nz else dp.star_set[:-nz],
112 "flux": avg_flux if not nz else avg_flux[:-nz],
113 "eflux": err if not nz else err[:-nz],
114 "rms": np.sqrt(mean_y2 - mean_2y),
115 "nmeas": np.bincount(index),
116 "chi2": chi2,
117 }
118 )
120 if show:
121 plot_star_averager(avg_cat, res=res, dp=dp, goods=~solver.bads)
122 return avg_cat, index, ~solver.bads
125def add_columns(avg_cat, T, names, index, goods):
126 """Complete average catalog with columns"""
127 N = np.bincount(index)
128 d = {}
129 with np.errstate(divide="ignore", invalid="ignore"):
130 for n in names:
131 avg = np.bincount(index, T[n][goods]) / N
132 d[n] = avg
133 avg_cat = avg_cat.assign(**d)
134 return avg_cat
137def plot_night_averager(avg_cat, single=True, **kwargs):
138 """Show a comparison of residual dispersion and expected errors for
139 the night average fit, stacked over several nights.
140 """
141 if single:
142 return plot_single_night_averager(avg_cat, **kwargs)
143 fig, ax = plt.subplots(2, 2) # , sharex=True, sharey=True)
144 ax = list(ax.flatten())
145 mag = -2.5 * np.log10(avg_cat["flux"])
146 for x, ax0, ax1, xlabel in zip(
147 [avg_cat["mjd"].to_numpy(), mag.to_numpy()],
148 [ax[0], ax[1]],
149 [ax[2], ax[3]],
150 ["mjd", "mag"],
151 ):
152 binplot(
153 x,
154 (avg_cat["eflux"] / avg_cat["flux"]).to_numpy(),
155 label="predicted errors",
156 ax=ax0,
157 data=False,
158 )
160 binplot(
161 x,
162 (avg_cat["rms"] / avg_cat["flux"]).to_numpy(),
163 color="r",
164 label="res rms",
165 ax=ax0,
166 )
167 ax0.grid()
168 ax0.set_ylabel(r"$\sigma_f / f$")
169 ax0.legend()
170 binplot(x, avg_cat["chi2"].to_numpy(), ax=ax1)
171 ax1.set_xlabel(xlabel)
172 ax1.set_ylabel("chi2")
173 ax1.grid()
174 fig.tight_layout()
175 return fig, ax
178def plot_single_night_averager(avg_cat, res=None, dp=None):
179 """Averager control plots."""
180 N = 2 if res is None else 3
181 fig, ax = plt.subplots(N, 1, sharex=True)
182 ax[0].plot(avg_cat["mjd"], avg_cat["rms"] / avg_cat["flux"], "r.", label="res rms")
183 ax[0].plot(
184 avg_cat["mjd"],
185 avg_cat["eflux"] * np.sqrt(avg_cat["nmeas"]) / avg_cat["flux"],
186 "k.",
187 label=r"$\sigma_f \sqrt{N} / f$",
188 )
189 ax[0].grid()
190 ax[0].set_ylabel(r"$\sigma_f / f$")
191 ax[0].legend()
192 ax[1].plot(avg_cat["mjd"], avg_cat["chi2"], "k.")
193 ax[1].set_xlabel("mjd")
194 ax[1].set_ylabel("chi2")
195 ax[1].grid()
196 if res is not None:
197 binplot(dp.mjd, res, robust=True, ax=ax[2])
198 ax[2].plot(dp.mjd, res / dp.flux, "k.")
199 ax[2].set_ylabel("res / f")
200 ax[2].grid()
201 return fig, ax
204def plot_star_averager(avg_cat, res=None, dp=None, goods=None):
205 """Averager control plots."""
206 N = 2 if res is None else 3
207 fig, ax = plt.subplots(N, 1, sharex=True, layout="constrained")
209 m = -2.5 * np.log10(np.array(avg_cat["flux"]))
210 binplot(m, avg_cat["rms"] / avg_cat["flux"], color="r", label="res rms", ax=ax[0])
211 binplot(
212 m,
213 avg_cat["eflux"] * np.sqrt(avg_cat["nmeas"]) / avg_cat["flux"],
214 color="k",
215 label=r"$\sigma_f \sqrt{N}$",
216 ax=ax[0],
217 )
218 ax[0].grid()
219 ax[0].set_ylabel(r"$\sigma_f / f$")
220 ax[0].legend()
222 ax_histy = ax[1].inset_axes([1.05, 0, 0.25, 1], sharey=ax[1])
223 ok = np.isfinite(avg_cat["chi2"])
224 ax_histy.hist(
225 avg_cat["chi2"][ok],
226 bins=10,
227 density=True,
228 histtype="step",
229 color="black",
230 orientation="horizontal",
231 )
232 binplot(m[ok], np.array(avg_cat["chi2"])[ok], ax=ax[1])
233 ax[1].set_xlabel("mag")
234 ax[1].set_ylabel("chi2")
235 ax[1].grid()
237 if res is not None:
238 ax_histy = ax[2].inset_axes([1.05, 0, 0.25, 1], sharey=ax[2])
239 #x = -2.5 * np.log10(dp.flux)[goods]
240 y = res[goods]
241 res_min, res_max = -10000.0, 10000.0
242 xx = np.linspace(res_min, res_max, 1000)
243 me, sc = norm.fit(y)
244 ax_histy.tick_params(axis="y", labelleft=False)
245 ax_histy.hist(
246 y,
247 bins=50,
248 density=True,
249 histtype="step",
250 color="black",
251 orientation="horizontal",
252 )
253 ax_histy.plot(
254 norm.pdf(xx, loc=me, scale=sc), xx, color="black", label=f"{int(sc)}"
255 )
256 ax_histy.legend(fontsize=8)
258 binplot(-2.5 * np.log10(dp.flux), res, robust=True, ax=ax[2])
259 ax[2].set_ylabel("res")
260 ax[2].set_ylim(res_min, res_max)
261 ax[2].grid()
262 fig.tight_layout(h_pad=0.1)
263 return fig, ax
266def night_averager(
267 T, mjd_key="mjd", flux_key="flux", eflux_key="error", error_floor=0, show=False
268):
269 """Compute mean flux per night.
271 :param recarray T: input catalog
272 :return array avg_cat: mjd, flux, flux_err
273 :return array indices: mjd indices
274 """
275 # pylint: disable=E1101
276 # pylint: disable=E1130
278 dp = DataProxy(T, flux=flux_key, eflux=eflux_key)
279 dp.add_field("mjd", T[mjd_key].astype(int))
280 dp.make_index("mjd", intmap=True)
282 weights = 1.0 / np.sqrt(dp.eflux**2 + error_floor)
283 weights[~np.isfinite(weights)] = 0
285 model = LinearModel(list(range(len(dp.nt))), dp.mjd_index, np.ones_like(dp.flux))
286 solver = RobustLinearSolver(model, np.array(dp.flux), weights=np.array(weights))
287 avg_flux = solver.robust_solution(nsig=3)
289 solver.model.params.free = avg_flux
290 res = solver.get_res(dp.flux)
291 wres = solver.get_wres(avg_flux)
292 index = dp.mjd_index[~solver.bads]
293 ngood = np.bincount(dp.mjd_index, ~solver.bads)
294 with np.errstate(divide="ignore", invalid="ignore"):
295 mean_y2 = np.bincount(index, weights=dp.flux[~solver.bads] ** 2) / ngood
296 mean_2y = (np.bincount(index, weights=dp.flux[~solver.bads]) / ngood) ** 2
297 avg_cat = pd.DataFrame(
298 data={
299 "mjd": [
300 float(np.mean(T[mjd_key][dp.mjd_index == i]))
301 for i in range(len(dp.mjd_set))
302 ],
303 "flux": avg_flux,
304 "eflux": np.sqrt(solver.get_cov().diagonal()),
305 "rms": np.sqrt(mean_y2 - mean_2y),
306 "nmeas": np.bincount(index),
307 "chi2": np.bincount(index, weights=wres[~solver.bads] ** 2)
308 / (ngood - 1),
309 }
310 )
312 if show: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 plot_single_night_averager(avg_cat, res=res, dp=dp)#, goods=~solver.bads)
315 return avg_cat, index, ~solver.bads
318def chain_averaging(lccat, logger, gaia_match=False, extra_cols=[]):
319 """Chain night and star averaging.
321 :param recarray lccat: star light curve catalog
322 :return recarray cat: night averaged catalog
323 :return recaraay star_lc_cat: star averaged catalog
324 """
325 if not isinstance(lccat, np.recarray):
326 lccat = lccat.to_records()
327 star_set = list(set(lccat["star"]))
328 cat = []
329 for s in star_set:
330 T = lccat[lccat["star"] == s]
331 T = T[T["ra"]!=0] #todo remove this as soon as mklc is fixed
332 try:
333 cat_, idx, goods = night_averager(
334 T,
335 mjd_key="mjd",
336 flux_key="flux",
337 eflux_key="error",
338 show=False,
339 )
340 cat_ = cat_.assign(star=int(s) * np.ones((len(cat_))).astype(int))
341 cat_ = add_columns(cat_, T, ["ra", "dec"]+extra_cols, idx, goods)
342 cat.append(cat_)
343 except: #pylint: disable=bare-except
344 logger.warning(
345 f"Night averager failed for star {s}"
346 f", number of epochs is {len(set(T['mjd'].astype(int)))}"
347 )
348 cat_ = pd.DataFrame(
349 data={
350 "mjd": T["mjd"],
351 "flux": T["flux"],
352 "eflux": T["error"],
353 "star": int(s) * np.ones((len(T))).astype(int),
354 "rms":np.ones((len(T)))*np.nan,
355 "nmeas":np.ones((len(T))),
356 "chi2":np.ones((len(T)))*np.nan,
357 })
358 d = {}
359 for k in ["ra", "dec"]+extra_cols:
360 d[k] = T[k]
361 cat_ = cat_.assign(**d)
362 cat.append(cat_)
364 cat = pd.concat(cat)
365 cat = cat[np.isfinite(cat["eflux"])]
366 cat = cat[np.isfinite(cat["ra"])]
368 # cat = cut_from_epoch_number(cat.to_records(), min_d=min_d)
369 star_lc_cat, idx, goods = star_lc_averager(
370 cat.to_records(),
371 star_key="star",
372 flux_key="flux",
373 eflux_key="eflux",
374 show=False,
375 )
377 star_lc_cat = add_columns(star_lc_cat, cat, ["ra", "dec"]+extra_cols, idx, goods)
379 if gaia_match:
380 selected_ids, star_lc_cat, _ = get_gaia_match(
381 star_lc_cat, offset=0.02, maxq=5000
382 )
383 star_lc_cat = star_lc_cat.astype({"star": "int64"})
384 star_lc_cat["star"] = selected_ids["SOURCE_ID"].astype("int64").values
385 return cat, star_lc_cat
388def lc_stack(
389 d_iterator, fn_provider, bands, cols=["flux", "eflux", "rms", "nmeas", "chi2"]
390):
391 """Stack all catalogs in a single one
393 :param dict d_iterator: keys and values indexing catalogs
394 :param func fn_provider: function which return a catalog filename for a given set of key/value
395 :param str list bands: list of band names
396 :param str list cols: list of columns to stack, named col_band in stacked catalog
397 :return recaraay stacked: stacked catalog
398 """
400 # pylint: disable=no-member
402 stacked = []
404 for k_i in itertools.product(*d_iterator.values()):
405 kwargs = dict(zip(d_iterator.keys(), k_i))
407 lc_catalog = []
408 for band in bands:
409 kwargs["band"] = band
410 cat_ = fn_provider(**kwargs)
411 if isinstance(cat_, str):
412 if os.path.exists(cat_):
413 cat_ = pd.read_parquet(cat_)
414 else:
415 continue
416 cat_ = cat_.assign(band=np.full(len(cat_), band))
417 lc_catalog.append(cat_)
418 if len(lc_catalog)==0:
419 continue
420 lc_catalog = pd.concat(lc_catalog)
422 # flux per band as columns
423 dp = DataProxy(lc_catalog)
424 dp.add_field("star", lc_catalog["star"].astype(int))
425 dp.make_index("star") # , intmap=True)
426 reshaped = pd.DataFrame.from_dict({"star": dp.star_set})
428 dkey = {}
429 for _k, _v in kwargs.items():
430 if _k != "band":
431 dkey[_k] = [_v] * len(reshaped)
432 reshaped = reshaped.assign(**dkey)
434 reshaped = add_columns(
435 reshaped,
436 lc_catalog,
437 ["ra", "dec"],
438 dp.star_index,
439 np.ones((len(dp.star_index))).astype("bool"),
440 )
441 N = len(reshaped)
443 for band in bands:
444 selec = lc_catalog["band"] == band
445 nancols = dict(
446 zip(
447 [c + f"_{band}" for c in cols],
448 [np.ones((N)) * np.nan for i in cols],
449 )
450 )
451 reshaped = reshaped.assign(**nancols)
452 for l in cols:
453 k = l + f"_{band}"
454 i = reshaped.columns.get_loc(k)
455 reshaped.iloc[dp.star_index[selec], i] = lc_catalog[l][selec]
456 stacked.append(reshaped)
457 stacked = pd.concat(stacked)
458 return stacked