Source code for ephyspy.allen_sdk.ephys_extractor

# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2015-2016. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import logging
import warnings
from collections import Counter

import numpy as np
import six
from pandas import DataFrame

from . import ephys_features as ft

# Constants for stimulus-specific analysis
RAMPS_START = 1.02
LONG_SQUARES_START = 1.02
LONG_SQUARES_END = 2.02
SHORT_SQUARES_WINDOW_START = 1.02
SHORT_SQUARES_WINDOW_END = 1.021
SHORT_SQUARE_TRIPLE_WINDOW_START = 2.02
SHORT_SQUARE_TRIPLE_WINDOW_END = 2.021


class EphysSweepFeatureExtractor:
    """Feature calculation for a sweep (voltage and/or current time series)."""

    def __init__(
        self,
        t=None,
        v=None,
        i=None,
        start=None,
        end=None,
        filter=10.0,
        dv_cutoff=20.0,
        max_interval=0.005,
        min_height=2.0,
        min_peak=-30.0,
        thresh_frac=0.05,
        baseline_interval=0.1,
        baseline_detect_thresh=0.3,
        id=None,
    ):
        """Initialize SweepFeatures object.

        Parameters
        ----------
        t : ndarray of times (seconds)
        v : ndarray of voltages (mV)
        i : ndarray of currents (pA)
        start : start of time window for feature analysis (optional)
        end : end of time window for feature analysis (optional)
        filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz
        (optional, default 10)
        dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional,
        default 20)
        max_interval : maximum acceptable time between start of spike and
        time of peak in sec (optional, default 0.005)
        min_height : minimum acceptable height from threshold to peak in mV
        (optional, default 2)
        min_peak : minimum acceptable absolute peak level in mV (optional,
        default -30)
        thresh_frac : fraction of average upstroke for threshold calculation
        (optional, default 0.05)
        baseline_interval: interval length for baseline voltage calculation
        (before start if start is defined, default 0.1)
        baseline_detect_thresh : dV/dt threshold for evaluating flatness of
        baseline region (optional, default 0.3)
        """
        self.id = id
        self.t = t
        self.v = v
        self.i = i
        self.start = start
        self.end = end
        self.filter = filter
        self.dv_cutoff = dv_cutoff
        self.max_interval = max_interval
        self.min_height = min_height
        self.min_peak = min_peak
        self.thresh_frac = thresh_frac
        self.baseline_interval = baseline_interval
        self.baseline_detect_thresh = baseline_detect_thresh
        self.stimulus_amplitude_calculator = None

        self._sweep_features = {}
        self._affected_by_clipping = []

    def process_spikes(self):
        """Perform spike-related feature analysis"""
        self._process_individual_spikes()
        self._process_spike_related_features()

    def _process_individual_spikes(self):
        v = self.v
        t = self.t
        dvdt = ft.calculate_dvdt(v, t, self.filter)

        # Basic features of spikes
        putative_spikes = ft.detect_putative_spikes(
            v, t, self.start, self.end, self.filter, self.dv_cutoff
        )
        peaks = ft.find_peak_indexes(v, t, putative_spikes, self.end)
        putative_spikes, peaks = ft.filter_putative_spikes(
            v,
            t,
            putative_spikes,
            peaks,
            self.min_height,
            self.min_peak,
            dvdt=dvdt,
            filter=self.filter,
        )

        if not putative_spikes.size:
            # Save time if no spikes detected
            self._spikes_df = DataFrame()
            return

        upstrokes = ft.find_upstroke_indexes(
            v, t, putative_spikes, peaks, self.filter, dvdt
        )
        thresholds = ft.refine_threshold_indexes(
            v, t, upstrokes, self.thresh_frac, self.filter, dvdt
        )
        thresholds, peaks, upstrokes, clipped = ft.check_thresholds_and_peaks(
            v,
            t,
            thresholds,
            peaks,
            upstrokes,
            self.end,
            self.max_interval,
            dvdt=dvdt,
            filter=self.filter,
        )
        if not thresholds.size:
            # Save time if no spikes detected
            self._spikes_df = DataFrame()
            return

        # Spike list and thresholds have been refined - now find other features
        upstrokes = ft.find_upstroke_indexes(v, t, thresholds, peaks, self.filter, dvdt)
        troughs = ft.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end)
        downstrokes = ft.find_downstroke_indexes(
            v, t, peaks, troughs, clipped, self.filter, dvdt
        )
        trough_details, clipped = ft.analyze_trough_details(
            v, t, thresholds, peaks, clipped, self.end, self.filter, dvdt=dvdt
        )
        widths = ft.find_widths(v, t, thresholds, peaks, trough_details[1], clipped)

        base_clipped_list = []

        # Points where we care about t, v, and i if available
        vit_data_indexes = {
            "threshold": thresholds,
            "peak": peaks,
            "trough": troughs,
        }
        base_clipped_list += ["trough"]

        # Points where we care about t and dv/dt
        dvdt_data_indexes = {"upstroke": upstrokes, "downstroke": downstrokes}
        base_clipped_list += ["downstroke"]

        # Trough details
        isi_types = trough_details[0]
        trough_detail_indexes = dict(
            zip(["fast_trough", "adp", "slow_trough"], trough_details[1:])
        )
        base_clipped_list += ["fast_trough", "adp", "slow_trough"]

        # Redundant, but ensures that DataFrame has right number of rows
        # Any better way to do it?
        spikes_df = DataFrame(data=thresholds, columns=["threshold_index"])
        spikes_df["clipped"] = clipped

        for k, all_vals in six.iteritems(vit_data_indexes):
            valid_ind = ~np.isnan(all_vals)
            vals = all_vals[valid_ind].astype(int)
            spikes_df[k + "_index"] = np.nan
            spikes_df[k + "_t"] = np.nan
            spikes_df[k + "_v"] = np.nan

            if len(vals) > 0:
                spikes_df.loc[valid_ind, k + "_index"] = vals
                spikes_df.loc[valid_ind, k + "_t"] = t[vals]
                spikes_df.loc[valid_ind, k + "_v"] = v[vals]

            if self.i is not None:
                spikes_df[k + "_i"] = np.nan
                if len(vals) > 0:
                    spikes_df.loc[valid_ind, k + "_i"] = self.i[vals]

            if k in base_clipped_list:
                self._affected_by_clipping += [
                    k + "_index",
                    k + "_t",
                    k + "_v",
                    k + "_i",
                ]

        for k, all_vals in six.iteritems(dvdt_data_indexes):
            valid_ind = ~np.isnan(all_vals)
            vals = all_vals[valid_ind].astype(int)
            spikes_df[k + "_index"] = np.nan
            spikes_df[k] = np.nan
            if len(vals) > 0:
                spikes_df.loc[valid_ind, k + "_index"] = vals
                spikes_df.loc[valid_ind, k + "_t"] = t[vals]
                spikes_df.loc[valid_ind, k + "_v"] = v[vals]
                spikes_df.loc[valid_ind, k] = dvdt[vals]

                if k in base_clipped_list:
                    self._affected_by_clipping += [
                        k + "_index",
                        k + "_t",
                        k + "_v",
                        k,
                    ]

        spikes_df["isi_type"] = isi_types
        self._affected_by_clipping += ["isi_type"]

        for k, all_vals in six.iteritems(trough_detail_indexes):
            valid_ind = ~np.isnan(all_vals)
            vals = all_vals[valid_ind].astype(int)
            spikes_df[k + "_index"] = np.nan
            spikes_df[k + "_t"] = np.nan
            spikes_df[k + "_v"] = np.nan
            if len(vals) > 0:
                spikes_df.loc[valid_ind, k + "_index"] = vals
                spikes_df.loc[valid_ind, k + "_t"] = t[vals]
                spikes_df.loc[valid_ind, k + "_v"] = v[vals]

            if self.i is not None:
                spikes_df[k + "_i"] = np.nan
                if len(vals) > 0:
                    spikes_df.loc[valid_ind, k + "_i"] = self.i[vals]

            if k in base_clipped_list:
                self._affected_by_clipping += [
                    k + "_index",
                    k + "_t",
                    k + "_v",
                    k + "_i",
                ]

        spikes_df["width"] = widths
        self._affected_by_clipping += ["width"]

        spikes_df["upstroke_downstroke_ratio"] = (
            spikes_df["upstroke"] / -spikes_df["downstroke"]
        )
        self._affected_by_clipping += ["upstroke_downstroke_ratio"]

        self._spikes_df = spikes_df

    def _process_spike_related_features(self):
        t = self.t

        if len(self._spikes_df) == 0:
            self._sweep_features["avg_rate"] = 0
            return

        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(t, thresholds)
        with warnings.catch_warnings():
            # ignore mean of empty slice warnings here
            warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")

            sweep_level_features = {
                "adapt": ft.adaptation_index(isis),
                "latency": ft.latency(t, thresholds, self.start),
                "isi_cv": (isis.std() / isis.mean()) if len(isis) >= 1 else np.nan,
                "mean_isi": isis.mean() if len(isis) > 0 else np.nan,
                "median_isi": np.median(isis),
                "first_isi": isis[0] if len(isis) >= 1 else np.nan,
                "avg_rate": ft.average_rate(t, thresholds, self.start, self.end),
            }

        for k, v in six.iteritems(sweep_level_features):
            self._sweep_features[k] = v

    def _process_pauses(self, cost_weight=1.0):
        # Pauses are unusually long ISIs with a "detour reset" among delay resets
        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(self.t, thresholds)
        isi_types = self._spikes_df["isi_type"][:-1].values

        return ft.detect_pauses(isis, isi_types, cost_weight)

    def pause_metrics(self):
        """Estimate average number of pauses and average fraction of time spent in a pause

        Attempts to detect pauses with a variety of conditions and averages results together.

        Pauses that are consistently detected contribute more to estimates.

        Returns
        -------
        avg_n_pauses : average number of pauses detected across conditions
        avg_pause_frac : average fraction of interval (between start and end) spent in a pause
        max_reliability : max fraction of times most reliable pause was detected given weights tested
        n_max_rel_pauses : number of pauses detected with `max_reliability`
        """

        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(self.t, thresholds)

        weight = 1.0
        pause_list = self._process_pauses(weight)

        if len(pause_list) == 0:
            return 0, 0.0

        n_pauses = len(pause_list)
        pause_frac = isis[pause_list].sum()
        pause_frac /= self.end - self.start

        return n_pauses, pause_frac

    def _process_bursts(self, tol=0.5, pause_cost=1.0):
        thresholds = self._spikes_df["threshold_index"].values.astype(int)
        isis = ft.get_isis(self.t, thresholds)

        isi_types = self._spikes_df["isi_type"][:-1].values

        fast_tr_v = self._spikes_df["fast_trough_v"].values
        fast_tr_t = self._spikes_df["fast_trough_t"].values
        slow_tr_v = self._spikes_df["slow_trough_v"].values
        slow_tr_t = self._spikes_df["slow_trough_t"].values
        thr_v = self._spikes_df["threshold_v"].values

        bursts = ft.detect_bursts(
            isis,
            isi_types,
            fast_tr_v,
            fast_tr_t,
            slow_tr_v,
            slow_tr_t,
            thr_v,
            tol,
            pause_cost,
        )

        return np.array(bursts)

    def burst_metrics(self):
        """Find bursts and return max "burstiness" index (normalized max rate in burst vs out).

        Returns
        -------
        max_burstiness_index : max "burstiness" index across detected bursts
        num_bursts : number of bursts detected
        """

        burst_info = self._process_bursts()

        if burst_info.shape[0] > 0:
            return burst_info[:, 0].max(), burst_info.shape[0]
        else:
            return 0.0, 0

    def delay_metrics(self):
        """Calculates ratio of latency to dominant time constant of rise before spike

        Returns
        -------
        delay_ratio : ratio of latency to tau (higher means more delay)
        tau : dominant time constant of rise before spike
        """

        if len(self._spikes_df) == 0:
            logging.info("No spikes available for delay calculation")
            return 0.0, 0.0
        start = self.start
        spike_time = self._spikes_df["threshold_t"].values[0]

        tau = ft.fit_prespike_time_constant(self.v, self.t, start, spike_time)
        latency = spike_time - start

        delay_ratio = latency / tau
        return delay_ratio, tau

    def _get_baseline_voltage(self):
        v = self.v
        t = self.t
        filter_frequency = 1.0  # in kHz

        # Look at baseline interval before start if start is defined
        if self.start is not None:
            return ft.average_voltage(
                v, t, self.start - self.baseline_interval, self.start
            )

        # Otherwise try to find an interval where things are pretty flat
        dv = ft.calculate_dvdt(v, t, filter_frequency)
        non_flat_points = np.flatnonzero(np.abs(dv >= self.baseline_detect_thresh))
        flat_intervals = t[non_flat_points[1:]] - t[non_flat_points[:-1]]
        long_flat_intervals = np.flatnonzero(flat_intervals >= self.baseline_interval)
        if long_flat_intervals.size > 0:
            interval_index = long_flat_intervals[0] + 1
            baseline_end_time = t[non_flat_points[interval_index]]
            return ft.average_voltage(
                v, t, baseline_end_time - self.baseline_interval, baseline_end_time
            )
        else:
            logging.info(
                "Could not find sufficiently flat interval for automatic baseline voltage",
                RuntimeWarning,
            )
            return np.nan

    def voltage_deflection(self, deflect_type=None):
        """Measure deflection (min or max, between start and end if specified).

        Parameters
        ----------
        deflect_type : measure minimal ('min') or maximal ('max') voltage deflection
            If not specified, it will check to see if the current (i) is positive or negative
            between start and end, then choose 'max' or 'min', respectively
            If the current is not defined, it will default to 'min'.

        Returns
        -------
        deflect_v : peak
        deflect_index : index of peak deflection
        """

        deflect_dispatch = {
            "min": np.argmin,
            "max": np.argmax,
        }

        start = self.start
        if not start:
            start = 0
        start_index = ft.find_time_index(self.t, start)

        end = self.end
        if not end:
            end = self.t[-1]
        end_index = ft.find_time_index(self.t, end)

        if deflect_type is None:
            if self.i is not None:
                halfway_index = ft.find_time_index(self.t, (end - start) / 2.0 + start)
                if self.i[halfway_index] >= 0:
                    deflect_type = "max"
                else:
                    deflect_type = "min"
            else:
                deflect_type = "min"

        deflect_func = deflect_dispatch[deflect_type]

        v_window = self.v[start_index:end_index]
        deflect_index = deflect_func(v_window) + start_index

        return self.v[deflect_index], deflect_index

    def stimulus_amplitude(self):
        """ """
        if self.stimulus_amplitude_calculator is not None:
            return self.stimulus_amplitude_calculator(self)
        else:
            return np.nan

    def estimate_time_constant(self):
        """Calculate the membrane time constant by fitting the voltage response with a
        single exponential.

        Returns
        -------
        tau : membrane time constant in seconds
        """

        # Assumes this is being done on a hyperpolarizing step
        v_peak, peak_index = self.voltage_deflection("min")
        v_baseline = self.sweep_feature("v_baseline")

        if self.start:
            start_index = ft.find_time_index(self.t, self.start)
        else:
            start_index = 0

        frac = 0.1
        search_result = np.flatnonzero(
            self.v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline
        )
        if not search_result.size:
            raise ft.FeatureError("could not find interval for time constant estimate")
        fit_start = self.t[search_result[0] + start_index]
        fit_end = self.t[peak_index]

        a, inv_tau, y0 = ft.fit_membrane_time_constant(
            self.v, self.t, fit_start, fit_end
        )

        return 1.0 / inv_tau

    def estimate_sag(self, peak_width=0.005):
        """Calculate the sag in a hyperpolarizing voltage response.

        Parameters
        ----------
        peak_width : window width to get more robust peak estimate in sec (default 0.005)

        Returns
        -------
        sag : fraction that membrane potential relaxes back to baseline
        """

        t = self.t
        v = self.v

        start = self.start
        if not start:
            start = 0

        end = self.end
        if not end:
            end = self.t[-1]

        v_peak, peak_index = self.voltage_deflection("min")
        v_peak_avg = ft.average_voltage(
            v,
            t,
            start=t[peak_index] - peak_width / 2.0,
            end=t[peak_index] + peak_width / 2.0,
        )
        v_baseline = self.sweep_feature("v_baseline")
        v_steady = ft.average_voltage(v, t, start=end - self.baseline_interval, end=end)
        sag = (v_peak_avg - v_steady) / (v_peak_avg - v_baseline)
        return sag

    def spikes(self):
        """Get all features for each spike as a list of records."""
        return self._spikes_df.to_dict("records")

    def spike_feature(self, key, include_clipped=False, force_exclude_clipped=False):
        """Get specified feature for every spike.

        Parameters
        ----------
        key : feature name
        include_clipped: return values for every identified spike, even when
        clipping means they will be incorrect/undefined

        Returns
        -------
        spike_feature_values : ndarray of features for each spike
        """

        if not hasattr(self, "_spikes_df"):
            raise AttributeError(
                "EphysSweepFeatureExtractor instance attribute with spike "
                "information does not exist yet - have spikes been processed?"
            )

        if len(self._spikes_df) == 0:
            return np.array([])

        if key not in self._spikes_df.columns:
            raise KeyError("requested feature '{:s}' not available".format(key))

        values = self._spikes_df[key].values

        if include_clipped and force_exclude_clipped:
            raise ValueError(
                "include_clipped and force_exclude_clipped cannot both be true"
            )

        if not include_clipped and self.is_spike_feature_affected_by_clipping(key):
            values = values[~self._spikes_df["clipped"].values]
        elif force_exclude_clipped:
            values = values[~self._spikes_df["clipped"].values]

        return values

    def is_spike_feature_affected_by_clipping(self, key):
        return key in self._affected_by_clipping

    def spike_feature_keys(self):
        """Get list of every available spike feature."""
        return self._spikes_df.columns.values.tolist()

    def sweep_feature(self, key, allow_missing=False):
        """Get sweep-level feature (`key`).

        Parameters
        ----------
        key : name of sweep-level feature
        allow_missing : return np.nan if key is missing for sweep (default False)

        Returns
        -------
        sweep_feature : sweep-level feature value
        """

        on_request_dispatch = {
            "v_baseline": self._get_baseline_voltage,
            "tau": self.estimate_time_constant,
            "sag": self.estimate_sag,
            "peak_deflect": self.voltage_deflection,
            "stim_amp": self.stimulus_amplitude,
        }

        if (
            allow_missing
            and key not in self._sweep_features
            and key not in on_request_dispatch
        ):
            return np.nan
        elif key not in self._sweep_features and key not in on_request_dispatch:
            raise KeyError("requested feature '{:s}' not available".format(key))

        if key not in self._sweep_features and key in on_request_dispatch:
            fn = on_request_dispatch[key]
            if fn is not None:
                self._sweep_features[key] = fn()
            else:
                raise KeyError("requested feature '{:s}' not defined".format(key))

        return self._sweep_features[key]

    def process_new_spike_feature(
        self, feature_name, feature_func, affected_by_clipping=False
    ):
        """Add new spike-level feature calculation function

        The function should take this sweep extractor as its argument. Its results
        can be accessed by calling the method spike_feature(<feature_name>).
        """

        if feature_name in self._spikes_df.columns:
            raise KeyError("Feature {:s} already exists for sweep".format(feature_name))

        self._spikes_df[feature_name] = feature_func(self)

        if affected_by_clipping:
            self._affected_by_clipping.append(feature_name)

    def process_new_sweep_feature(self, feature_name, feature_func):
        """Add new sweep-level feature calculation function

        The function should take this sweep extractor as its argument. Its results
        can be accessed by calling the method sweep_feature(<feature_name>).
        """

        if feature_name in self._sweep_features:
            raise KeyError("Feature {:s} already exists for sweep".format(feature_name))

        self._sweep_features[feature_name] = feature_func(self)

    def set_stimulus_amplitude_calculator(self, function):
        self.stimulus_amplitude_calculator = function

    def sweep_feature_keys(self):
        """Get list of every available sweep-level feature."""
        return self._sweep_features.keys()

    def as_dict(self):
        """Create dict of features and spikes."""
        output_dict = self._sweep_features.copy()
        output_dict["spikes"] = self.spikes()
        if self.id is not None:
            output_dict["id"] = self.id
        return output_dict


[docs]class EphysSweepSetFeatureExtractor: def __init__( self, t_set=None, v_set=None, i_set=None, start=None, end=None, filter=10.0, dv_cutoff=20.0, max_interval=0.005, min_height=2.0, min_peak=-30.0, thresh_frac=0.05, baseline_interval=0.1, baseline_detect_thresh=0.3, id_set=None, ): """Initialize EphysSweepSetFeatureExtractor object. Parameters ---------- t_set : list of ndarray of times in seconds v_set : list of ndarray of voltages in mV i_set : list of ndarray of currents in pA start : start of time window for feature analysis (optional, can be list) end : end of time window for feature analysis (optional, can be list) filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10) dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20) max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005) min_height : minimum acceptable height from threshold to peak in mV (optional, default 2) min_peak : minimum acceptable absolute peak level in mV (optional, default -30) thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05) baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1) baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3) """ if t_set is not None and v_set is not None: self._set_sweeps( t_set, v_set, i_set, start, end, filter, dv_cutoff, max_interval, min_height, min_peak, thresh_frac, baseline_interval, baseline_detect_thresh, id_set, ) else: self._sweeps = None
[docs] @classmethod def from_sweeps(cls, sweep_list): """Initialize EphysSweepSetFeatureExtractor object with a list of pre-existing sweep feature extractor objects. """ obj = cls() obj._sweeps = sweep_list return obj
def _set_sweeps( self, t_set, v_set, i_set, start, end, filter, dv_cutoff, max_interval, min_height, min_peak, thresh_frac, baseline_interval, baseline_detect_thresh, id_set, ): if type(t_set) != list: raise ValueError("t_set must be a list") if type(v_set) != list: raise ValueError("v_set must be a list") if i_set is not None and type(i_set) != list: raise ValueError("i_set must be a list") if len(t_set) != len(v_set): raise ValueError("t_set and v_set must have the same number of items") if i_set and len(t_set) != len(i_set): raise ValueError("t_set and i_set must have the same number of items") if id_set is None: id_set = range(len(t_set)) if len(id_set) != len(t_set): raise ValueError("t_set and id_set must have the same number of items") sweeps = [] if i_set is None: i_set = [None] * len(t_set) if type(start) is not list: start = [start] * len(t_set) end = [end] * len(t_set) sweeps = [ EphysSweepFeatureExtractor( t, v, i, start, end, filter=filter, dv_cutoff=dv_cutoff, max_interval=max_interval, min_height=min_height, min_peak=min_peak, thresh_frac=thresh_frac, baseline_interval=baseline_interval, baseline_detect_thresh=baseline_detect_thresh, id=sid, ) for t, v, i, start, end, sid in zip(t_set, v_set, i_set, start, end, id_set) ] self._sweeps = sweeps
[docs] def sweeps(self): """Get list of EphysSweepFeatureExtractor objects.""" return self._sweeps
[docs] def process_spikes(self): """Analyze spike features for all sweeps.""" for sweep in self._sweeps: sweep.process_spikes()
[docs] def sweep_features(self, key, allow_missing=False): """Get nparray of sweep-level feature (`key`) for all sweeps Parameters ---------- key : name of sweep-level feature allow_missing : return np.nan if key is missing for sweep (default False) Returns ------- sweep_feature : nparray of sweep-level feature values """ return np.array([swp.sweep_feature(key, allow_missing) for swp in self._sweeps])
[docs] def spike_feature_averages(self, key): """Get nparray of average spike-level feature (`key`) for all sweeps""" return np.array([swp.spike_feature(key).mean() for swp in self._sweeps])
[docs]class EphysCellFeatureExtractor: # Class constants for specific processing SUBTHRESH_MAX_AMP = 0 SAG_TARGET = -100.0 def __init__( self, ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp=-100 ): """Initialize EphysCellFeatureExtractor object from EphysSweepSetExtractors for ramp, short square, and long square sweeps. Parameters ---------- dataset : NwbDataSet ramps_ext : EphysSweepSetFeatureExtractor prepared with ramp sweeps short_squares_ext : EphysSweepSetFeatureExtractor prepared with short square sweeps long_squares_ext : EphysSweepSetFeatureExtractor prepared with long square sweeps """ self._ramps_ext = ramps_ext self._short_squares_ext = short_squares_ext self._long_squares_ext = long_squares_ext self._subthresh_min_amp = subthresh_min_amp self._features = { "ramps": {}, "short_squares": {}, "long_squares": {}, } self._spiking_long_squares_ext = None self._subthreshold_long_squares_ext = None self._subthreshold_membrane_property_ext = None
[docs] def process(self, keys=None): """Processes features. Can take a specific key (or set of keys) to do a subset of processing.""" dispatch = { "ramps": self._analyze_ramps, "short_squares": self._analyze_short_squares, "long_squares": self._analyze_long_squares, "long_squares_spiking": self._analyze_long_squares_spiking, } if keys is None: keys = list(dispatch.keys()) if type(keys) is not list: keys = list(keys) for k in [j for j in keys if j in dispatch]: dispatch[k]()
def _analyze_ramps(self): ext = self._ramps_ext ext.process_spikes() self._all_ramps_ext = ext # pull out the spiking sweeps spiking_sweeps = [ sweep for sweep in self._ramps_ext.sweeps() if sweep.sweep_feature("avg_rate") > 0 ] ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps) self._ramps_ext = ext self._features["ramps"]["spiking_sweeps"] = ext.sweeps()
[docs] def ramps_features(self, all=False): if all: return self._all_ramps_ext else: return self._ramps_ext
def _analyze_short_squares(self): ext = self._short_squares_ext ext.process_spikes() # Need to count how many had spikes at each amplitude; find most; # ties go to lower amplitude spiking_sweeps = [ sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0 ] if len(spiking_sweeps) == 0: raise ft.FeatureError( "No spiking short square sweeps, cannot compute cell features." ) most_common = Counter(map(_short_step_stim_amp, spiking_sweeps)).most_common() common_amp, common_count = most_common[0] for c in most_common[1:]: if c[1] < common_count: break if c[0] < common_amp: common_amp = c[0] self._features["short_squares"]["stimulus_amplitude"] = common_amp ext = EphysSweepSetFeatureExtractor.from_sweeps( [ sweep for sweep in spiking_sweeps if _short_step_stim_amp(sweep) == common_amp ] ) self._short_squares_ext = ext self._features["short_squares"]["common_amp_sweeps"] = ext.sweeps() for s in self._features["short_squares"]["common_amp_sweeps"]: s.set_stimulus_amplitude_calculator(_short_step_stim_amp)
[docs] def short_squares_features(self): return self._short_squares_ext
def _analyze_long_squares(self): self._analyze_long_squares_spiking() self._analyze_long_squares_subthreshold() def _analyze_long_squares_spiking(self, force_reprocess=False): if not force_reprocess and self._spiking_long_squares_ext: return ext = self._long_squares_ext ext.process_spikes() self._features["long_squares"]["sweeps"] = ext.sweeps() for s in self._features["long_squares"]["sweeps"]: s.set_stimulus_amplitude_calculator(_step_stim_amp) spiking_indexes = np.flatnonzero(ext.sweep_features("avg_rate")) if len(spiking_indexes) == 0: raise ft.FeatureError( "No spiking long square sweeps, cannot compute cell features." ) amps = ext.sweep_features("stim_amp") # self.long_squares_stim_amps() min_index = np.argmin(amps[spiking_indexes]) rheobase_index = spiking_indexes[min_index] rheobase_i = _step_stim_amp(ext.sweeps()[rheobase_index]) self._features["long_squares"]["rheobase_extractor_index"] = rheobase_index self._features["long_squares"]["rheobase_i"] = rheobase_i self._features["long_squares"]["rheobase_sweep"] = ext.sweeps()[rheobase_index] spiking_sweeps = [ sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0 ] self._spiking_long_squares_ext = EphysSweepSetFeatureExtractor.from_sweeps( spiking_sweeps ) self._features["long_squares"][ "spiking_sweeps" ] = self._spiking_long_squares_ext.sweeps() self._features["long_squares"]["fi_fit_slope"] = fit_fi_slope( self._spiking_long_squares_ext ) def _analyze_long_squares_subthreshold(self): ext = self._long_squares_ext subthresh_sweeps = [ sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") == 0 ] subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(subthresh_sweeps) self._subthreshold_long_squares_ext = subthresh_ext if len(subthresh_ext.sweeps()) == 0: raise ft.FeatureError( "No subthreshold long square sweeps, cannot evaluate cell features." ) sags = subthresh_ext.sweep_features("sag") sag_eval_levels = np.array( [sweep.voltage_deflection()[0] for sweep in subthresh_ext.sweeps()] ) target_level = self.SAG_TARGET closest_index = np.argmin(np.abs(sag_eval_levels - target_level)) self._features["long_squares"]["sag"] = sags[closest_index] self._features["long_squares"]["vm_for_sag"] = sag_eval_levels[closest_index] self._features["long_squares"]["subthreshold_sweeps"] = subthresh_ext.sweeps() for s in self._features["long_squares"]["subthreshold_sweeps"]: s.set_stimulus_amplitude_calculator(_step_stim_amp) logging.debug("subthresh_sweeps: %d", len(subthresh_sweeps)) calc_subthresh_sweeps = [ sweep for sweep in subthresh_sweeps if self._subthresh_min_amp < sweep.sweep_feature("stim_amp") < self.SUBTHRESH_MAX_AMP ] # noqa F501 logging.debug("calc_subthresh_sweeps: %d", len(calc_subthresh_sweeps)) calc_subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps( calc_subthresh_sweeps ) self._subthreshold_membrane_property_ext = calc_subthresh_ext self._features["long_squares"][ "subthreshold_membrane_property_sweeps" ] = calc_subthresh_ext.sweeps() self._features["long_squares"]["input_resistance"] = input_resistance( calc_subthresh_ext ) self._features["long_squares"]["tau"] = membrane_time_constant( calc_subthresh_ext ) self._features["long_squares"]["v_baseline"] = np.nanmean( ext.sweep_features("v_baseline") )
[docs] def long_squares_features(self, option=None): option_table = { "spiking": self._spiking_long_squares_ext, "subthreshold": self._subthreshold_long_squares_ext, "subthreshold_membrane_property": self._subthreshold_membrane_property_ext, } if option: return option_table[option] return self._long_squares_ext
[docs] def long_squares_stim_amps(self, option=None): option_table = { "spiking": self._spiking_long_squares_ext, "subthreshold": self._subthreshold_long_squares_ext, "subthreshold_membrane_property": self._subthreshold_membrane_property_ext, } if option: ext = option_table[option] else: ext = self._long_squares_ext return np.array(map(_step_stim_amp, ext.sweeps()))
[docs] def cell_features(self): return self._features
[docs] def as_dict(self): """Create dict of cell features.""" # get shallow copies of the sub-type dictionaries out = { "long_squares": self._features["long_squares"].copy(), "short_squares": self._features["short_squares"].copy(), "ramps": self._features["ramps"].copy(), } # convert feature extractor lists to sweep dictionarsweep extract lists ls_sweeps = [s.as_dict() for s in out["long_squares"]["sweeps"]] ls_spike_sweeps = [s.as_dict() for s in out["long_squares"]["spiking_sweeps"]] rheo_sweep = out["long_squares"]["rheobase_sweep"].as_dict() ls_sub_sweeps = [ s.as_dict() for s in out["long_squares"]["subthreshold_sweeps"] ] ls_sub_mem_sweeps = [ s.as_dict() for s in out["long_squares"]["subthreshold_membrane_property_sweeps"] ] ss_sweeps = [s.as_dict() for s in out["short_squares"]["common_amp_sweeps"]] ramp_sweeps = [s.as_dict() for s in out["ramps"]["spiking_sweeps"]] out["long_squares"]["sweeps"] = ls_sweeps out["long_squares"]["spiking_sweeps"] = ls_spike_sweeps out["long_squares"]["subthreshold_sweeps"] = ls_sub_sweeps out["long_squares"]["subthreshold_membrane_property_sweeps"] = ls_sub_mem_sweeps out["long_squares"]["rheobase_sweep"] = rheo_sweep out["short_squares"]["common_amp_sweeps"] = ss_sweeps out["ramps"]["spiking_sweeps"] = ramp_sweeps return out
[docs]def input_resistance(ext): """Estimate input resistance in MOhms, assuming all sweeps in passed extractor are hyperpolarizing responses.""" sweeps = ext.sweeps() if not sweeps: raise ft.FeatureError("no sweeps available for input resistance calculation") v_vals = [] i_vals = [] for sweep in sweeps: if sweep.i is None: raise ft.FeatureError( "cannot calculate input resistance: i not defined for a sweep" ) v_peak, min_index = sweep.voltage_deflection("min") v_vals.append(v_peak) i_vals.append(sweep.i[min_index]) v = np.array(v_vals) i = np.array(i_vals) if len(v) == 1: # If there's just one sweep, we'll have to use its own baseline to # estimate the input resistance v = np.append(v, sweeps[0].sweep_feature("v_baseline")) i = np.append(i, 0.0) A = np.vstack([i, np.ones_like(i)]).T m, c = np.linalg.lstsq(A, v)[0] return m * 1e3
[docs]def membrane_time_constant(ext): """Average the membrane time constant values estimated from each sweep in passed extractor.""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy") avg_tau = np.nanmean(ext.sweep_features("tau")) return avg_tau
[docs]def fit_fi_slope(ext): """Fit the rate and stimulus amplitude to a line and return the slope of the fit.""" if len(ext.sweeps()) < 2: raise ft.FeatureError( "Cannot fit f-I curve slope with less than two suprathreshold sweeps" ) x = np.array(list(map(_step_stim_amp, ext.sweeps()))) y = ext.sweep_features("avg_rate") A = np.vstack([x, np.ones_like(x)]).T m, c = np.linalg.lstsq(A, y)[0] return m
[docs]def reset_long_squares_start(when): global LONG_SQUARES_START, LONG_SQUARES_END delta = LONG_SQUARES_END - LONG_SQUARES_START LONG_SQUARES_START = when LONG_SQUARES_END = when + delta
[docs]def cell_extractor_for_nwb( dataset, ramps, short_squares, long_squares, subthresh_min_amp=-100 ): """Initialize EphysCellFeatureExtractor object from NWB data set Parameters ---------- dataset : NwbDataSet ramps : list of sweep numbers of ramp sweeps short_squares : list of sweep numbers of short square sweeps long_squares : list of sweep numbers of long square sweeps """ if len(short_squares) == 0: raise ft.FeatureError("no short square sweep numbers provided") if len(ramps) == 0: raise ft.FeatureError("no ramp sweep numbers provided") if len(long_squares) == 0: raise ft.FeatureError("no long_square sweep numbers provided") ramps_ext = extractor_for_nwb_sweeps(dataset, ramps, fixed_start=RAMPS_START) temp_short_sq_ext = extractor_for_nwb_sweeps(dataset, short_squares) t_set = [s.t for s in temp_short_sq_ext.sweeps()] v_set = [s.v for s in temp_short_sq_ext.sweeps()] cutoff, thresh_frac = ft.estimate_adjusted_detection_parameters( v_set, t_set, SHORT_SQUARES_WINDOW_START, SHORT_SQUARES_WINDOW_END ) thresh_frac = max(thresh_frac, 0.1) short_squares_ext = extractor_for_nwb_sweeps( dataset, short_squares, dv_cutoff=cutoff, thresh_frac=thresh_frac ) long_squares_ext = extractor_for_nwb_sweeps( dataset, long_squares, fixed_start=LONG_SQUARES_START, fixed_end=LONG_SQUARES_END, ) return EphysCellFeatureExtractor( ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp )
[docs]def extractor_for_nwb_sweeps( dataset, sweep_numbers, fixed_start=None, fixed_end=None, dv_cutoff=20.0, thresh_frac=0.05, ): v_set = [] t_set = [] i_set = [] start = [] end = [] for sweep_number in sweep_numbers: data = dataset.get_sweep(sweep_number) v = data["response"] * 1e3 # mV i = data["stimulus"] * 1e12 # pA hz = data["sampling_rate"] dt = 1.0 / hz t = np.arange(0, len(v)) * dt # sec s, e = dt * np.array(data["index_range"]) v_set.append(v) i_set.append(i) t_set.append(t) start.append(s) end.append(e) if fixed_start and not fixed_end: start = [fixed_start] * len(end) elif fixed_start and fixed_end: start = fixed_start end = fixed_end return EphysSweepSetFeatureExtractor( t_set, v_set, i_set, start=start, end=end, dv_cutoff=dv_cutoff, thresh_frac=thresh_frac, id_set=sweep_numbers, )
def _step_stim_amp(sweep): t_index = ft.find_time_index(sweep.t, sweep.start) return sweep.i[t_index + 1] def _short_step_stim_amp(sweep): t_index = ft.find_time_index(sweep.t, sweep.start) return sweep.i[t_index + 1 :].max()