#!/usr/bin/env python3
# Copyright 2023 Jonas Beck
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import Dict, Optional
import numpy as np
from matplotlib.axes import Axes
from ephyspy.features.base import SpikeFeature
from ephyspy.features.utils import fetch_available_fts
from ephyspy.utils import fwhm, has_spike_feature, is_spike_feature, scatter_spike_ft
[docs]def available_spike_features(
compute_at_init: bool = False, store_diagnostics: bool = False
) -> Dict[str, SpikeFeature]:
"""Return a dictionary of all implemented spike features.
Looks for all classes that inherit from SpikeFeature and returns a dictionary
of all available features. If compute_at_init is True, the features are
computed at initialization.
Args:
compute_at_init (bool, optional): If True, the features are computed at
initialization. Defaults to False.
store_diagnostics (bool, optional): If True, the features are computed
with diagnostics. Defaults to False.
Returns:
dict[str, SpikeFeature]: Dictionary of all available spike features.
"""
all_features = fetch_available_fts()
features = {ft.__name__.lower(): ft for ft in all_features if is_spike_feature(ft)}
features = {k.replace("spike_", ""): v for k, v in features.items()}
if any((compute_at_init, store_diagnostics)):
return {
k: lambda *args, **kwargs: v(
*args,
compute_at_init=compute_at_init,
store_diagnostics=store_diagnostics,
**kwargs,
)
for k, v in features.items()
}
else:
return features
[docs]class Spike_AP_upstroke(SpikeFeature):
"""Extract spike level upstroke feature.
depends on: /.
description: upstroke of AP.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
upstroke = self.lookup_spike_feature("upstroke_v", recompute=recompute)
return upstroke
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"upstroke", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_downstroke(SpikeFeature):
"""Extract spike level downstroke feature.
depends on: /.
description: downstroke of AP.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
upstroke = self.lookup_spike_feature("downstroke_v", recompute=recompute)
return upstroke
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"downstroke", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_fast_trough(SpikeFeature):
"""Extract spike level fast trough feature.
depends on: /.
description: fast trough of AP.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
upstroke = self.lookup_spike_feature("fast_trough_v", recompute=recompute)
return upstroke
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"fast_trough", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_slow_trough(SpikeFeature):
"""Extract spike level slow trough feature.
depends on: /.
description: slow trough of AP.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
upstroke = self.lookup_spike_feature("slow_trough_v", recompute=recompute)
return upstroke
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"slow_trough", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_amp(SpikeFeature):
"""Extract spike level peak height feature.
depends on: threshold_v, peak_v.
description: v_peak - threshold_v.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_peak = self.lookup_spike_feature("peak_v", recompute=recompute)
threshold_v = self.lookup_spike_feature("threshold_v", recompute=recompute)
peak_height = v_peak - threshold_v
return peak_height if len(v_peak) > 0 else np.array([])
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
if has_spike_feature(self.data, "threshold_v"):
idxs = slice(None) if selected_idxs is None else selected_idxs
thresh_v = self.lookup_spike_feature("threshold_v")[idxs]
peak_t = self.lookup_spike_feature("peak_t")[idxs]
peak_v = self.lookup_spike_feature("peak_v")[idxs]
ax.plot(peak_t, peak_v, "x", **kwargs)
ax.vlines(peak_t, thresh_v, peak_v, ls="--", label="ap_amp", **kwargs)
return ax
[docs]class Spike_AP_AHP(SpikeFeature):
"""Extract spike level after hyperpolarization feature.
depends on: threshold_v, fast_trough_v.
description: v_fast_trough - threshold_v.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_fast_trough = self.lookup_spike_feature("fast_trough_v", recompute=recompute)
threshold_v = self.lookup_spike_feature("threshold_v", recompute=recompute)
return v_fast_trough - threshold_v
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
if has_spike_feature(self.data, "ap_ahp"):
idxs = slice(None) if selected_idxs is None else selected_idxs
trough_t = self.lookup_spike_feature("fast_trough_t")[idxs]
trough_v = self.lookup_spike_feature("fast_trough_v")[idxs]
threshold_t = self.lookup_spike_feature("threshold_t")[idxs]
threshold_v = self.lookup_spike_feature("threshold_v")[idxs]
ax.vlines(
0.5 * (trough_t + threshold_t),
trough_v,
threshold_v,
ls="--",
lw=1,
label="ahp",
**kwargs,
)
return ax
[docs]class Spike_AP_ADP(SpikeFeature):
"""Extract spike level after depolarization feature.
depends on: adp_v, fast_trough_v.
description: v_adp - v_fast_trough.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_adp = self.lookup_spike_feature("adp_v", recompute=recompute)
v_fast_trough = self.lookup_spike_feature("fast_trough_v", recompute=recompute)
return v_adp - v_fast_trough
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
if has_spike_feature(self.data, "ap_adp"):
idxs = slice(None) if selected_idxs is None else selected_idxs
adp_t = self.lookup_spike_feature("adp_t")[idxs]
adp_v = self.lookup_spike_feature("adp_v")[idxs]
trough_t = self.lookup_spike_feature("fast_trough_t")[idxs]
trough_v = self.lookup_spike_feature("fast_trough_v")[idxs]
ax.vlines(
0.5 * (adp_t + trough_t),
adp_v,
trough_v,
ls="--",
lw=1,
label="adp",
**kwargs,
)
return ax
[docs]class Spike_AP_peak(SpikeFeature):
"""Extract spike level peak feature.
depends on: peak_v.
description: max voltage of AP.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_peak = self.lookup_spike_feature("peak_v", recompute=recompute)
return v_peak
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"peak", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_thresh(SpikeFeature):
"""Extract spike level ap threshold feature.
depends on: threshold_v.
description: For details on how AP thresholds are computed see AllenSDK.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_thresh = self.lookup_spike_feature("threshold_v", recompute=recompute)
return v_thresh
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"threshold", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_trough(SpikeFeature):
"""Extract spike level ap trough feature.
depends on: through_v.
description: For details on how AP troughs are computed see AllenSDK.
units: mV.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
v_thresh = self.lookup_spike_feature("trough_v", recompute=recompute)
return v_thresh
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
return scatter_spike_ft(
"trough", self.data, ax=ax, selected_idxs=selected_idxs, **kwargs
)
[docs]class Spike_AP_width(SpikeFeature):
"""Extract spike level ap width feature.
depends on: width.
description: full width half max of AP.
units: s.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
width = self.lookup_spike_feature("width", recompute=recompute)
return width
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
if has_spike_feature(self.data, "width"):
idxs = slice(None) if selected_idxs is None else selected_idxs
# the following is adapted from `allen_sdk.ephys_features.find_widths`
trough_idxs = self.lookup_spike_feature("trough_index").astype(int)
spike_idxs = self.lookup_spike_feature("threshold_index").astype(int)
peak_idxs = self.lookup_spike_feature("peak_index").astype(int)
t = self.data.t
v = self.data.v
ap_height = v[peak_idxs] - v[trough_idxs]
trough_fwhm = ap_height / 2.0 + v[trough_idxs]
thresh_fwhm = (v[peak_idxs] - v[spike_idxs]) / 2.0 + v[spike_idxs]
# Some spikes in burst may have deep trough but short height, so can't use same
# definition for width
fwhm = trough_fwhm.copy()
fwhm[trough_fwhm < v[spike_idxs]] = thresh_fwhm[trough_fwhm < v[spike_idxs]]
width_idx = np.array(
[
pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0]
if np.flatnonzero(v[pk:spk:-1] <= wl).size > 0
else np.nan
for pk, spk, wl in zip(
peak_idxs,
spike_idxs,
fwhm,
)
]
).astype(int)
fwhm = fwhm[idxs]
width_t = t[width_idx][idxs]
width = self.lookup_spike_feature("width")[idxs]
ax.hlines(fwhm, width_t, width_t + width, label="width", ls="--", **kwargs)
return ax
[docs]class Spike_AP_UDR(SpikeFeature):
"""Extract spike level ap udr feature.
depends on: upstroke, downstroke.
description: upstroke / downstroke. For details on how upstroke, downstroke
are computed see AllenSDK.
units: /.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
upstroke = self.lookup_spike_feature("upstroke", recompute=recompute)
downstroke = self.lookup_spike_feature("downstroke", recompute=recompute)
return upstroke / -downstroke
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
if has_spike_feature(self.data, "threshold_t"):
idxs = slice(None) if selected_idxs is None else selected_idxs
upstroke_t = self.lookup_spike_feature("upstroke_t")[idxs]
upstroke_v = self.lookup_spike_feature("upstroke_v")[idxs]
downstroke_t = self.lookup_spike_feature("downstroke_t")[idxs]
downstroke_v = self.lookup_spike_feature("downstroke_v")[idxs]
ax.plot(upstroke_t, upstroke_v, "x", label="upstroke", **kwargs)
ax.plot(downstroke_t, downstroke_v, "x", label="upstroke", **kwargs)
return ax
[docs]class Spike_ISI(SpikeFeature):
"""Extract spike level inter-spike-interval feature.
depends on: threshold_t.
description: The distance between subsequent spike thresholds. isi at the
first index is nan since isi[t+1] = threshold_t[t+1] - threshold_t[t].
units: s.
"""
def __init__(self, data=None, compute_at_init=True):
super().__init__(data, compute_at_init)
def _compute(self, recompute=False, store_diagnostics=True):
spike_times = self.lookup_spike_feature("threshold_t", recompute=recompute)
if len(spike_times) > 1:
isi = np.diff(spike_times)
isi = np.insert(isi, 0, 0)
return isi
elif len(spike_times) == 1:
return np.array([float("nan")])
else:
return np.array([])
def _plot(self, ax: Optional[Axes] = None, selected_idxs=None, **kwargs) -> Axes:
if has_spike_feature(self.data, "isi"):
idxs = slice(None) if selected_idxs is None else selected_idxs
thresh_t = self.lookup_spike_feature("threshold_t")[idxs]
thresh_v = self.lookup_spike_feature("threshold_v")[idxs]
isi = self.lookup_spike_feature("isi")[idxs]
ax.hlines(
thresh_v, thresh_t - isi, thresh_t, ls="--", label="isi", **kwargs
)
ax.plot(thresh_t, thresh_v, "x", **kwargs)
return ax