# -*- coding: utf-8 -*-
"""
Tests for the function sta module
:copyright: Copyright 2015-2016 by the Elephant team, see AUTHORS.txt.
:license: Modified BSD, see LICENSE.txt for details.
"""
import unittest
import math
import numpy as np
import scipy
from numpy.testing import assert_array_equal
from numpy.testing.utils import assert_array_almost_equal
import neo
from neo import AnalogSignal, SpikeTrain
from elephant.conversion import BinnedSpikeTrain
import quantities as pq
from quantities import ms, mV, Hz
import elephant.sta as sta
import warnings
[docs]class sta_TestCase(unittest.TestCase):
[docs] def setUp(self):
self.asiga0 = AnalogSignal(np.array([
np.sin(np.arange(0, 20 * math.pi, 0.1))]).T,
units='mV', sampling_rate=10 / ms)
self.asiga1 = AnalogSignal(np.array([
np.sin(np.arange(0, 20 * math.pi, 0.1)),
np.cos(np.arange(0, 20 * math.pi, 0.1))]).T,
units='mV', sampling_rate=10 / ms)
self.asiga2 = AnalogSignal(np.array([
np.sin(np.arange(0, 20 * math.pi, 0.1)),
np.cos(np.arange(0, 20 * math.pi, 0.1)),
np.tan(np.arange(0, 20 * math.pi, 0.1))]).T,
units='mV', sampling_rate=10 / ms)
self.st0 = SpikeTrain(
[9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi],
units='ms', t_stop=self.asiga0.t_stop)
self.lst = [SpikeTrain(
[9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi],
units='ms', t_stop=self.asiga1.t_stop),
SpikeTrain([30, 35, 40], units='ms', t_stop=self.asiga1.t_stop)]
#***********************************************************************
#************************ Test for typical values **********************
[docs] def test_spike_triggered_average_with_n_spikes_on_constant_function(self):
'''Signal should average to the input'''
const = 13.8
x = const * np.ones(201)
asiga = AnalogSignal(
np.array([x]).T, units='mV', sampling_rate=10 / ms)
st = SpikeTrain([3, 5.6, 7, 7.1, 16, 16.3], units='ms', t_stop=20)
window_starttime = -2 * ms
window_endtime = 2 * ms
STA = sta.spike_triggered_average(
asiga, st, (window_starttime, window_endtime))
a = int(((window_endtime - window_starttime) *
asiga.sampling_rate).simplified)
cutout = asiga[0: a]
cutout.t_start = window_starttime
assert_array_almost_equal(STA, cutout, 12)
[docs] def test_spike_triggered_average_with_shifted_sin_wave(self):
'''Signal should average to zero'''
STA = sta.spike_triggered_average(
self.asiga0, self.st0, (-4 * ms, 4 * ms))
target = 5e-2 * mV
self.assertEqual(np.abs(STA).max().dimensionality.simplified,
pq.Quantity(1, "V").dimensionality.simplified)
self.assertLess(np.abs(STA).max(), target)
[docs] def test_only_one_spike(self):
'''The output should be the same as the input'''
x = np.arange(0, 20, 0.1)
y = x**2
sr = 10 / ms
z = AnalogSignal(np.array([y]).T, units='mV', sampling_rate=sr)
spiketime = 8 * ms
spiketime_in_ms = int((spiketime / ms).simplified)
st = SpikeTrain([spiketime_in_ms], units='ms', t_stop=20)
window_starttime = -3 * ms
window_endtime = 5 * ms
STA = sta.spike_triggered_average(
z, st, (window_starttime, window_endtime))
cutout = z[int(((spiketime + window_starttime) * sr).simplified):
int(((spiketime + window_endtime) * sr).simplified)]
cutout.t_start = window_starttime
assert_array_equal(STA, cutout)
[docs] def test_usage_of_spikes(self):
st = SpikeTrain([16.5 * math.pi, 17.5 * math.pi,
18.5 * math.pi, 19.5 * math.pi], units='ms', t_stop=20 * math.pi)
STA = sta.spike_triggered_average(
self.asiga0, st, (-math.pi * ms, math.pi * ms))
self.assertEqual(STA.annotations['used_spikes'], 3)
self.assertEqual(STA.annotations['unused_spikes'], 1)
#***********************************************************************
#**** Test for an invalid value, to check that the function raises *****
#********* an exception or returns an error code ***********************
[docs] def test_analog_signal_of_wrong_type(self):
'''Analog signal given as list, but must be AnalogSignal'''
asiga = [0, 1, 2, 3, 4]
self.assertRaises(TypeError, sta.spike_triggered_average,
asiga, self.st0, (-2 * ms, 2 * ms))
[docs] def test_spiketrain_of_list_type_in_wrong_sense(self):
st = [10, 11, 12]
self.assertRaises(TypeError, sta.spike_triggered_average,
self.asiga0, st, (1 * ms, 2 * ms))
[docs] def test_spiketrain_of_nonlist_and_nonspiketrain_type(self):
st = (10, 11, 12)
self.assertRaises(TypeError, sta.spike_triggered_average,
self.asiga0, st, (1 * ms, 2 * ms))
[docs] def test_forgotten_AnalogSignal_argument(self):
self.assertRaises(TypeError, sta.spike_triggered_average,
self.st0, (-2 * ms, 2 * ms))
[docs] def test_one_smaller_nrspiketrains_smaller_nranalogsignals(self):
'''Number of spiketrains between 1 and number of analogsignals'''
self.assertRaises(ValueError, sta.spike_triggered_average,
self.asiga2, self.lst, (-2 * ms, 2 * ms))
[docs] def test_more_spiketrains_than_analogsignals_forbidden(self):
self.assertRaises(ValueError, sta.spike_triggered_average,
self.asiga0, self.lst, (-2 * ms, 2 * ms))
[docs] def test_spike_earlier_than_analogsignal(self):
st = SpikeTrain([-1 * math.pi, 2 * math.pi],
units='ms', t_start=-2 * math.pi, t_stop=20 * math.pi)
self.assertRaises(ValueError, sta.spike_triggered_average,
self.asiga0, st, (-2 * ms, 2 * ms))
[docs] def test_spike_later_than_analogsignal(self):
st = SpikeTrain(
[math.pi, 21 * math.pi], units='ms', t_stop=25 * math.pi)
self.assertRaises(ValueError, sta.spike_triggered_average,
self.asiga0, st, (-2 * ms, 2 * ms))
[docs] def test_impossible_window(self):
self.assertRaises(ValueError, sta.spike_triggered_average,
self.asiga0, self.st0, (-2 * ms, -5 * ms))
[docs] def test_window_larger_than_signal(self):
self.assertRaises(ValueError, sta.spike_triggered_average,
self.asiga0, self.st0, (-15 * math.pi * ms, 15 * math.pi * ms))
[docs] def test_wrong_window_starttime_unit(self):
self.assertRaises(TypeError, sta.spike_triggered_average,
self.asiga0, self.st0, (-2 * mV, 2 * ms))
[docs] def test_wrong_window_endtime_unit(self):
self.assertRaises(TypeError, sta.spike_triggered_average,
self.asiga0, self.st0, (-2 * ms, 2 * Hz))
[docs] def test_window_borders_as_complex_numbers(self):
self.assertRaises(TypeError, sta.spike_triggered_average, self.asiga0,
self.st0, ((-2 * math.pi + 3j) * ms, (2 * math.pi + 3j) * ms))
#***********************************************************************
#**** Test for an empty value (where the argument is a list, array, ****
#********* vector or other container datatype). ************************
[docs] def test_empty_analogsignal(self):
asiga = AnalogSignal([], units='mV', sampling_rate=10 / ms)
st = SpikeTrain([5], units='ms', t_stop=10)
self.assertRaises(ValueError, sta.spike_triggered_average,
asiga, st, (-1 * ms, 1 * ms))
[docs] def test_one_spiketrain_empty(self):
'''Test for one empty SpikeTrain, but existing spikes in other'''
st = [SpikeTrain(
[9 * math.pi, 10 * math.pi, 11 * math.pi, 12 * math.pi],
units='ms', t_stop=self.asiga1.t_stop),
SpikeTrain([], units='ms', t_stop=self.asiga1.t_stop)]
STA = sta.spike_triggered_average(self.asiga1, st, (-1 * ms, 1 * ms))
cmp_array = AnalogSignal(np.array([np.zeros(20, dtype=float)]).T,
units='mV', sampling_rate=10 / ms)
cmp_array = cmp_array / 0.
cmp_array.t_start = -1 * ms
assert_array_equal(STA.magnitude[:, 1], cmp_array.magnitude[:, 0])
[docs] def test_all_spiketrains_empty(self):
st = SpikeTrain([], units='ms', t_stop=self.asiga1.t_stop)
with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger warnings.
STA = sta.spike_triggered_average(
self.asiga1, st, (-1 * ms, 1 * ms))
self.assertEqual("No spike at all was either found or used "
"for averaging", str(w[-1].message))
nan_array = np.empty(20)
nan_array.fill(np.nan)
cmp_array = AnalogSignal(np.array([nan_array, nan_array]).T,
units='mV', sampling_rate=10 / ms)
assert_array_equal(STA.magnitude, cmp_array.magnitude)
# =========================================================================
# Tests for new scipy verison (with scipy.signal.coherence)
# =========================================================================
@unittest.skipIf(not hasattr(scipy.signal, 'coherence'), "Please update scipy "
"to a version >= 0.16")
[docs]class sfc_TestCase_new_scipy(unittest.TestCase):
[docs] def setUp(self):
# standard testsignals
tlen0 = 100 * pq.s
f0 = 20. * pq.Hz
fs0 = 1 * pq.ms
t0 = np.arange(
0, tlen0.rescale(pq.s).magnitude,
fs0.rescale(pq.s).magnitude) * pq.s
self.anasig0 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
self.st0 = SpikeTrain(
np.arange(0, tlen0.rescale(pq.ms).magnitude, 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)
self.bst0 = BinnedSpikeTrain(self.st0, binsize=fs0)
# shortened analogsignals
self.anasig1 = self.anasig0.time_slice(1 * pq.s, None)
self.anasig2 = self.anasig0.time_slice(None, 99 * pq.s)
# increased sampling frequency
fs1 = 0.1 * pq.ms
self.anasig3 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs1)
self.bst1 = BinnedSpikeTrain(
self.st0.time_slice(self.anasig3.t_start, self.anasig3.t_stop),
binsize=fs1)
# analogsignal containing multiple traces
self.anasig4 = AnalogSignal(
np.array([
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
np.sin(4 * np.pi * (f0 * t0).simplified.magnitude)]).
transpose(),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
# shortened spike train
self.st3 = SpikeTrain(
np.arange(
(tlen0.rescale(pq.ms).magnitude * .25),
(tlen0.rescale(pq.ms).magnitude * .75), 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)
self.bst3 = BinnedSpikeTrain(self.st3, binsize=fs0)
self.st4 = SpikeTrain(np.arange(
(tlen0.rescale(pq.ms).magnitude * .25),
(tlen0.rescale(pq.ms).magnitude * .75), 50) * pq.ms,
t_start=5 * fs0, t_stop=tlen0 - 5 * fs0)
self.bst4 = BinnedSpikeTrain(self.st4, binsize=fs0)
# spike train with incompatible binsize
self.bst5 = BinnedSpikeTrain(self.st3, binsize=fs0 * 2.)
# spike train with same binsize as the analog signal, but with
# bin edges not aligned to the time axis of the analog signal
self.bst6 = BinnedSpikeTrain(
self.st3, binsize=fs0, t_start=4.5 * fs0, t_stop=tlen0 - 4.5 * fs0)
# =========================================================================
# Tests for correct input handling
# =========================================================================
[docs] def test_start_stop_times_out_of_range(self):
self.assertRaises(ValueError,
sta.spike_field_coherence,
self.anasig1, self.bst0)
self.assertRaises(ValueError,
sta.spike_field_coherence,
self.anasig2, self.bst0)
[docs] def test_incompatible_spiketrain_analogsignal(self):
# These spike trains have incompatible binning (binsize or alignment to
# time axis of analog signal)
self.assertRaises(ValueError,
sta.spike_field_coherence,
self.anasig0, self.bst5)
self.assertRaises(ValueError,
sta.spike_field_coherence,
self.anasig0, self.bst6)
[docs] def test_signal_dimensions(self):
# single analogsignal trace and single spike train
s_single, f_single = sta.spike_field_coherence(self.anasig0, self.bst0)
self.assertEqual(len(f_single.shape), 1)
self.assertEqual(len(s_single.shape), 2)
# multiple analogsignal traces and single spike train
s_multi, f_multi = sta.spike_field_coherence(self.anasig4, self.bst0)
self.assertEqual(len(f_multi.shape), 1)
self.assertEqual(len(s_multi.shape), 2)
# frequencies are identical since same sampling frequency was used
# in both cases and data length is the same
assert_array_equal(f_single, f_multi)
# coherences of s_single and first signal in s_multi are identical,
# since first analogsignal trace in anasig4 is same as in anasig0
assert_array_equal(s_single[:, 0], s_multi[:, 0])
# =========================================================================
# Tests for correct return values
# =========================================================================
[docs] def test_spike_field_coherence_perfect_coherence(self):
# check for detection of 20Hz peak in anasig0/bst0
s, f = sta.spike_field_coherence(
self.anasig0, self.bst0, window='boxcar')
f_ind = np.where(f >= 19.)[0][0]
max_ind = np.argmax(s[1:]) + 1
self.assertEqual(f_ind, max_ind)
self.assertAlmostEqual(s[f_ind], 1., delta=0.01)
[docs] def test_output_frequencies(self):
nfft = 256
_, f = sta.spike_field_coherence(self.anasig3, self.bst1, nfft=nfft)
# check number of frequency samples
self.assertEqual(len(f), nfft / 2 + 1)
# check values of frequency samples
assert_array_almost_equal(
f, np.linspace(
0, self.anasig3.sampling_rate.rescale('Hz').magnitude / 2,
nfft / 2 + 1) * pq.Hz)
[docs] def test_short_spiketrain(self):
# this spike train has the same length as anasig0
s1, f1 = sta.spike_field_coherence(
self.anasig0, self.bst3, window='boxcar')
# this spike train has the same spikes as above, but is shorter than
# anasig0
s2, f2 = sta.spike_field_coherence(
self.anasig0, self.bst4, window='boxcar')
# the results above should be the same, nevertheless
assert_array_equal(s1.magnitude, s2.magnitude)
assert_array_equal(f1.magnitude, f2.magnitude)
# =========================================================================
# Tests for old scipy verison (without scipy.signal.coherence)
# =========================================================================
@unittest.skipIf(hasattr(scipy.signal, 'coherence'), 'Applies only for old '
'scipy versions (<0.16)')
[docs]class sfc_TestCase_old_scipy(unittest.TestCase):
[docs] def setUp(self):
# standard testsignals
tlen0 = 100 * pq.s
f0 = 20. * pq.Hz
fs0 = 1 * pq.ms
t0 = np.arange(
0, tlen0.rescale(pq.s).magnitude,
fs0.rescale(pq.s).magnitude) * pq.s
self.anasig0 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
self.st0 = SpikeTrain(
np.arange(0, tlen0.rescale(pq.ms).magnitude, 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)
self.bst0 = BinnedSpikeTrain(self.st0, binsize=fs0)
def test_old_scipy_version(self):
self.assertRaises(AttributeError, sta.spike_field_coherence,
self.anasig0, self.bst0)
if __name__ == '__main__':
unittest.main()