#!/usr/bin/env python3
# Copyright 2023 Jonas Beck
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import warnings
from typing import Callable, Dict, Optional
import numpy as np
from matplotlib.pyplot import Axes
from numpy import ndarray
from scipy.integrate import cumulative_trapezoid
from scipy.optimize import curve_fit
import ephyspy.allen_sdk.ephys_features as ft
from ephyspy.features.base import SweepFeature
from ephyspy.features.utils import (
FeatureError,
fetch_available_fts,
get_sweep_burst_metrics,
get_sweep_sag_idxs,
has_rebound,
has_spikes,
has_stimulus,
is_hyperpol,
median_idx,
where_stimulus,
)
from ephyspy.utils import (
is_sweep_feature,
parse_desc,
relabel_line,
unpack,
where_between,
)
[docs]def available_sweep_features(
compute_at_init: bool = False, store_diagnostics: bool = False
) -> Dict[str, SweepFeature]:
"""Return a dictionary of all implemented sweep features.
Looks for all classes that inherit from SweepFeature and returns a dictionary
of all available features. If compute_at_init is True, the features are
computed at initialization.
Args:
compute_at_init (bool, optional): If True, the features are computed at
initialization. Defaults to False.
store_diagnostics (bool, optional): If True, the features are computed
with diagnostics. Defaults to False.
Returns:
dict[str, SweepFeature]: Dictionary of all available spike features.
"""
all_features = fetch_available_fts()
features = {ft.__name__.lower(): ft for ft in all_features if is_sweep_feature(ft)}
features = {k.replace("sweep_", ""): v for k, v in features.items()}
if any((compute_at_init, store_diagnostics)):
return {
k: lambda *args, **kwargs: v(
*args,
compute_at_init=compute_at_init,
store_diagnostics=store_diagnostics,
**kwargs,
)
for k, v in features.items()
}
else:
return features
[docs]class NullSweepFeature(SweepFeature):
"""Dummy sweep level feature.
Dummy feature that can be used as a placeholder to compute sweepset level
features using `SweepSetFeature` if no sweep level feature for it is available.
depends on: /.
description: Only the corresponding sweepset level feature exsits.
units: /."""
def __init__(self, data=None, compute_at_init=True, name=None):
super().__init__(data, compute_at_init, name=name)
def _compute(self, recompute=False, store_diagnostics=True):
return
[docs]class Sweep_Stim_amp(SweepFeature):
"""Extract sweep level stimulus ampltiude feature.
depends on: /.
description: maximum amplitude of stimulus.
units: pA."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
idx = np.argmax(abs(self.data.i).T, axis=0)
if store_diagnostics:
self._update_diagnostics({"idx": idx, "t": self.data.t[idx]})
return self.data.i[idx]
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
i_amp = self.value
t = self.diagnostics["t"]
ax.plot(
t,
i_amp,
"x",
label=self.name,
**kwargs,
)
return ax
[docs]class Sweep_Stim_onset(SweepFeature):
"""Extract sweep level stimulus onset feature.
depends on: /.
description: time of stimulus onset.
units: s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
stim_onset = float("nan")
if has_stimulus(self.data):
where_stim = where_stimulus(self.data)
stim_onset = self.data.t[where_stim][0]
i_onset = self.data.i[where_stim][0]
idx_onset = np.arange(len(where_stim))[where_stim][0]
if store_diagnostics:
self._update_diagnostics(
{
"i_onset": i_onset,
"where_stim": where_stim,
"idx_onset": idx_onset,
}
)
return stim_onset
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
onset = self.value
i_onset = self.diagnostics["i_onset"]
ax.plot(onset, i_onset, "x", label=self.name, **kwargs)
return ax
[docs]class Sweep_Stim_end(SweepFeature):
"""Extract sweep level stimulus end feature.
depends on: /.
description: time of stimulus end.
units: s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
stim_end = float("nan")
if has_stimulus(self.data):
where_stim = where_stimulus(self.data)
stim_end = self.data.t[where_stim][-1]
i_end = self.data.i[where_stim][-1]
idx_end = np.arange(len(where_stim))[where_stim][-1]
if store_diagnostics:
self._update_diagnostics(
{"i_end": i_end, "where_stim": where_stim, "idx_end": idx_end}
)
return stim_end
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
end = self.value
i_end = self.diagnostics["i_end"]
ax.plot(end, i_end, "x", label=self.name, **kwargs)
return ax
[docs]class Sweep_Num_AP(SweepFeature):
"""Extract sweep level spike count feature.
depends on: stim_onset, stim_end.
description: # peaks during stimulus.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
onset = self.lookup_sweep_feature("stim_onset")
end = self.lookup_sweep_feature("stim_end")
stim_window = where_between(peak_t, onset, end)
peak_i = self.lookup_spike_feature("peak_index")[stim_window]
num_ap = len(peak_i)
if num_ap <= 0:
num_ap = float("nan")
if store_diagnostics:
peak_t = peak_t[stim_window]
peak_v = self.lookup_spike_feature("peak_v")[stim_window]
self._update_diagnostics(
{
"peak_i": peak_i,
"peak_t": peak_t,
"peak_v": peak_v,
}
)
return num_ap
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
peak_t, peak_v = unpack(self.diagnostics, ["peak_t", "peak_v"])
ax.plot(peak_t, peak_v, "x", label=self.name, **kwargs)
return ax
[docs]class Sweep_AP_freq(SweepFeature):
"""Extract sweep level spike rate feature.
depends on: numap.
description: # peaks during stimulus / stimulus duration.
units: Hz."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
num_ap = self.lookup_sweep_feature("num_ap", recompute=recompute)
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
ap_freq = num_ap / (end - onset)
if store_diagnostics:
self._update_diagnostics(
{"ap_freq": ap_freq, "num_ap": num_ap, "onset": onset, "end": end}
)
return ap_freq
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
num_ap = self.lookup_sweep_feature("num_ap", return_value=False)
ax = num_ap.plot(ax=ax, **kwargs)
return ax
[docs]class Sweep_AP_latency(SweepFeature):
"""Extract sweep level ap_latency feature.
depends on: stim_onset.
description: time of first spike after stimulus onset.
units: s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_latency = float("nan")
if has_stimulus(self.data):
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
thresh_t = self.lookup_spike_feature("threshold_t", recompute=recompute)
thresholds = self.lookup_spike_feature("threshold_v", recompute=recompute)
stim_window = where_between(thresh_t, onset, end)
thresh_t_stim = thresh_t[stim_window]
if len(thresh_t_stim) > 0:
v_first_spike = thresholds[stim_window][0]
t_first_spike = thresh_t_stim[0]
ap_latency = t_first_spike - onset
if store_diagnostics:
self._update_diagnostics(
{
"onset": onset,
"end": end,
"spike_times_during_stim": thresh_t_stim,
"t_first": t_first_spike,
"v_first": v_first_spike,
}
)
return ap_latency
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
v_first, t_first, onset = unpack(
self.diagnostics, ["v_first", "t_first", "onset"]
)
ax.hlines(v_first, onset, t_first, label=self.name, **kwargs)
return ax
[docs]class Sweep_V_baseline(SweepFeature):
"""Extract sweep level baseline voltage feature.
depends on: stim_onset.
description: average voltage in baseline_interval (in s) before stimulus onset.
units: mV."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_baseline_avg = float("nan")
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
if np.isnan(onset):
baseline_interval = float("nan")
where_baseline = np.ones_like(self.data.t, dtype=bool) # for I=0pA
else:
baseline_interval = self.data.baseline_interval
where_baseline = where_between(
self.data.t, onset - baseline_interval, onset
)
t_baseline = self.data.t[where_baseline]
v_baseline = self.data.v[where_baseline]
v_baseline_avg = np.mean(v_baseline)
# v_baseline_avg = sweep._get_baseline_voltage() # bad since start is set to t[0]
if store_diagnostics:
self._update_diagnostics(
{
"where_baseline": where_baseline,
"t_baseline": t_baseline,
"v_baseline": v_baseline,
"baseline_interval": baseline_interval,
"stim_onset": onset,
}
)
return v_baseline_avg
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t_base, v_base = unpack(self.diagnostics, ["t_baseline", "v_baseline"])
ax.plot(t_base, v_base, label=self.name + " interval", **kwargs)
ax.axhline(self.value, ls="--", label=self.name, **kwargs)
return ax
[docs]class Sweep_V_deflect(SweepFeature):
"""Extract sweep level voltage deflection feature.
depends on: stim_end.
description: average voltage during last 100 ms of stimulus.
units: mV."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_deflect_avg = float("nan")
if has_stimulus(self.data) and is_hyperpol(self.data):
# v_deflect_avg = self.data.voltage_deflection()[0]
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
v_deflect_avg = ft.average_voltage(
self.data.v, self.data.t, start=end - 0.1, end=end
)
idx_deflect = np.where(where_between(self.data.t, end - 0.1, end))[0]
t_deflect = self.data.t[idx_deflect]
v_deflect = self.data.v[idx_deflect]
if store_diagnostics:
self._update_diagnostics(
{
"idx_deflect": idx_deflect,
"t_deflect": t_deflect,
"v_deflect": v_deflect,
}
)
return v_deflect_avg
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t_deflect, v_deflect = unpack(self.diagnostics, ["t_deflect", "v_deflect"])
t_bar = np.ones_like(t_deflect) * self.value
ax.plot(t_deflect, v_deflect, label=self.name + " interval", **kwargs)
ax.plot(t_deflect, t_bar, ls="--", label=self.name, **kwargs)
return ax
[docs]class Sweep_Tau(SweepFeature):
"""Extract sweep level time constant feature.
depends on: v_baseline, stim_onset.
description: time constant of exponential fit to voltage deflection.
units: s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
tau = float("nan")
if is_hyperpol(self.data):
"""The following code block is copied and adapted from sweep.estimate_time_constant()."""
v_peak, peak_index = self.data.voltage_deflection("min")
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
if 5 < v_baseline - v_peak:
stim_onset = self.lookup_sweep_feature(
"stim_onset", recompute=recompute
)
onset_idx = ft.find_time_index(self.data.t, stim_onset)
frac = 0.1
search_result = np.flatnonzero(
self.data.v[onset_idx:] <= frac * (v_peak - v_baseline) + v_baseline
)
if not search_result.size:
raise FeatureError(
"could not find interval for time constant estimate"
)
fit_start = self.data.t[search_result[0] + onset_idx]
fit_end = self.data.t[peak_index]
if self.data.v[peak_index] < -200:
warnings.warn(
"A DOWNWARD PEAK WAS OBSERVED GOING TO LESS THAN 200 MV!!!"
)
# Look for another local minimum closer to stimulus onset
# We look for a couple of milliseconds after stimulus onset to 50 ms before the downward peak
end_index = (onset_idx + 50) + np.argmin(
self.data.v[onset_idx + 50 : peak_index - 1250]
)
fit_end = self.data.t[end_index]
fit_start = self.data.t[onset_idx + 50]
a, inv_tau, y0 = ft.fit_membrane_time_constant(
self.data.v, self.data.t, fit_start, fit_end
)
tau = 1.0 / inv_tau * 1000
if store_diagnostics:
self._update_diagnostics(
{
"a": a,
"inv_tau": inv_tau,
"y0": y0,
"fit_start": fit_start,
"fit_end": fit_end,
"equation": "y0 + a * exp(-inv_tau * x)",
}
)
return tau
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
y0, a, inv_tau = unpack(self.diagnostics, ["y0", "a", "inv_tau"])
fit_start, fit_end = unpack(self.diagnostics, ["fit_start", "fit_end"])
t, v = self.data.t, self.data.v
y = lambda t: y0 + a * np.exp(-inv_tau * t)
where_fit = where_between(t, fit_start, fit_end)
t, v = t[where_fit], v[where_fit]
t_offset = t[0]
t_fit = t - t_offset
ax.plot(t, v, label=self.name + " interval", **kwargs)
ax.plot(t, y(t_fit), ls="--", color="k", label=self.name + " fit")
return ax
[docs]class Sweep_AP_freq_adapt(SweepFeature):
"""Extract sweep level spike frequency adaptation feature.
depends on: stim_onset, stim_end, num_ap.
description: ratio of spikes in second and first half half of stimulus interval, if there is at least 5 spikes in total.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_freq_adapt = float("nan")
num_ap = self.lookup_sweep_feature("num_ap", recompute=recompute)
if num_ap > 5 and has_stimulus(self.data):
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
t_half = (end - onset) / 2 + onset
where_1st_half = where_between(self.data.t, onset, t_half)
where_2nd_half = where_between(self.data.t, t_half, end)
t_1st_half = self.data.t[where_1st_half]
t_2nd_half = self.data.t[where_2nd_half]
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
stim_window = where_between(peak_t, onset, end)
peak_t = peak_t[stim_window]
spikes_1st_half = peak_t[peak_t < t_half]
spikes_2nd_half = peak_t[peak_t > t_half]
num_spikes_1st_half = len(spikes_1st_half)
num_spikes_2nd_half = len(spikes_2nd_half)
ap_freq_adapt = num_spikes_2nd_half / num_spikes_1st_half
if store_diagnostics:
self._update_diagnostics(
{
"num_spikes_1st_half": num_spikes_1st_half,
"num_spikes_2nd_half": num_spikes_2nd_half,
"where_1st_half": where_1st_half,
"where_2nd_half": where_2nd_half,
"t_1st_half": t_1st_half,
"t_2nd_half": t_2nd_half,
}
)
return ap_freq_adapt
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
num_ap = self.lookup_sweep_feature("num_ap", return_value=False)
peaks_t, peaks_v = unpack(num_ap.diagnostics, ["peak_t", "peak_v"])
half1, half2 = unpack(self.diagnostics, ["t_1st_half", "t_2nd_half"])
for i, (half, m) in enumerate(zip([half1, half2], ["+", "x"])):
in_half = where_between(peaks_t, *half[[0, -1]])
ax.plot(peaks_t[in_half], peaks_v[in_half], m, label=f"{i+1}/2", **kwargs)
return ax
[docs]class Sweep_AP_amp_slope(SweepFeature):
"""Extract sweep level spike count feature.
depends on: stim_onset, stim_end.
description: spike amplitude adaptation as the slope of a linear fit v_peak(t_peak)
during the stimulus interval.
units: mV/s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_amp_slope = float("nan")
onset = self.lookup_sweep_feature("stim_onset")
end = self.lookup_sweep_feature("stim_end")
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
peak_v = self.lookup_spike_feature("peak_v", recompute=recompute)
stim_window = where_between(peak_t, onset, end)
peak_t = peak_t[stim_window]
peak_v = peak_v[stim_window]
if len(peak_v) > 5:
y = lambda x, m, b: m * x + b
(m, b), e = curve_fit(y, peak_t, peak_v)
ap_amp_slope = m
if store_diagnostics:
self._update_diagnostics(
{
"peak_t": peak_t,
"peak_v": peak_v,
"slope": m,
"intercept": b,
}
)
return ap_amp_slope
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
intercept, slope, peak_t = unpack(
self.diagnostics, ["intercept", "slope", "peak_t"]
)
y = lambda t: intercept + slope * t
if not np.isnan(self.value):
ts = peak_t # or ts = self.data.t
ax.plot(ts, y(ts), "--", label=self.name, **kwargs)
return ax
[docs]class Sweep_ISI_FF(SweepFeature):
"""Extract sweep level inter-spike-interval (ISI) Fano factor feature.
depends on: ISIs.
description: Var(ISIs) / Mean(ISIs).
units: s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
isi_ff = float("nan")
if has_spikes(self.data):
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
if len(isi) > 1:
isi_ff = np.nanvar(isi) / np.nanmean(isi)
if store_diagnostics:
self._update_diagnostics(
{
"isi": isi,
"isi_var": np.nanvar(isi),
"isi_mean": np.nanmean(isi),
}
)
return isi_ff
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
warnings.warn(f" {self.name} plotting not implemented.")
return ax
[docs]class Sweep_ISI_CV(SweepFeature):
"""Extract sweep level inter-spike-interval (ISI) coefficient of variation (CV) feature.
depends on: ISIs.
description: Std(ISIs) / Mean(ISIs).
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
isi_cv = float("nan")
if has_spikes(self.data):
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
if len(isi) > 1:
isi_cv = np.nanstd(isi) / np.nanmean(isi)
if store_diagnostics:
self._update_diagnostics(
{
"isi": isi,
"isi_std": np.nanstd(isi),
"isi_mean": np.nanmean(isi),
}
)
return isi_cv
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
warnings.warn(f" {self.name} plotting not implemented.")
return ax
[docs]class Sweep_AP_FF(SweepFeature):
"""Extract sweep level AP amplitude Fano factor feature.
depends on: ap_amp.
description: Var(ap_amp) / Mean(ap_amp).
units: mV."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_ff = float("nan")
if has_spikes(self.data):
ap_amp = self.lookup_spike_feature("ap_amp", recompute=recompute)
if len(ap_amp) > 1:
ap_ff = np.nanvar(ap_amp) / np.nanmean(ap_amp)
if store_diagnostics:
self._update_diagnostics(
{
"ap_amp": ap_amp,
"ap_amp_var": np.nanvar(ap_amp),
"ap_amp_mean": np.nanmean(ap_amp),
}
)
return ap_ff
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
warnings.warn(f" {self.name} plotting not implemented.")
return ax
[docs]class Sweep_AP_CV(SweepFeature):
"""Extract sweep level AP amplitude coefficient of variation (CV) feature.
depends on: ap_amp.
description: Std(ap_amp) / Mean(ap_amp).
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_cv = float("nan")
if has_spikes(self.data):
ap_amp = self.lookup_spike_feature("ap_amp", recompute=recompute)
if len(ap_amp) > 1:
ap_cv = np.nanstd(ap_amp) / np.nanmean(ap_amp)
if store_diagnostics:
self._update_diagnostics(
{
"ap_amp": ap_amp,
"ap_amp_std": np.nanstd(ap_amp),
"ap_amp_mean": np.nanmean(ap_amp),
}
)
return ap_cv
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
warnings.warn(f" {self.name} plotting not implemented.")
return ax
[docs]class Sweep_V_sag(SweepFeature):
"""Extract sweep level sag voltage feature.
depends on: v_deflect, v_baseline.
description: Average voltage around max deflection.
units: mV."""
def __init__(self, data=None, compute_at_init=True, peak_width=0.005):
super().__init__(data, compute_at_init=False)
self.peak_width = peak_width
if compute_at_init and data is not None: # because of peak_width
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
v_sag = float("nan")
if is_hyperpol(self.data):
where_sag = get_sweep_sag_idxs(self, store_diagnostics=store_diagnostics)
if np.sum(where_sag) > 10: # TODO: what should be min sag duration!?
# The following can also be found in sweep.estimate_sag()
v_deflect, idx_deflect = self.data.voltage_deflection("min")
if self.data.v[idx_deflect] < -200:
warnings.warn("Downward peak < 200 mV")
# Look for another local minimum closer to stimulus onset
idx_deflect -= ft.find_time_index(
self.data.t, 0.12
) - ft.find_time_index(self.data.t, 0.1)
t_deflect = self.data.t[idx_deflect]
stim_onset = self.lookup_sweep_feature(
"stim_onset", recompute=recompute
)
stim_end = self.lookup_sweep_feature("stim_end", recompute=recompute)
if ( # TODO: Check if stricter criterion is sensible, i.e. t_deflect < t_half_stim
stim_onset < t_deflect < stim_end
): # in some rare cases this is not the case
start = t_deflect - self.peak_width / 2.0
end = t_deflect + self.peak_width / 2.0
v_sag = ft.average_voltage(
self.data.v,
self.data.t,
start=start,
end=end,
)
if store_diagnostics:
self._update_diagnostics(
{
"where_sag": where_sag,
"v_deflect": v_deflect,
"idx_deflect": idx_deflect,
"t_deflect": t_deflect,
"start": start,
"end": end,
}
)
return v_sag
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
start, end = unpack(self.diagnostics, ["start", "end"])
ax.hlines(self.value, start, end, ls="--", label=self.name, **kwargs)
return ax
[docs]class Sweep_Sag(SweepFeature):
"""Extract sweep level sag feature.
depends on: v_sag.
description: magnitude of the depolarization peak.
units: mV."""
def __init__(self, data=None, compute_at_init=True, peak_width=0.005):
self.peak_width = peak_width
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of peak_width
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
sag = float("nan")
if is_hyperpol(self.data):
where_sag = get_sweep_sag_idxs(self, store_diagnostics=store_diagnostics)
if np.sum(where_sag) > 10: # TODO: what should be min sag duration!?
v_sag = self.lookup_sweep_feature("v_sag", recompute=recompute)
v_baseline = self.lookup_sweep_feature(
"v_baseline", recompute=recompute
)
sag = v_sag - v_baseline
if store_diagnostics:
self._update_diagnostics(
{
"v_sag": v_sag,
"where_sag": where_sag,
"v_baseline": v_baseline,
}
)
return sag
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
v_baseline, v_sag = unpack(self.diagnostics, ["v_baseline", "v_sag"])
sag_voltage_ft = self.lookup_sweep_feature("v_sag", return_value=False)
t_deflect = unpack(sag_voltage_ft.diagnostics, ["t_deflect"])
ax.vlines(t_deflect, v_baseline, v_sag, label="sag", **kwargs)
return ax
[docs]class Sweep_V_steady(SweepFeature):
"""Extract sweep level hyperpol steady state feature.
depends on: stim_end.
description: hyperpol steady state voltage.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_steady = float("nan")
if is_hyperpol(self.data):
stim_end = self.lookup_sweep_feature("stim_end", recompute=recompute)
start = stim_end - self.data.baseline_interval
v_steady = ft.average_voltage(
self.data.v,
self.data.t,
start=start,
end=stim_end,
)
if store_diagnostics:
self._update_diagnostics({"start": start, "end": stim_end})
return v_steady
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
start, end = unpack(self.diagnostics, ["start", "end"])
ax.hlines(self.value, start, end, ls="--", label=self.name, **kwargs)
return ax
[docs]class Sweep_Sag_fraction(SweepFeature):
"""Extract sweep level sag fraction feature.
depends on: /.
description: fraction that membrane potential relaxes back to baseline.
units: /."""
def __init__(self, data=None, compute_at_init=True, peak_width=0.005):
self.peak_width = peak_width
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of peak_width
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
sag_fraction = float("nan")
if is_hyperpol(self.data):
where_sag = get_sweep_sag_idxs(self, store_diagnostics=store_diagnostics)
if np.sum(where_sag) > 10: # TODO: what should be min sag duration!?
sag = self.lookup_sweep_feature("sag", recompute=recompute)
v_sag = self.lookup_sweep_feature("v_sag", recompute=recompute)
v_steady = self.lookup_sweep_feature("v_steady", recompute=recompute)
sag_fraction = (v_sag - v_steady) / sag
if store_diagnostics:
self._update_diagnostics(
{
"sag": sag,
"v_sag": v_sag,
"where_sag": where_sag,
"v_steady": v_steady,
}
)
return sag_fraction
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
sag_v_ft = self.lookup_sweep_feature("v_sag", return_value=False)
t_deflect = unpack(sag_v_ft.diagnostics, "t_deflect")
v_baseline = self.lookup_sweep_feature("v_baseline")
stim_end = self.lookup_sweep_feature("stim_end")
v_sag, v_steady = unpack(self.diagnostics, ["v_sag", "v_steady"])
ax.vlines(t_deflect, v_baseline, v_sag, label="sag", **kwargs)
ax.vlines(stim_end, v_steady, v_sag, label="v_sag - v_steady", **kwargs)
return ax
[docs]class Sweep_Sag_ratio(SweepFeature):
"""Extract sweep level sag ratio feature.
depends on: /.
description: ratio of steady state voltage decrease to the largest voltage decrease.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
sag_ratio = float("nan")
if is_hyperpol(self.data):
where_sag = get_sweep_sag_idxs(self, store_diagnostics=store_diagnostics)
if np.sum(where_sag) > 10: # TODO: what should be min sag duration!?
sag = self.lookup_sweep_feature("sag", recompute=recompute)
v_steady = self.lookup_sweep_feature("v_steady", recompute=recompute)
v_baseline = self.lookup_sweep_feature(
"v_baseline", recompute=recompute
)
sag_ratio = sag / (v_steady - v_baseline)
if store_diagnostics:
self._update_diagnostics(
{
"sag": sag,
"where_sag": where_sag,
"v_baseline": v_baseline,
"v_steady": v_steady,
}
)
return sag_ratio
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
sag_v_ft = self.lookup_sweep_feature("v_sag", return_value=False)
t_deflect = unpack(sag_v_ft.diagnostics, "t_deflect")
v_sag = self.lookup_sweep_feature("v_sag")
stim_end = self.lookup_sweep_feature("stim_end")
v_baseline, v_steady = unpack(self.diagnostics, ["v_baseline", "v_steady"])
ax.vlines(stim_end, v_steady, v_baseline, label="v_steady - v_base", **kwargs)
ax.vlines(t_deflect, v_baseline, v_sag, label="sag", **kwargs)
return ax
[docs]class Sweep_Sag_area(SweepFeature):
"""Extract sweep level sag area feature.
depends on: v_deflect, stim_onset, stim_end.
description: area under the sag.
units: mV*s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
sag_area = float("nan")
if is_hyperpol(self.data):
where_sag = get_sweep_sag_idxs(self, store_diagnostics=store_diagnostics)
if np.sum(where_sag) > 10: # TODO: what should be min sag duration!?
v_sag = self.data.v[where_sag]
t_sag = self.data.t[where_sag]
v_sagline = v_sag[0]
# Take running average of v?
if len(v_sag) > 10: # at least 10 points to integrate
sag_area = cumulative_trapezoid(v_sagline - v_sag, t_sag)[-1]
if store_diagnostics:
self._update_diagnostics(
{
"where_sag": where_sag,
"v_sag": v_sag,
"t_sag": t_sag,
"v_sagline": v_sagline,
}
)
return sag_area
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t_sag, v_sag, v_sagline = unpack(
self.diagnostics, ["t_sag", "v_sag", "v_sagline"]
)
ax.plot(t_sag, v_sag, label="sag interval", **kwargs)
ax.fill_between(t_sag, v_sag, v_sagline, alpha=0.5, label=self.name)
return ax
[docs]class Sweep_Sag_time(SweepFeature):
"""Extract sweep level sag duration feature.
depends on: v_deflect, stim_onset, stim_end.
description: duration of the sag.
units: s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
sag_time = float("nan")
if is_hyperpol(self.data):
where_sag = get_sweep_sag_idxs(self, store_diagnostics=store_diagnostics)
if np.sum(where_sag) > 10: # TODO: what should be min sag duration!?
sag_t_start, sag_t_end = self.data.t[where_sag][[0, -1]]
sag_time = sag_t_end - sag_t_start
if store_diagnostics:
self._update_diagnostics(
{
"where_sag": where_sag,
"sag_t_start": sag_t_start,
"sag_t_end": sag_t_end,
}
)
return sag_time
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
where_sag, sag_start, sag_end = unpack(
self.diagnostics, ["where_sag", "sag_t_start", "sag_t_end"]
)
v = self.data.v[where_sag][0]
ax.hlines(v, xmin=sag_start, xmax=sag_end, label=self.name, **kwargs)
return ax
[docs]class Sweep_V_plateau(SweepFeature):
"""Extract sweep level plataeu voltage feature.
depends on: stim_end.
description: average voltage during the plateau.
units: mV."""
def __init__(self, data=None, compute_at_init=True, T_plateau=0.1):
self.T_plateau = T_plateau
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of T_plateau
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
v_avg_plateau = float("nan")
if is_hyperpol(self.data):
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
# same as voltage deflection
where_plateau = where_between(self.data.t, end - self.T_plateau, end)
v_plateau = self.data.v[where_plateau]
t_plateau = self.data.t[where_plateau]
v_avg_plateau = ft.average_voltage(v_plateau, t_plateau)
if store_diagnostics:
self._update_diagnostics(
{
"where_plateau": where_plateau,
"v_plateau": v_plateau,
"t_plateau": t_plateau,
}
)
return v_avg_plateau
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t, v = unpack(self.diagnostics, ["t_plateau", "v_plateau"])
ax.plot(t, v, label=self.name + " interval", **kwargs)
ax.hlines(self.value, *t[[0, -1]], ls="--", label=self.name, **kwargs)
return ax
[docs]class Sweep_Rebound(SweepFeature):
"""Extract sweep level rebound feature.
depends on: v_baseline, stim_end.
description: V_max during stimulus_end and stimulus_end + T_rebound - V_baseline.
units: mV."""
def __init__(self, data=None, compute_at_init=True, T_rebound=0.3):
self.T_rebound = T_rebound
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of T_rebound
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
rebound = float("nan")
if has_rebound(self, self.T_rebound):
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
where_rebound = where_between(self.data.t, end, end + self.T_rebound)
where_rebound = np.logical_and(where_rebound, self.data.v > v_baseline)
t_rebound = self.data.t[where_rebound]
v_rebound = self.data.v[where_rebound]
if len(v_rebound) > 10: # at least 10 time points with rebound
idx_rebound = np.argmax(self.data.v[where_rebound] - v_baseline)
idx_rebound = np.where(where_rebound)[0][idx_rebound]
max_rebound = self.data.v[idx_rebound]
rebound = max_rebound - v_baseline
if store_diagnostics:
self._update_diagnostics(
{
"idx_rebound": idx_rebound,
"t_rebound": t_rebound,
"v_rebound": v_rebound,
"v_baseline": v_baseline,
"max_rebound": max_rebound,
"where_rebound": where_rebound,
}
)
return rebound
def _plot(self, ax=None, include_details=False, **kwargs):
t_rebound, v_rebound, idx_rebound, v_baseline = unpack(
self.diagnostics, ["t_rebound", "v_rebound", "idx_rebound", "v_baseline"]
)
t = self.data.t[idx_rebound]
v = self.data.v[idx_rebound]
ax.vlines(t, v_baseline, v, label=self.name, **kwargs)
if include_details:
ax.plot(t_rebound, v_rebound, label=self.name + " interval", **kwargs)
return ax
[docs]class Sweep_Rebound_APs(SweepFeature):
"""Extract sweep level number of rebounding spikes feature.
depends on: stim_end.
description: number of spikes during stimulus_end and stimulus_end + T_rebound.
units: /."""
def __init__(self, data=None, compute_at_init=True, T_rebound=0.3):
self.T_rebound = T_rebound
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of T_rebound
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
num_rebound_aps = float("nan")
if has_rebound(self, self.T_rebound):
t_spike = self.lookup_spike_feature("peak_t", recompute=recompute)
idx_spike = self.lookup_spike_feature("peak_index", recompute=recompute)
v_spike = self.lookup_spike_feature("peak_v", recompute=recompute)
if len(t_spike) != 0:
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
w_rebound = where_between(t_spike, end, end + self.T_rebound)
idx_rebound = idx_spike[w_rebound]
t_rebound = t_spike[w_rebound]
v_rebound = v_spike[w_rebound]
num_rebound_aps = np.sum(w_rebound)
if num_rebound_aps > 0:
if store_diagnostics:
self._update_diagnostics(
{
"idx_rebound": idx_rebound,
"t_rebound": t_rebound,
"v_rebound": v_rebound,
}
)
return num_rebound_aps
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t_rebound, v_rebound = unpack(self.diagnostics, ["t_rebound", "v_rebound"])
ax.plot(t_rebound, v_rebound, "x", label=self.name, **kwargs)
return ax
[docs]class Sweep_Rebound_area(SweepFeature):
"""Extract sweep level rebound area feature.
depends on: v_baseline, stim_end.
description: area between rebound curve and baseline voltage from stimulus_end
to stimulus_end + T_rebound.
units: mV*s."""
def __init__(self, data=None, compute_at_init=True, T_rebound=0.3):
self.T_rebound = T_rebound
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of T_rebound
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
rebound_area = float("nan")
if has_rebound(self, self.T_rebound):
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
where_rebound = where_between(self.data.t, end, end + self.T_rebound)
where_rebound = np.logical_and(where_rebound, self.data.v > v_baseline)
v_rebound = self.data.v[where_rebound]
t_rebound = self.data.t[where_rebound]
if len(v_rebound) > 10: # at least 10 points to integrate
rebound_area = cumulative_trapezoid(v_rebound - v_baseline, t_rebound)[
-1
]
if store_diagnostics:
self._update_diagnostics(
{
"where_rebound": where_rebound,
"t_rebound": t_rebound,
"v_rebound": v_rebound,
"v_baseline": v_baseline,
}
)
return rebound_area
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t, v = self.data.t, self.data.v
v_baseline, where_rebound = unpack(
self.diagnostics, ["v_baseline", "where_rebound"]
)
ax.fill_between(
t, v, v_baseline, where=where_rebound, alpha=0.5, label=self.name, **kwargs
)
return ax
[docs]class Sweep_Rebound_latency(SweepFeature):
"""Extract sweep level rebound latency feature.
depends on: v_baseline, stim_end.
description: duration from stimulus_end to when the voltage reaches above
baseline for the first time. t_rebound = t_off + rebound_latency.
units: s."""
def __init__(self, data=None, compute_at_init=True, T_rebound=0.3):
self.T_rebound = T_rebound
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of T_rebound
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
rebound_latency = float("nan")
if has_rebound(self, self.T_rebound):
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
where_rebound = where_between(self.data.t, end, end + self.T_rebound)
where_rebound = np.logical_and(where_rebound, self.data.v > v_baseline)
t_rebound = self.data.t[where_rebound]
v_rebound = self.data.v[where_rebound]
if len(v_rebound) > 10: # at least 10 time points with rebound
idx_rebound_reached = np.where(where_rebound)[0]
t_rebound_reached = self.data.t[idx_rebound_reached][0]
rebound_latency = t_rebound_reached - end
if store_diagnostics:
self._update_diagnostics(
{
"idx_rebound_reached": idx_rebound_reached,
"t_rebound_reached": t_rebound_reached,
"where_rebound": where_rebound,
"t_rebound": t_rebound,
"v_rebound": v_rebound,
"v_baseline": v_baseline,
}
)
return rebound_latency
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t, v = self.data.t, self.data.v
t_rebound_reached = unpack(self.diagnostics, "t_rebound_reached")
stim_end = self.lookup_sweep_feature("stim_end", return_value=False)
end_idx = unpack(stim_end.diagnostics, "idx_end")
until_rebound = where_between(t, t[end_idx], t_rebound_reached)
ax.fill_between(
t, v, v[end_idx], where=until_rebound, alpha=0.5, label=self.name, **kwargs
)
return ax
[docs]class Sweep_Rebound_avg(SweepFeature):
"""Extract sweep level average rebound feature.
depends on: v_baseline, stim_end.
description: average voltage between stimulus_end
and stimulus_end + T_rebound - baseline voltage.
units: mV."""
def __init__(self, data=None, compute_at_init=True, T_rebound=0.3):
self.T_rebound = T_rebound
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of T_rebound
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
v_rebound_avg = float("nan")
if has_rebound(self, self.T_rebound):
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
where_rebound = where_between(self.data.t, end, end + self.T_rebound)
where_rebound = np.logical_and(where_rebound, self.data.v > v_baseline)
v_rebound = self.data.v[where_rebound]
t_rebound = self.data.t[where_rebound]
if len(v_rebound) > 10: # at least 10 rebound points
v_rebound_avg = ft.average_voltage(v_rebound, t_rebound) - v_baseline
if store_diagnostics:
self._update_diagnostics(
{
"where_rebound": where_rebound,
"t_rebound": t_rebound,
"v_rebound": v_rebound,
"v_baseline": v_baseline,
}
)
return v_rebound_avg
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t_rebound, v_rebound = unpack(self.diagnostics, ["t_rebound", "v_rebound"])
v_baseline = unpack(self.diagnostics, "v_baseline")
ax.plot(t_rebound, v_rebound, label=self.name + " interval", **kwargs)
ax.hlines(
[self.value + v_baseline],
# np.mean(v_rebound),
*t_rebound[[0, -1]],
ls="--",
label=self.name,
**kwargs,
)
return ax
[docs]class Sweep_V_rest(SweepFeature):
"""Extract sweep level resting potential feature.
depends on: v_baseline, r_input, dc_offset.
description: v_rest = v_baseline - r_input*dc_offset.
units: mV."""
def __init__(self, data=None, compute_at_init=True, dc_offset=0):
self.dc_offset = dc_offset
super().__init__(data, compute_at_init=False)
if compute_at_init and data is not None: # because of dc_offset
self.get_value()
def _compute(self, recompute=False, store_diagnostics=True):
v_rest = float("nan")
v_baseline = self.lookup_sweep_feature("v_baseline", recompute=recompute)
r_input = self.lookup_sweep_feature("r_input", recompute=recompute)
try:
v_rest = v_baseline - r_input * 1e-3 * self.dc_offset
if store_diagnostics:
self._update_diagnostics(
{
"v_baseline": v_baseline,
"r_input": r_input,
"dc_offset": self.dc_offset,
}
)
except KeyError:
pass
return v_rest
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
r_input, dc_offset = unpack(self.diagnostics, ["r_input", "dc_offset"])
t, v = self.data.t, self.data.v
v -= r_input * dc_offset * 1e-3
ax.plot(t, v, label="v(t) - r_in*dc_offset", **kwargs)
ax.axhline(self.value, ls="--", label=self.name)
return ax
[docs]class Sweep_Num_bursts(SweepFeature):
"""Extract sweep level number of bursts feature.
depends on: num_ap.
description: Number of detected bursts.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
num_bursts = float("nan")
num_ap = self.lookup_sweep_feature("num_ap", recompute=recompute)
if num_ap > 5 and has_stimulus(self.data):
idx_burst, idx_burst_start, idx_burst_end = get_sweep_burst_metrics(
self.data
)
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
if not np.isnan(idx_burst).any():
t_burst_start = peak_t[idx_burst_start]
t_burst_end = peak_t[idx_burst_end]
num_bursts = len(idx_burst)
num_bursts = float("nan") if num_bursts == 0 else num_bursts
if store_diagnostics:
self._update_diagnostics(
{
"idx_burst": idx_burst,
"idx_burst_start": idx_burst_start,
"idx_burst_end": idx_burst_end,
"t_burst_start": t_burst_start,
"t_burst_end": t_burst_end,
}
)
return num_bursts
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t_burst_start, t_burst_end = unpack(
self.diagnostics, ["t_burst_start", "t_burst_end"]
)
for i, (t_start, t_end) in enumerate(zip(t_burst_start, t_burst_end)):
ax.axvspan(
t_start,
t_end,
alpha=0.5,
label=f"burst {i+1}",
**kwargs,
)
return ax
[docs]class Sweep_Burstiness(SweepFeature):
"""Extract sweep level burstiness feature.
depends on: num_ap.
description: max "burstiness" index across detected bursts.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
max_burstiness = float("nan")
num_ap = self.lookup_sweep_feature("num_ap", recompute=recompute)
if num_ap > 5 and has_stimulus(self.data):
idx_burst, idx_burst_start, idx_burst_end = get_sweep_burst_metrics(
self.data
)
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
if not np.isnan(idx_burst).any():
t_burst_start = peak_t[idx_burst_start]
t_burst_end = peak_t[idx_burst_end]
num_bursts = len(idx_burst)
max_burstiness = idx_burst.max() if num_bursts > 0 else float("nan")
max_burstiness = (
float("nan") if max_burstiness < 0 else max_burstiness
) # don't consider negative burstiness
if store_diagnostics:
self._update_diagnostics(
{
"idx_burst": idx_burst,
"idx_burst_start": idx_burst_start,
"idx_burst_end": idx_burst_end,
"t_burst_start": t_burst_start,
"t_burst_end": t_burst_end,
}
)
return max_burstiness
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
num_bursts = self.lookup_sweep_feature("num_bursts", return_value=False)
ax = num_bursts.plot(ax=ax, **kwargs)
return ax
[docs]class Sweep_ISI_adapt(SweepFeature):
"""Extract sweep level inter-spike-interval (ISI) adaptation index feature.
depends on: ISIs.
description: /.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
isi_adapt = float("nan")
if has_spikes(self.data):
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
if len(isi) > 1:
isi_adapt = isi[1] / isi[0]
if store_diagnostics:
self._update_diagnostics({"isi": isi})
return isi_adapt
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
plot_isi = self.data.added_spike_features["isi"].plot
ax = plot_isi(ax=ax, selected_idxs=[1, 2], **kwargs)
relabel_line(ax, "isi", self.name)
return ax
[docs]class Sweep_ISI_adapt_avg(SweepFeature):
"""Extract sweep level average inter-spike-interval (ISI) adaptation index feature.
depends on: ISIs.
description: /.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
isi_adapt_avg = float("nan")
if has_spikes(self.data):
isi = self.lookup_spike_feature("isi", recompute=recompute)[1:]
if len(isi) > 2:
isi_changes = isi[1:] / isi[:-1]
isi_adapt_avg = isi_changes.mean()
if store_diagnostics:
self._update_diagnostics({"isi": isi})
return isi_adapt_avg
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
plot_isi = self.data.added_spike_features["isi"].plot
ax = plot_isi(ax=ax, **kwargs)
relabel_line(ax, "isi", self.name)
return ax
[docs]class Sweep_AP_amp_adapt(SweepFeature):
"""Extract sweep level AP amplitude adaptation index feature.
depends on: ap_amp.
description: /.
units: mV/s."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_amp_adapt = float("nan")
if has_spikes(self.data):
ap_amp = self.lookup_spike_feature("ap_amp", recompute=recompute)
if len(ap_amp) > 1:
ap_amp_adapt = ap_amp[1] / ap_amp[0]
if store_diagnostics:
self._update_diagnostics({"ap_amp": ap_amp})
return ap_amp_adapt
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
plot_ap_amp = self.data.added_spike_features["ap_amp"].plot
ax = plot_ap_amp(ax=ax, selected_idxs=[0, 1], **kwargs)
relabel_line(ax, "ap_amp", self.name)
return ax
[docs]class Sweep_AP_amp_adapt_avg(SweepFeature):
"""Extract sweep level average AP amplitude adaptation index feature.
depends on: ap_amp.
description: /.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
ap_amp_adapt_avg = float("nan")
if has_spikes(self.data):
ap_amp = self.lookup_spike_feature("ap_amp", recompute=recompute)
if len(ap_amp) > 2:
ap_amp_changes = ap_amp[1:] / ap_amp[:-1]
ap_amp_adapt_avg = ap_amp_changes.mean()
if store_diagnostics:
self._update_diagnostics({"ap_amp": ap_amp})
return ap_amp_adapt_avg
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
plot_ap_amp = self.data.added_spike_features["ap_amp"].plot
ax = plot_ap_amp(ax=ax, **kwargs)
relabel_line(ax, "ap_amp", self.name)
return ax
[docs]class Sweep_Wildness(SweepFeature):
"""Extract sweep level wildness feature.
depends on: /.
description: Wildness is the number of spikes that occur outside of the stimulus interval.
units: /."""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
num_wild_spikes = float("nan")
if has_spikes(self.data):
onset = self.lookup_sweep_feature("stim_onset", recompute=recompute)
end = self.lookup_sweep_feature("stim_end", recompute=recompute)
peak_t = self.lookup_spike_feature("peak_t", recompute=recompute)
peak_idx = self.lookup_spike_feature("peak_index", recompute=recompute)
peak_v = self.lookup_spike_feature("peak_v", recompute=recompute)
stim_window = where_between(peak_t, onset, end)
i_wild_spikes = peak_idx[~stim_window]
t_wild_spikes = peak_t[~stim_window]
v_wild_spikes = peak_v[~stim_window]
if len(i_wild_spikes) > 0:
num_wild_spikes = len(i_wild_spikes)
if store_diagnostics:
self._update_diagnostics(
{
"i_wild_spikes": i_wild_spikes,
"t_wild_spikes": t_wild_spikes,
"v_wild_spikes": v_wild_spikes,
}
)
return num_wild_spikes
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
t, v = unpack(self.diagnostics, ["t_wild_spikes", "v_wild_spikes"])
ax.plot(t, v, "x", label=self.name, **kwargs)
return ax
[docs]class APSweepFeature(SweepFeature):
"""Extract sweep level AP feature.
description: Action potential feature to represent a sweep."""
def __init__(
self,
data=None,
compute_at_init=True,
ft_name: Optional[str] = None,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
"""
Args:
ft_name (Optional[str], optional): Name of the spike feature
ap_selector (Optional[Callable], optional): Function which selects a
representative ap or set of aps based on a given criterion.
Function expects a EphysSweepSetFeatureExtractor object as input and
returns indices for the selected aps. If none is provided, falls
back to selecting all aps.
ft_aggregator (Optional[Callable], optional): Function which aggregates
a list of feature values into a single value. Function expects a
list or ndarray of numbers as input. If none is provided, falls back
to `np.nanmedian` (equates to pass through for single sweeps)."""
self.ap_selector = ap_selector
self.ft_aggregator = ft_aggregator
super().__init__(data, compute_at_init)
if ft_name is not None:
self.name = ft_name
def _select(self, data):
"""Function expects a EphysSweepSetFeatureExtractor object as input and
returns indices for the selected aps.
description: Select a representative ap or set of aps based on a
given criterion. If none is provided, falls back to selecting all aps during stimulus.
"""
if self.ap_selector is None:
peak_t = self.lookup_spike_feature("peak_t")
onset = self.lookup_sweep_feature("stim_onset")
end = self.lookup_sweep_feature("stim_end")
stim_window = where_between(peak_t, onset, end)
return stim_window
else:
return self.ap_selector(data)
def _aggregate(self, X):
"""Function expects a list or ndarray of numbers as input.
description: Aggregate a list of feature values into a single value. If none is provided, falls back
to `np.nanmedian` (equates to pass through for single sweeps)."""
if np.isnan(X).all():
return float("nan")
elif self.ft_aggregator is None:
self._update_diagnostics({"aggregate_idx": median_idx(X)})
return np.nanmedian(X).item()
else:
return self.ft_aggregator(X)
def _compute(self, recompute=False, store_diagnostics=True):
feature = self.lookup_spike_feature(self.name, recompute=recompute)
ft_agg = float("nan")
if len(feature) > 0:
selected_idx = self._select(self.data)
fts_selected = feature[selected_idx]
if isinstance(fts_selected, (float, int, np.float64, np.int64)):
ft_agg = fts_selected
elif isinstance(fts_selected, ndarray):
if len(fts_selected.flat) == 0:
ft_agg = float("nan")
else:
ft_agg = self._aggregate(feature)
if store_diagnostics:
self._update_diagnostics(
{
"selected_idx": selected_idx,
"selected_fts": fts_selected,
"selection": parse_desc(self._select),
"aggregation": parse_desc(self._aggregate),
}
)
return ft_agg
def _plot(self, ax: Optional[Axes] = None, **kwargs) -> Axes:
idxs = unpack(self.diagnostics, "selected_idx")
plot_spike_feature = self.data.added_spike_features[self.name].plot
ax = plot_spike_feature(ax=ax, selected_idxs=idxs, **kwargs)
return ax
[docs]class Sweep_AP_AHP(APSweepFeature):
"""Extract sweep level Afterhyperpolarization feature.
depends on: /.
description: Afterhyperpolarization (AHP) for representative AP. Difference
between the fast trough and the threshold.
units: mV."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_ahp", ap_selector, ft_aggregator)
[docs]class Sweep_AP_ADP(APSweepFeature):
"""Extract sweep level Afterdepolarization feature.
depends on: /.
description: Afterdepolarization (ADP) for representative AP. Difference between the ADP and the fast trough.
units: mV."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_adp", ap_selector, ft_aggregator)
[docs]class Sweep_AP_thresh(APSweepFeature):
"""Extract sweep level AP threshold feature.
depends on: /.
description: AP threshold for representative AP.
units: mV."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_thresh", ap_selector, ft_aggregator)
[docs]class Sweep_AP_amp(APSweepFeature):
"""Extract sweep level AP amplitude feature.
depends on: /.
description: AP amplitude for representative AP.
units: mV."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_amp", ap_selector, ft_aggregator)
[docs]class Sweep_AP_width(APSweepFeature):
"""Extract sweep level AP width feature.
depends on: /.
description: AP width for representative AP.
units: s."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_width", ap_selector, ft_aggregator)
[docs]class Sweep_AP_peak(APSweepFeature):
"""Extract sweep level AP peak feature.
depends on: /.
description: AP peak for representative AP.
units: mV."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_peak", ap_selector, ft_aggregator)
[docs]class Sweep_AP_trough(APSweepFeature):
"""Extract sweep level AP trough feature.
depends on: /.
description: AP trough for representative AP.
units: mV."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_trough", ap_selector, ft_aggregator)
[docs]class Sweep_AP_UDR(APSweepFeature):
"""Extract sweep level Upstroke-to-downstroke ratio feature.
depends on: /.
description: Upstroke-to-downstroke ratio for representative AP.
units: /."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "ap_udr", ap_selector, ft_aggregator)
[docs]class Sweep_ISI(APSweepFeature):
"""Extract sweep level ISI ratio feature.
depends on: /.
description: Median interspike interval.
units: /."""
def __init__(
self,
data=None,
compute_at_init=True,
ap_selector: Optional[Callable] = None,
ft_aggregator: Optional[Callable] = None,
):
super().__init__(data, compute_at_init, "isi", ap_selector, ft_aggregator)