Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/signal/wavelets.py : 9%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import numpy as np
2from scipy.linalg import eig
3from scipy.special import comb
4from scipy.signal import convolve
6__all__ = ['daub', 'qmf', 'cascade', 'morlet', 'ricker', 'morlet2', 'cwt']
9def daub(p):
10 """
11 The coefficients for the FIR low-pass filter producing Daubechies wavelets.
13 p>=1 gives the order of the zero at f=1/2.
14 There are 2p filter coefficients.
16 Parameters
17 ----------
18 p : int
19 Order of the zero at f=1/2, can have values from 1 to 34.
21 Returns
22 -------
23 daub : ndarray
24 Return
26 """
27 sqrt = np.sqrt
28 if p < 1:
29 raise ValueError("p must be at least 1.")
30 if p == 1:
31 c = 1 / sqrt(2)
32 return np.array([c, c])
33 elif p == 2:
34 f = sqrt(2) / 8
35 c = sqrt(3)
36 return f * np.array([1 + c, 3 + c, 3 - c, 1 - c])
37 elif p == 3:
38 tmp = 12 * sqrt(10)
39 z1 = 1.5 + sqrt(15 + tmp) / 6 - 1j * (sqrt(15) + sqrt(tmp - 15)) / 6
40 z1c = np.conj(z1)
41 f = sqrt(2) / 8
42 d0 = np.real((1 - z1) * (1 - z1c))
43 a0 = np.real(z1 * z1c)
44 a1 = 2 * np.real(z1)
45 return f / d0 * np.array([a0, 3 * a0 - a1, 3 * a0 - 3 * a1 + 1,
46 a0 - 3 * a1 + 3, 3 - a1, 1])
47 elif p < 35:
48 # construct polynomial and factor it
49 if p < 35:
50 P = [comb(p - 1 + k, k, exact=1) for k in range(p)][::-1]
51 yj = np.roots(P)
52 else: # try different polynomial --- needs work
53 P = [comb(p - 1 + k, k, exact=1) / 4.0**k
54 for k in range(p)][::-1]
55 yj = np.roots(P) / 4
56 # for each root, compute two z roots, select the one with |z|>1
57 # Build up final polynomial
58 c = np.poly1d([1, 1])**p
59 q = np.poly1d([1])
60 for k in range(p - 1):
61 yval = yj[k]
62 part = 2 * sqrt(yval * (yval - 1))
63 const = 1 - 2 * yval
64 z1 = const + part
65 if (abs(z1)) < 1:
66 z1 = const - part
67 q = q * [1, -z1]
69 q = c * np.real(q)
70 # Normalize result
71 q = q / np.sum(q) * sqrt(2)
72 return q.c[::-1]
73 else:
74 raise ValueError("Polynomial factorization does not work "
75 "well for p too large.")
78def qmf(hk):
79 """
80 Return high-pass qmf filter from low-pass
82 Parameters
83 ----------
84 hk : array_like
85 Coefficients of high-pass filter.
87 """
88 N = len(hk) - 1
89 asgn = [{0: 1, 1: -1}[k % 2] for k in range(N + 1)]
90 return hk[::-1] * np.array(asgn)
93def cascade(hk, J=7):
94 """
95 Return (x, phi, psi) at dyadic points ``K/2**J`` from filter coefficients.
97 Parameters
98 ----------
99 hk : array_like
100 Coefficients of low-pass filter.
101 J : int, optional
102 Values will be computed at grid points ``K/2**J``. Default is 7.
104 Returns
105 -------
106 x : ndarray
107 The dyadic points ``K/2**J`` for ``K=0...N * (2**J)-1`` where
108 ``len(hk) = len(gk) = N+1``.
109 phi : ndarray
110 The scaling function ``phi(x)`` at `x`:
111 ``phi(x) = sum(hk * phi(2x-k))``, where k is from 0 to N.
112 psi : ndarray, optional
113 The wavelet function ``psi(x)`` at `x`:
114 ``phi(x) = sum(gk * phi(2x-k))``, where k is from 0 to N.
115 `psi` is only returned if `gk` is not None.
117 Notes
118 -----
119 The algorithm uses the vector cascade algorithm described by Strang and
120 Nguyen in "Wavelets and Filter Banks". It builds a dictionary of values
121 and slices for quick reuse. Then inserts vectors into final vector at the
122 end.
124 """
125 N = len(hk) - 1
127 if (J > 30 - np.log2(N + 1)):
128 raise ValueError("Too many levels.")
129 if (J < 1):
130 raise ValueError("Too few levels.")
132 # construct matrices needed
133 nn, kk = np.ogrid[:N, :N]
134 s2 = np.sqrt(2)
135 # append a zero so that take works
136 thk = np.r_[hk, 0]
137 gk = qmf(hk)
138 tgk = np.r_[gk, 0]
140 indx1 = np.clip(2 * nn - kk, -1, N + 1)
141 indx2 = np.clip(2 * nn - kk + 1, -1, N + 1)
142 m = np.zeros((2, 2, N, N), 'd')
143 m[0, 0] = np.take(thk, indx1, 0)
144 m[0, 1] = np.take(thk, indx2, 0)
145 m[1, 0] = np.take(tgk, indx1, 0)
146 m[1, 1] = np.take(tgk, indx2, 0)
147 m *= s2
149 # construct the grid of points
150 x = np.arange(0, N * (1 << J), dtype=float) / (1 << J)
151 phi = 0 * x
153 psi = 0 * x
155 # find phi0, and phi1
156 lam, v = eig(m[0, 0])
157 ind = np.argmin(np.absolute(lam - 1))
158 # a dictionary with a binary representation of the
159 # evaluation points x < 1 -- i.e. position is 0.xxxx
160 v = np.real(v[:, ind])
161 # need scaling function to integrate to 1 so find
162 # eigenvector normalized to sum(v,axis=0)=1
163 sm = np.sum(v)
164 if sm < 0: # need scaling function to integrate to 1
165 v = -v
166 sm = -sm
167 bitdic = {'0': v / sm}
168 bitdic['1'] = np.dot(m[0, 1], bitdic['0'])
169 step = 1 << J
170 phi[::step] = bitdic['0']
171 phi[(1 << (J - 1))::step] = bitdic['1']
172 psi[::step] = np.dot(m[1, 0], bitdic['0'])
173 psi[(1 << (J - 1))::step] = np.dot(m[1, 1], bitdic['0'])
174 # descend down the levels inserting more and more values
175 # into bitdic -- store the values in the correct location once we
176 # have computed them -- stored in the dictionary
177 # for quicker use later.
178 prevkeys = ['1']
179 for level in range(2, J + 1):
180 newkeys = ['%d%s' % (xx, yy) for xx in [0, 1] for yy in prevkeys]
181 fac = 1 << (J - level)
182 for key in newkeys:
183 # convert key to number
184 num = 0
185 for pos in range(level):
186 if key[pos] == '1':
187 num += (1 << (level - 1 - pos))
188 pastphi = bitdic[key[1:]]
189 ii = int(key[0])
190 temp = np.dot(m[0, ii], pastphi)
191 bitdic[key] = temp
192 phi[num * fac::step] = temp
193 psi[num * fac::step] = np.dot(m[1, ii], pastphi)
194 prevkeys = newkeys
196 return x, phi, psi
199def morlet(M, w=5.0, s=1.0, complete=True):
200 """
201 Complex Morlet wavelet.
203 Parameters
204 ----------
205 M : int
206 Length of the wavelet.
207 w : float, optional
208 Omega0. Default is 5
209 s : float, optional
210 Scaling factor, windowed from ``-s*2*pi`` to ``+s*2*pi``. Default is 1.
211 complete : bool, optional
212 Whether to use the complete or the standard version.
214 Returns
215 -------
216 morlet : (M,) ndarray
218 See Also
219 --------
220 morlet2 : Implementation of Morlet wavelet, compatible with `cwt`.
221 scipy.signal.gausspulse
223 Notes
224 -----
225 The standard version::
227 pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
229 This commonly used wavelet is often referred to simply as the
230 Morlet wavelet. Note that this simplified version can cause
231 admissibility problems at low values of `w`.
233 The complete version::
235 pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
237 This version has a correction
238 term to improve admissibility. For `w` greater than 5, the
239 correction term is negligible.
241 Note that the energy of the return wavelet is not normalised
242 according to `s`.
244 The fundamental frequency of this wavelet in Hz is given
245 by ``f = 2*s*w*r / M`` where `r` is the sampling rate.
247 Note: This function was created before `cwt` and is not compatible
248 with it.
250 """
251 x = np.linspace(-s * 2 * np.pi, s * 2 * np.pi, M)
252 output = np.exp(1j * w * x)
254 if complete:
255 output -= np.exp(-0.5 * (w**2))
257 output *= np.exp(-0.5 * (x**2)) * np.pi**(-0.25)
259 return output
262def ricker(points, a):
263 """
264 Return a Ricker wavelet, also known as the "Mexican hat wavelet".
266 It models the function:
268 ``A * (1 - (x/a)**2) * exp(-0.5*(x/a)**2)``,
270 where ``A = 2/(sqrt(3*a)*(pi**0.25))``.
272 Parameters
273 ----------
274 points : int
275 Number of points in `vector`.
276 Will be centered around 0.
277 a : scalar
278 Width parameter of the wavelet.
280 Returns
281 -------
282 vector : (N,) ndarray
283 Array of length `points` in shape of ricker curve.
285 Examples
286 --------
287 >>> from scipy import signal
288 >>> import matplotlib.pyplot as plt
290 >>> points = 100
291 >>> a = 4.0
292 >>> vec2 = signal.ricker(points, a)
293 >>> print(len(vec2))
294 100
295 >>> plt.plot(vec2)
296 >>> plt.show()
298 """
299 A = 2 / (np.sqrt(3 * a) * (np.pi**0.25))
300 wsq = a**2
301 vec = np.arange(0, points) - (points - 1.0) / 2
302 xsq = vec**2
303 mod = (1 - xsq / wsq)
304 gauss = np.exp(-xsq / (2 * wsq))
305 total = A * mod * gauss
306 return total
309def morlet2(M, s, w=5):
310 """
311 Complex Morlet wavelet, designed to work with `cwt`.
313 Returns the complete version of morlet wavelet, normalised
314 according to `s`::
316 exp(1j*w*x/s) * exp(-0.5*(x/s)**2) * pi**(-0.25) * sqrt(1/s)
318 Parameters
319 ----------
320 M : int
321 Length of the wavelet.
322 s : float
323 Width parameter of the wavelet.
324 w : float, optional
325 Omega0. Default is 5
327 Returns
328 -------
329 morlet : (M,) ndarray
331 See Also
332 --------
333 morlet : Implementation of Morlet wavelet, incompatible with `cwt`
335 Notes
336 -----
338 .. versionadded:: 1.4.0
340 This function was designed to work with `cwt`. Because `morlet2`
341 returns an array of complex numbers, the `dtype` argument of `cwt`
342 should be set to `complex128` for best results.
344 Note the difference in implementation with `morlet`.
345 The fundamental frequency of this wavelet in Hz is given by::
347 f = w*fs / (2*s*np.pi)
349 where ``fs`` is the sampling rate and `s` is the wavelet width parameter.
350 Similarly we can get the wavelet width parameter at ``f``::
352 s = w*fs / (2*f*np.pi)
354 Examples
355 --------
356 >>> from scipy import signal
357 >>> import matplotlib.pyplot as plt
359 >>> M = 100
360 >>> s = 4.0
361 >>> w = 2.0
362 >>> wavelet = signal.morlet2(M, s, w)
363 >>> plt.plot(abs(wavelet))
364 >>> plt.show()
366 This example shows basic use of `morlet2` with `cwt` in time-frequency
367 analysis:
369 >>> from scipy import signal
370 >>> import matplotlib.pyplot as plt
371 >>> t, dt = np.linspace(0, 1, 200, retstep=True)
372 >>> fs = 1/dt
373 >>> w = 6.
374 >>> sig = np.cos(2*np.pi*(50 + 10*t)*t) + np.sin(40*np.pi*t)
375 >>> freq = np.linspace(1, fs/2, 100)
376 >>> widths = w*fs / (2*freq*np.pi)
377 >>> cwtm = signal.cwt(sig, signal.morlet2, widths, w=w)
378 >>> plt.pcolormesh(t, freq, np.abs(cwtm), cmap='viridis', shading='gouraud')
379 >>> plt.show()
381 """
382 x = np.arange(0, M) - (M - 1.0) / 2
383 x = x / s
384 wavelet = np.exp(1j * w * x) * np.exp(-0.5 * x**2) * np.pi**(-0.25)
385 output = np.sqrt(1/s) * wavelet
386 return output
389def cwt(data, wavelet, widths, dtype=None, **kwargs):
390 """
391 Continuous wavelet transform.
393 Performs a continuous wavelet transform on `data`,
394 using the `wavelet` function. A CWT performs a convolution
395 with `data` using the `wavelet` function, which is characterized
396 by a width parameter and length parameter. The `wavelet` function
397 is allowed to be complex.
399 Parameters
400 ----------
401 data : (N,) ndarray
402 data on which to perform the transform.
403 wavelet : function
404 Wavelet function, which should take 2 arguments.
405 The first argument is the number of points that the returned vector
406 will have (len(wavelet(length,width)) == length).
407 The second is a width parameter, defining the size of the wavelet
408 (e.g. standard deviation of a gaussian). See `ricker`, which
409 satisfies these requirements.
410 widths : (M,) sequence
411 Widths to use for transform.
412 dtype : data-type, optional
413 The desired data type of output. Defaults to ``float64`` if the
414 output of `wavelet` is real and ``complex128`` if it is complex.
416 .. versionadded:: 1.4.0
418 kwargs
419 Keyword arguments passed to wavelet function.
421 .. versionadded:: 1.4.0
423 Returns
424 -------
425 cwt: (M, N) ndarray
426 Will have shape of (len(widths), len(data)).
428 Notes
429 -----
431 .. versionadded:: 1.4.0
433 For non-symmetric, complex-valued wavelets, the input signal is convolved
434 with the time-reversed complex-conjugate of the wavelet data [1].
436 ::
438 length = min(10 * width[ii], len(data))
439 cwt[ii,:] = signal.convolve(data, np.conj(wavelet(length, width[ii],
440 **kwargs))[::-1], mode='same')
442 References
443 ----------
444 .. [1] S. Mallat, "A Wavelet Tour of Signal Processing (3rd Edition)",
445 Academic Press, 2009.
447 Examples
448 --------
449 >>> from scipy import signal
450 >>> import matplotlib.pyplot as plt
451 >>> t = np.linspace(-1, 1, 200, endpoint=False)
452 >>> sig = np.cos(2 * np.pi * 7 * t) + signal.gausspulse(t - 0.4, fc=2)
453 >>> widths = np.arange(1, 31)
454 >>> cwtmatr = signal.cwt(sig, signal.ricker, widths)
455 >>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
456 ... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())
457 >>> plt.show()
458 """
459 if wavelet == ricker:
460 window_size = kwargs.pop('window_size', None)
461 # Determine output type
462 if dtype is None:
463 if np.asarray(wavelet(1, widths[0], **kwargs)).dtype.char in 'FDG':
464 dtype = np.complex128
465 else:
466 dtype = np.float64
468 output = np.zeros((len(widths), len(data)), dtype=dtype)
469 for ind, width in enumerate(widths):
470 N = np.min([10 * width, len(data)])
471 # the conditional block below and the window_size
472 # kwarg pop above may be removed eventually; these
473 # are shims for 32-bit arch + NumPy <= 1.14.5 to
474 # address gh-11095
475 if wavelet == ricker and window_size is None:
476 ceil = np.ceil(N)
477 if ceil != N:
478 N = int(N)
479 wavelet_data = np.conj(wavelet(N, width, **kwargs)[::-1])
480 output[ind] = convolve(data, wavelet_data, mode='same')
481 return output