Source code for spike.plugins.wavelet
#!/usr/bin/env python
# encoding: utf-8
"""A plugin which install wavelet denoising
This plugin is based on the PyWavelet library, which should be installed independently before trying to use this plugin
It can be found at: http://www.pybytes.com/pywavelets/
M-A Delsuc april 2016 from an idea by L Chiron
"""
from __future__ import print_function
import numpy as np
import unittest
try:
import pywt
ok = True
except:
raise Exception('This plugin requires the installation of the PyWavelet library ( www.pybytes.com/pywavelets )')
print('*** The wavelet plugin requires the installation of the PyWavelet library ( www.pybytes.com/pywavelets )')
ok = False
from spike import NPKError
from spike.NPKData import NPKData_plugin
from spike.util.signal_tools import findnoiselevel
########################################################################
[docs]def ilog2(x):
"integer log2 definition"
return int(np.floor(np.log(x)/np.log(2)))
[docs]def denoise1D(data, noiseSigma, wavelet='db3'):
"""performed the 1D denoising
data : a 1D numpy array
wavelet : the wavelet basis used,
"""
levels = ilog2(data.shape[0])
WC = pywt.wavedec(data,wavelet,level=levels)
threshold=noiseSigma*np.sqrt(2*ilog2(data.size))
NWC = map(lambda x: pywt.thresholding.soft(x,threshold), WC)
return pywt.waverec( NWC, wavelet)
[docs]def denoise2D(data, noiseSigma, wavelet='db3'):
"""performed the 2D denoising
data : a 2D numpy array
wavelet : the wavelet basis used
"""
levels = ilog2(data.shape[0])
WC = pywt.wavedec2(data,wavelet,level=levels)
threshold=noiseSigma*np.sqrt(2*ilog2(data.size))
NWC = map(lambda x: pywt.thresholding.soft(x,threshold), WC)
return pywt.waverec2( NWC, wavelet)
[docs]def wavelet(npkd, nsigma=1.0, wavelet='db3'):
"""
Performs the wavelet denoising of a 1D or 2D spectrum.
nsigma the threshold is nsigma times the estimate noise level,
default 1.0 - corresponds to a relatively strong denoising
wavelet the wavelet basis used, default 'db3' (Daubechies 3)
check pywt.wavelist() for the list of possible wavelet
eg:
d.wavelet(nsigma=0.5) # d is cleaned after execution
ref: Donoho DL (1995) De-noising by soft-thresholding. IEEE Trans Inf Theory 41:613–621.
Based on the PyWavelet library
"""
noise = findnoiselevel(npkd.get_buffer())
if npkd.dim == 1:
z = denoise1D(npkd.get_buffer(), nsigma*npkd.get_buffer().std(), wavelet=wavelet)
elif npkd.dim == 2:
z = denoise2D(npkd.get_buffer(), nsigma*npkd.get_buffer().std(), wavelet=wavelet)
else:
raise NPKError("not implemented")
npkd.set_buffer(z)
return npkd
if ok:
NPKData_plugin("wavelet", wavelet)
[docs]class WaveLetTest(unittest.TestCase):
""" - Testing Wavelet plugin- """
[docs] def setUp(self):
self.verbose = 1 # verbose > 0 switches messages on
[docs] def announce(self):
if self.verbose >0:
print("\n========",self.shortDescription(),'===============')
[docs] def test_wave(self):
""" - testing wavelet - """
from spike.util.signal_tools import findnoiselevel
from spike.NPKData import NPKData
if not ok:
print("wavelet plugin not installed")
return
self.announce()
si = 8000
b = np.random.randn(si)
d = NPKData(buffer=b)
d[1234]=100.0
d[4321]=30.0
n = findnoiselevel(d.get_buffer())
self.assertTrue(abs(n-0.95)<0.1)
self.assertAlmostEqual(d[1234],100.0)
self.assertAlmostEqual(d[4321],30.0)
d.wavelet(nsigma=0.5) # relatively strong denoising
n = findnoiselevel(d.get_buffer())
print (n, d[1234], d[4321])
self.assertTrue(n<0.01)
self.assertTrue(d[1234]>80.0)
self.assertTrue(d[4321]>10.0)