Source code for ephyspy.analysis

#!/usr/bin/env python3
# Copyright 2023 Jonas Beck

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.pyplot import Axes, Figure

from ephyspy.features.spike_features import available_spike_features
from ephyspy.utils import remove_mpl_artist_by_label

if TYPE_CHECKING:
    from ephyspy.sweeps import EphysSweep, EphysSweepSet

############################
### spike level features ###
############################


[docs]def plot_spike_feature( sweep: EphysSweep, ft: str, ax: Optional[Axes] = None, **kwargs ) -> Axes: """Plot spike feature by name. Args: sweep (EphysSweep): Sweep to plot the feature for. ft (str): Name of the feature to plot (all lowercase). Can plot all features that are included in the `EphysSweep._spikes_df` and all features in `available_spike_features()`. ax (Axes): Matplotlib axes. **kwargs: Additional kwargs are passed to the plotting function. Returns: Axes: Matplotlib axes. """ if ft in available_spike_features(): ax = available_spike_features()[ft](sweep).plot(ax=ax, **kwargs) else: raise ValueError(f"Feature {ft} does not exist.") return ax
[docs]def plot_spike_features( sweep: EphysSweep, window: Tuple = [0.4, 0.45] ) -> Tuple[Figure, Axes]: """Plot overview of the extracted spike features for a sweep. Args: sweep (EphysSweep): Sweep to plot the features for. window (Tuple, optional): Specific Time window to zoom in on a subset or single spikes to see more detail. Defaults to [0.4, 0.45]. Returns: Tuple[Figure, Axes]: Matplotlib figure and axes.""" mosaic = "aaabb\naaabb\ncccbb" fig, axes = plt.subplot_mosaic(mosaic, figsize=(12, 4), constrained_layout=True) # plot sweep axes["a"].plot(sweep.t, sweep.v, color="k") axes["a"].set_ylabel("Voltage (mV)") axes["a"].axvline(window[0], color="grey", alpha=0.5) axes["a"].axvline(window[1], color="grey", alpha=0.5, label="window") axes["b"].plot(sweep.t, sweep.v, color="k") axes["b"].set_ylabel("Voltage (mV)") axes["b"].set_xlabel("Time (s)") axes["b"].set_xlim(window) axes["c"].plot(sweep.t, sweep.i, color="k") axes["c"].axvline(window[0], color="grey", alpha=0.5) axes["c"].axvline(window[1], color="grey", alpha=0.5, label="window") axes["c"].set_yticks([0, np.max(sweep.i) + np.min(sweep.i)]) axes["c"].set_ylabel("Current (pA)") axes["c"].set_xlabel("Time (s)") # plot ap features for x in ["a", "b"]: for ft in available_spike_features(): plot_spike_feature(sweep, ft, axes[x]) axes["b"].legend(loc="center left", bbox_to_anchor=(1.0, 0.5)) return fig, axes
[docs]def plot_sweepset_diagnostics( sweepset: EphysSweepSet, figsize=(15, 14), ) -> Tuple[Figure, Axes]: """Plot diagnostics overview for the whole sweepset. This function is useful to diagnose outliers on the sweepset level. Args: sweepset (EphysSweepSet): sweepset to diagnose. Returns: Fig, Axes: figure and axes with plot. """ from ephyspy.features.sweepset_features import ( NullSweepSetFeature, available_sweepset_features, ) mosaic = [ ["set_fts", "set_fts", "set_fts", "r_input"], ["fp_trace", "fp_trace", "fp_trace", "rheobase"], ["ap_trace", "ap_trace", "ap_trace", "ap_window"], ["sag_fts", "set_hyperpol_fts", "set_hyperpol_fts", "rebound_fts"], ] sweepset.add_features( available_spike_features() ) # HOTFIX: Spike features are not yet looked up properly fts = NullSweepSetFeature(sweepset) def plot_sweepset_ft(fts, ft, ax, **kwargs): FT = fts.lookup_sweepset_feature(ft, return_value=False) return FT.plot(ax=ax, **kwargs) def sweep_idx(fts, ft): try: FT = fts.lookup_sweepset_feature(ft, return_value=False) return FT.diagnostics["selected_idx"] except KeyError: return slice(0) except TypeError: # features like dfdI don't have a selected_idx return slice(0) def spike_idx(fts, ft): sw_idx = sweep_idx(fts, ft) FT = fts.lookup_sweep_feature(ft, return_value=False) return FT[sw_idx].diagnostics["aggregate_idx"] fig, axes = plt.subplot_mosaic(mosaic, figsize=figsize, constrained_layout=True) onset = fts.lookup_sweep_feature("stim_onset")[0] end = fts.lookup_sweep_feature("stim_end")[0] t0, tfin = sweepset.sweeps()[0].t[[0, -1]] for ax in axes.values(): ax.set_xlim(t0, tfin) # set selected_sweeps = {} for ft in available_sweepset_features(): sweep = sweepset[sweep_idx(fts, ft)] selected_sweeps[ft] = sweep if not sweep == [] else None unique_sweeps = {} for k, v in selected_sweeps.items(): if v not in unique_sweeps: unique_sweeps[v] = k else: unique_sweeps[v] = unique_sweeps[v] + ", " + k unique_sweeps = {v: k for k, v in unique_sweeps.items()} keys = list(unique_sweeps.keys()) # combine features for shorter labels for combined_ft, tag in [ ("ap features", "ap_"), ("sag features", "sag"), ("rebound features", "rebound"), ("isi features", "isi"), ]: for i in range(len(keys)): keys[i] = ", ".join( np.unique([combined_ft if tag in k else k for k in keys[i].split(", ")]) ) unique_sweeps = {k: s for k, s in zip(keys, unique_sweeps.values())} for label, sweep in unique_sweeps.items(): try: sweep.plot(axes["set_fts"], label=label) except AttributeError: pass sweepset.plot(axes["set_fts"], color="grey", alpha=0.2) plot_sweepset_ft(fts, "slow_hyperpolarization", axes["set_fts"]) axes["set_fts"].legend(title="representative sweeps", loc="upper right") ap_sweep_idx = sweep_idx(fts, "ap_thresh") ap_idx = spike_idx(fts, "ap_amp") # fp plot_sweepset_ft(fts, "num_ap", axes["fp_trace"]) plot_sweepset_ft(fts, "ap_freq_adapt", axes["fp_trace"]) plot_sweepset_ft(fts, "ap_amp_slope", axes["fp_trace"]) stim = sweepset[sweep_idx(fts, "num_ap")].i stim_amp = int(np.max(stim) + np.min(stim)) axes["fp_trace"].legend(title=f"@{stim_amp }pA") # different selection / aggregation # plot_sweepset_ft(fts, "ap_amp_adapt", axes["fp_trace"]) # plot_sweepset_ft(fts, "isi_ff", axes["fp_trace"]) # plot_sweepset_ft(fts, "isi_cv", axes["fp_trace"]) # plot_sweepset_ft(fts, "ap_ff", axes["fp_trace"]) # plot_sweepset_ft(fts, "ap_cv", axes["fp_trace"]) # plot_sweepset_ft(fts, "isi", axes["fp_trace"]) # ap plot_sweepset_ft(fts, "ap_thresh", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_peak", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_trough", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_width", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_amp", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_ahp", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_adp", axes["ap_trace"]) plot_sweepset_ft(fts, "ap_udr", axes["ap_trace"]) stim = sweepset[sweep_idx(fts, "ap_thresh")].i stim_amp = int(np.max(stim) + np.min(stim)) axes["ap_trace"].legend(title=f"@{stim_amp }pA") ap_sweep = sweepset[ap_sweep_idx] for i, ft in enumerate(available_spike_features()): plot_spike_feature(ap_sweep, ft, axes["ap_window"], color=f"C{i}") ap_start = ap_sweep.spike_feature("threshold_t")[ap_idx] - 5e-3 ap_end = ap_sweep.spike_feature("fast_trough_t")[ap_idx] + 5e-3 if isinstance(ap_start, np.ndarray): ap_start = ap_start[0] ap_end = ap_end[-1] axes["ap_window"].set_xlim(ap_start, ap_end) axes["ap_trace"].axvline(ap_start, color="grey") axes["ap_trace"].axvline(ap_end, color="grey", label="selected ap") ap_sweep.plot(axes["ap_window"]) axes["ap_window"].legend(loc="center left", bbox_to_anchor=(1.0, 0.5)) # hyperpol plot_sweepset_ft(fts, "tau", axes["set_hyperpol_fts"]) plot_sweepset_ft(fts, "v_baseline", axes["set_hyperpol_fts"]) stim = sweepset[sweep_idx(fts, "tau")].i stim_amp = int(np.max(stim) + np.min(stim)) axes["set_hyperpol_fts"].legend(title=f"@{stim_amp }pA") # sag plot_sweepset_ft(fts, "sag_area", axes["sag_fts"]) plot_sweepset_ft(fts, "sag_time", axes["sag_fts"]) plot_sweepset_ft(fts, "sag_ratio", axes["sag_fts"], color="tab:orange") remove_mpl_artist_by_label(axes["sag_fts"], "sag") plot_sweepset_ft(fts, "sag_fraction", axes["sag_fts"], color="tab:green") remove_mpl_artist_by_label(axes["sag_fts"], "sag") plot_sweepset_ft(fts, "sag", axes["sag_fts"]) axes["sag_fts"].set_xlim(onset - 0.05, end + 0.05) stim = sweepset[sweep_idx(fts, "sag")].i stim_amp = int(np.max(stim) + np.min(stim)) axes["sag_fts"].legend(title=f"@{stim_amp }pA") # rebound plot_sweepset_ft(fts, "rebound", axes["rebound_fts"]) plot_sweepset_ft(fts, "rebound_latency", axes["rebound_fts"]) plot_sweepset_ft(fts, "rebound_area", axes["rebound_fts"]) plot_sweepset_ft(fts, "rebound_avg", axes["rebound_fts"]) axes["rebound_fts"].set_xlim(end - 0.05, None) stim = sweepset[sweep_idx(fts, "rebound")].i stim_amp = int(np.max(stim) + np.min(stim)) axes["rebound_fts"].legend(title=f"@{stim_amp }pA") axes["rebound_fts"].legend(loc="center left", bbox_to_anchor=(1.0, 0.5)) fig.text(-0.02, 0.5, "U (mV)", va="center", rotation="vertical", fontsize=16) fig.text(0.5, -0.02, "t (s)", ha="center", fontsize=16) plot_sweepset_ft(fts, "rheobase", axes["rheobase"]) plot_sweepset_ft(fts, "r_input", axes["r_input"]) axes["set_fts"].set_title("All sweeps") axes["fp_trace"].set_title("Representative spiking sweep") axes["ap_trace"].set_title("Representative AP sweep") axes["ap_window"].set_title("Representative AP") axes["set_hyperpol_fts"].set_title("Hyperpolarization sweeps") axes["sag_fts"].set_title("sag") axes["rebound_fts"].set_title("rebound") axes["rheobase"].set_title("Rheobase") return fig, axes