Coverage for kwave/utils/filterutils.py: 9%
276 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
1import numpy as np
2from kwave.utils.kutils import get_win
3from kwave.utils.misc import find_closest, sinc
4import scipy
5from scipy.signal import lfilter
6from scipy.fftpack import fft, ifft, ifftshift, fftshift, fftn, ifftn
7import math
8from math import pi
10from kwave.utils.conversionutils import scale_SI
11from kwave.utils.checkutils import num_dim, num_dim2
14# Compute the next highest power of 2 of a 32–bit number `n`
15def next_pow2(n):
16 # decrement `n` (to handle cases when `n` itself is a power of 2)
17 n = n - 1
19 # set all bits after the last set bit
20 n |= n >> 1
21 n |= n >> 2
22 n |= n >> 4
23 n |= n >> 8
24 n |= n >> 16
26 # increment `n` and return
27 return n + 1
30def single_sided_correction(func_fft, fft_len, dim):
31 """
32 correct the single - sided magnitude by multiplying the symmetric points by
33 2(the DC and Nyquist components are unique and are not multiplied by 2
34 and the Nyquist component only exists for even numbered FFT lengths)
35 """
36 if fft_len % 2:
38 # odd FFT length switch dim case
39 if dim == 0:
40 func_fft[1:, :] = func_fft[1:, :] * 2
41 elif dim == 1:
42 func_fft[:, 1:] = func_fft[:, 1:] * 2
43 elif dim == 2:
44 func_fft[:, :, 1:] = func_fft[:, :, 1:] * 2
45 elif dim == 3:
46 func_fft[:, :, :, 1:] = func_fft[:, :, :, 1:] * 2
47 else:
49 # even FFT length
50 if dim == 0:
51 func_fft[1: -1, :, :, :] = func_fft[1: -1, :, :, :] * 2
52 elif dim == 2:
53 func_fft[:, 1: -1, :, :] = func_fft[:, 1: -1, :, :] * 2
54 elif dim == 3:
55 func_fft[:, :, 1: -1, :] = func_fft[:, :, 1: -1, :] * 2
56 elif dim == 4:
57 func_fft[:, :, :, 1: -1] = func_fft[:, :, :, 1: -1] * 2
59 return func_fft
62def spect(func, Fs, dim='auto', fft_len=0, power_two=False, unwrap=False, window='Rectangular'):
63 """
65 Args:
66 func: signal to analyse
67 Fs: sampling frequency [Hz]
68 dim: dimension over which the spectrum is calculated
69 fft_len: length of FFT. If the set
70 length is smaller than the signal length, the default
71 value is used instead (default = signal length).
72 power_two: Boolean controlling whether the FFT length is forced to
73 be the next highest power of 2 (default = false).
74 unwrap (bool):
75 window: parameter string controlling the window type used to
76 filter the signal before the FFT is taken (default =
77 'Rectangular'). Any valid input types for get_win may be
78 used.
80 Returns:
81 f: frequency array
82 func_as: single-sided amplitude spectrum
83 func_ps: single-sided phase spectrum
85 """
87 # check the size of the input
88 sz = func.shape
90 # check input isn't scalar
91 if np.size(func) == 1:
92 raise ValueError('Input signal cannot be scalar.')
94 # check input doesn't have more than 4 dimensions
95 if len(sz) > 4:
96 raise ValueError('Input signal must have 1, 2, 3, or 4 dimensions.')
98 # automatically set dimension to first non - singleton dimension
99 if dim == 'auto':
100 dim_index = 0
101 while dim_index <= len(sz):
102 if sz[dim_index] > 1:
103 dim = dim_index
104 break
105 dim_index = dim_index + 1
107 # assign the number of points being analysed
108 func_length = sz[dim]
110 # set the length of the FFT
111 if not fft_len > func_length:
112 if power_two:
113 # find an appropriate FFT length of the form 2 ^ N that is equal to or
114 # larger than the length of the input signal
115 fft_len = 2 ** (next_pow2(func_length))
116 else:
117 # set the FFT length to the function length
118 fft_len = func_length
120 # window the signal, reshaping the window to be in the correct direction
121 win, coherent_gain = get_win(func_length, window, symmetric=False)
122 win = np.reshape(win, tuple(([1] * dim + [func_length] + [1] * (len(sz) - 2))))
123 func = win * func
125 # compute the fft using the defined FFT length, if fft_len >
126 # func_length, the input signal is padded with zeros
127 func_fft = fft(func, fft_len, dim)
129 # correct for the magnitude scaling of the FFT and the coherent gain of the
130 # window(note that the correction is equal to func_length NOT fft_len)
131 func_fft = func_fft / (func_length * coherent_gain)
133 # reduce to a single sided spectrum where the number of unique points for
134 # even numbered FFT lengths is given by N / 2 + 1, and for odd(N + 1) / 2
135 num_unique_pts = int(np.ceil((fft_len + 1) / 2))
136 if dim == 0:
137 func_fft = func_fft[0:num_unique_pts]
138 elif dim == 1:
139 func_fft = func_fft[:, 0: num_unique_pts]
140 elif dim == 2:
141 func_fft = func_fft[:, :, 0: num_unique_pts]
142 elif dim == 3:
143 func_fft = func_fft[:, :, :, 0: num_unique_pts]
145 func_fft = single_sided_correction(func_fft, fft_len, dim)
147 # create the frequency axis variable
148 f = np.arange(0, func_fft.shape[dim]) * Fs / fft_len
150 # calculate the amplitude spectrum
151 func_as = np.abs(func_fft)
153 # calculate the phase spectrum
154 func_ps = np.angle(func_fft)
156 # unwrap the phase spectrum if required
157 if unwrap:
158 func_ps = unwrap(func_ps, [], dim)
160 return f, func_as, func_ps
163def extract_amp_phase(data, Fs, source_freq, dim='auto', fft_padding=3, window='Hanning'):
164 """
165 extract_amp_phase extracts the amplitude and phase information at a
166 specified frequency from a vector or matrix of time series data. By
167 default the time dimension is set to the highest non-singleton
168 dimension. The amplitude and phase are extracted from the frequency
169 spectrum, which is calculated using a windowed and zero padded FFT.
170 The values are extracted at the frequency closest to source_freq.
172 Args:
174 data: matrix of time signals [s]
175 Fs: sampling frequency [Hz]
176 source_freq: frequency at which the amplitude and phase should be
177 extracted [Hz]
178 dim:
179 fft_padding:
180 window:
182 Returns:
184 """
186 # check for the dim input
187 if dim == 'auto':
188 dim = num_dim(data)
189 if dim == 2 and data.shape[1] == 1:
190 dim = 1
192 # create 1D window and reshape to be oriented in the time dimension of the
193 # input data
194 win, coherent_gain = get_win(data.shape[dim], window)
195 # this list magic in Python comes from the use of ones in MATLAB
196 # TODO: simplify this
197 win = np.reshape(win, [1] * (dim - 1) + [len(win)])
199 # apply window to time dimension of input data
200 data = win * data
202 # compute amplitude and phase spectra
203 f, func_as, func_ps = spect(data, Fs, fft_len=fft_padding * data.shape[dim], dim=dim)
205 # correct for coherent gain
206 func_as = func_as / coherent_gain
208 # find the index of the frequency component closest to source_freq
209 _, f_index = find_closest(f, source_freq)
211 # get size of output variable, collapsing the time dimension
212 sz = list(data.shape)
213 sz[dim - 1] = 1
215 # extract amplitude and relative phase at freq_index
216 if dim == 0:
217 amp = func_as[f_index]
218 phase = func_ps[f_index]
219 elif dim == 1:
220 amp = func_as[:, f_index]
221 phase = func_ps[:, f_index]
222 elif dim == 2:
223 amp = func_as[:, :, f_index]
224 phase = func_ps[:, :, f_index]
225 elif dim == 3:
226 amp = func_as[:, :, :, f_index]
227 phase = func_ps[:, :, :, f_index]
228 else:
229 raise ValueError('dim must be 0, 1, 2, or 3');
231 return amp.squeeze(), phase.squeeze(), f[f_index]
234def create_cw_signals(t_array, freq, amp, phase, ramp_length=4):
235 """
236 create_cw_signals generates a series of continuous wave (CW) signals
237 based on the 1D or 2D input matrices amp and phase, where each signal
238 is given by:
240 amp(i, j) .* sin(2 .* pi .* freq .* t_array + phase(i, j));
242 To avoid startup transients, a cosine tapered up-ramp is applied to
243 the beginning of the signal. By default, the length of this ramp is
244 four periods of the wave. The up-ramp can be turned off by setting
245 the ramp_length to 0.
247 Example:
249 # define sampling parameters
250 f = 5e6
251 T = 1/f
252 Fs = 100e6
253 dt = 1/Fs
254 t_array = np.arange(0, 10*T, dt)
256 # define amplitude and phase
257 amp = get_win(9, 'Gaussian')
258 phase = np.arange(0, 2*pi, 9).T
260 # create signals and plot
261 cw_signal = create_cw_signals(t_array, f, amp, phase)
263 Args:
264 t_array:
265 freq:
266 amp:
267 phase:
268 ramp_length:
270 Returns:
271 cw_signals:
273 """
274 if len(phase) == 1:
275 phase = phase * np.ones(amp.shape)
277 N1, N2 = amp.T.shape
279 cw_signals = np.zeros([N1, N2, len(t_array)])
281 for idx1 in range(N1 - 1):
282 for idx2 in range(N2 - 1):
283 cw_signals[idx1, idx2, :] = amp[idx1, idx2] * np.sin(2 * pi * freq * t_array + phase[idx1, idx2])
285 if ramp_length != 0:
286 # get period and time-step
287 period = 1 / freq
288 dt = t_array[1] - t_array[0]
290 # create ramp x-axis between 0 and pi
291 ramp_length_points = int(np.round(ramp_length * period / dt))
292 ramp_axis = np.arange(0, pi, pi / (ramp_length_points))
294 # create ramp using a shifted cosine
295 ramp = (-np.cos(ramp_axis) + 1) * 0.5
296 ramp = np.reshape(ramp, (1, 1, -1))
298 # apply ramp to all signals simultaneously
300 cw_signals[:, :, :ramp_length_points] *= ramp
302 return np.squeeze(cw_signals)
305def envelope_detection(signal):
306 """
307 envelopeDetection applies the Hilbert transform to extract the
308 envelope from an input vector x. If x is a matrix, the envelope along
309 the last axis.
311 Args:
312 signal:
314 Returns:
315 signal_envelope:
317 """
319 return np.abs(scipy.signal.hilbert(signal))
322def brenner_sharpness(im):
323 ndim = len(np.squeeze(im).shape)
325 if ndim == 2:
326 bren_x = (im[:-2, :] - im[2:, :]) ** 2
327 bren_y = (im[:, :-2] - im[:, 2:]) ** 2
328 s = bren_x.sum() + bren_y.sum()
329 elif ndim == 3:
330 bren_x = (im[:-2, :, :] - im[2:, :, :]) ** 2
331 bren_y = (im[:, :-2, :] - im[:, 2:, :]) ** 2
332 bren_z = (im[:, :, :-2] - im[:, :, 2:]) ** 2
333 s = bren_x.sum() + bren_y.sum() + bren_z.sum()
334 else:
335 raise ValueError("Invalid number of dimensions in im")
336 return s
339def norm_var(im):
340 mu = np.mean(im)
341 s = np.sum((im - mu) ** 2) / mu
342 return s
345def tenenbaum_sharpness(im):
346 ndim = len(np.squeeze(im).shape)
347 if ndim == 2:
348 sobel = scipy.ndimage.sobel(im)
349 elif ndim == 3:
350 sobel = scipy.ndimage.sobel(im)
352 else:
353 raise ValueError("Invalid number of dimensions in im")
354 return sobel.sum()
357def sharpness(im, metric="Brenner"):
358 """
359 sharpness returns a scalar metric related to the sharpness of the 2D
360 or 3D image matrix defined by im. By default, the metric is based on
361 the Brenner gradient which returns the sum of the centered
362 finite-difference at each matrix element in each Cartesian direction.
363 Metrics calculated using the Sobel operator or the normalised
364 variance can also be returned by setting the input paramater metric.
366 For further details, see B. E. Treeby, T. K. Varslot, E. Z. Zhang,
367 J. G. Laufer, and P. C. Beard, "Automatic sound speed selection in
368 photoacoustic image reconstruction using an autofocus approach," J.
369 Biomed. Opt., vol. 16, no. 9, p. 090501, 2011.
371 Args:
372 im:
373 metric (str): Defaults "Brenner"
375 Returns:
376 sharp_met
378 """
379 assert isinstance(im, np.ndarray), "Argument im must be of type numpy array"
381 if metric == "Brenner":
382 sharp_met = brenner_sharpness(im)
383 elif metric == "Tenenbaum":
384 sharp_met = tenenbaum_sharpness(im)
385 elif metric == "NormVariance":
386 sharp_met = norm_var(im)
387 else:
388 raise ValueError(
389 "Unrecognized sharpness metric passed. Valid values are ['Brenner', 'Tanenbaum', 'NormVariance']")
391 return sharp_met
394def fwhm(f, x):
395 """
396 fwhm calculates the Full Width at Half Maximum (FWHM) of a positive
397 1D input function f(x) with spacing given by x.
400 Args:
401 f: f(x)
402 x: x
404 Returns:
405 fwhm_val: FWHM of f(x)
407 """
409 # ensure f is numpy array
410 f = np.array(f)
411 if len(f.squeeze().shape) != 1:
412 raise ValueError("Input function must be 1-dimensional.")
414 def lin_interp(x, y, i, half):
415 return x[i] + (x[i + 1] - x[i]) * ((half - y[i]) / (y[i + 1] - y[i]))
417 def half_max_x(x, y):
418 half = max(y) / 2.0
419 signs = np.sign(np.add(y, -half))
420 zero_crossings = (signs[0:-2] != signs[1:-1])
421 zero_crossings_i = np.where(zero_crossings)[0]
422 return [lin_interp(x, y, zero_crossings_i[0], half),
423 lin_interp(x, y, zero_crossings_i[1], half)]
425 hmx = half_max_x(x, f)
426 fwhm_val = hmx[1] - hmx[0]
428 return fwhm_val
431def gaussian(x, magnitude=None, mean=0, variance=1):
432 """
433 gaussian returns a Gaussian distribution f(x) with the specified
434 magnitude, mean, and variance. If these values are not specified, the
435 magnitude is normalised and values of variance = 1 and mean = 0 are
436 used. For example running
438 import matplotlib.pyplot as plt
439 x = np.arrange(-3,0.05,3)
440 plt.plot(x, gaussian(x))
442 will plot a normalised Gaussian distribution.
444 Note, the full width at half maximum of the resulting distribution
445 can be calculated by FWHM = 2 * sqrt(2 * log(2) * variance).
448 Args:
449 x:
450 magnitude: Bell height. Defaults to normalized.
451 mean (float): mean or expected value. Defaults to 0.
452 variance (float): variance ~ bell width. Defaults to 1.
454 Returns:
455 gauss_distr: Gaussian distribution
457 """
458 if magnitude is None:
459 magnitude = (2 * math.pi * variance) ** -0.5
461 gauss_distr = magnitude * np.exp(-(x - mean) ** 2 / (2 * variance))
463 return gauss_distr
464 # return magnitude * norm.pdf(x, loc=mean, scale=variance)
467def gaussian_filter(signal, Fs, frequency, bandwidth):
468 """
469 gaussian_filter applies a frequency domain Gaussian filter with the
470 specified center frequency and percentage bandwidth to the input
471 signal. If the input signal is given as a matrix, the filter is
472 applied to each matrix row.
474 Args:
475 signal: signal to filter
476 Fs: sampling frequency [Hz]
477 frequency: center frequency of filter [Hz]
478 bandwidth: bandwidth of filter
480 Returns:
481 signal: filtered signal
483 """
484 N = len(signal)
485 if N % 2 == 0:
486 f = np.arange(-N / 2, N / 2) * Fs / N
487 else:
488 f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * Fs / N
490 mean = frequency
491 variance = (bandwidth / 100 * frequency / (2 * np.sqrt(2 * np.log(2)))) ** 2
492 magnitude = 1
494 # create double-sided Gaussain filter
495 gfilter = np.fmax(gaussian(f, magnitude, mean, variance), gaussian(f, magnitude, -mean, variance))
497 # apply filter
498 signal = np.real(ifft(ifftshift(gfilter * fftshift(fft(signal)))))
500 return signal
503def filterTimeSeries(kgrid, medium, signal, ppw=3, rppw=0, stop_band_atten=60, transition_width=0.1, zerophase=False,
504 plot_spectrums=False, plot_signals=False):
505 """
506 Filter signal using the Kaiser windowing method
507 filterTimeSeries filters an input time domain signal using a low pass
508 filter applied by applyFilter with a specified cut-off frequency,
509 stop-band attenuation, and transition bandwidth. It uses the Kaiser
510 Windowing method to design the FIR filter, which can be implemented
511 as either a zero phase or linear phase filter. The cutoff frequency
512 is defined by a minimum number of points per wavelength. A smoothing
513 ramp can also be applied to the beginning of the signal to reduce
514 high frequency transients.
515 Args:
516 kgrid:
517 medium:
518 signal:
519 ppw: Points-Per-Wavelength (default 3)
520 rppw: Ramp-Points-Per-Wavelength (default 0)
521 stop_band_atten: Stop-Band-Attenuation (default 60)
522 transition_width: Transition-width (default 0.1)
523 zero-phase: (default False)
526 Returns:
528 """
530 # check the input is a row vector
531 if num_dim2(signal) == 1:
532 m, n = signal.shape
533 if n == 1:
534 signal = signal.T
535 rotate_signal = True
536 else:
537 rotate_signal = False
538 else:
539 raise TypeError('Input signal must be a vector.')
541 # update the command line status
542 print('Filtering input signal...')
544 # extract the time step
545 assert not isinstance(kgrid.t_array, str) or kgrid.t_array != 'auto', 'kgrid.t_array must be explicitly defined.'
547 # compute the sampling frequency
548 Fs = 1 / kgrid.dt
550 # extract the minium sound speed
551 if medium.sound_speed is not None:
553 # for the fluid code, use medium.sound_speed
554 c0 = medium.sound_speed.min()
556 elif all(medium.is_defined('sound_speed_compression', 'sound_speed_shear')): # pragma: no cover
558 # for the elastic code, combine the shear and compression sound speeds and remove zeros values
559 ss = np.hstack([medium.sound_speed_compression, medium.sound_speed_shear])
560 ss[ss == 0] = np.nan
561 c0 = np.nanmin(ss)
563 # cleanup unused variables
564 del ss
566 else:
567 raise ValueError(
568 'The input fields medium.sound_speed or medium.sound_speed_compression and medium.sound_speed_shear must '
569 'be defined.')
571 # extract the maximum supported frequency (two points per wavelength)
572 f_max = kgrid.k_max_all * c0 / (2 * np.pi)
574 # calculate the filter cut-off frequency
575 filter_cutoff_f = 2 * f_max / ppw
577 # calculate the wavelength of the filter cut-off frequency as a number of time steps
578 filter_wavelength = ((2 * np.pi / filter_cutoff_f) / kgrid.dt)
580 # filter the signal if required
581 if ppw != 0:
582 filtered_signal = apply_filter(signal, Fs, float(filter_cutoff_f), 'LowPass',
583 zero_phase=zerophase, stop_band_atten=float(stop_band_atten),
584 transition_width=transition_width)
586 # add a start-up ramp if required
587 if rppw != 0:
588 # calculate the length of the ramp in time steps
589 ramp_length = round(rppw * filter_wavelength / (2 * ppw))
591 # create the ramp
592 ramp = (-np.cos(np.arange(0, ramp_length - 1 + 1) * np.pi / ramp_length) + 1) / 2
594 # apply the ramp
595 filtered_signal[1:ramp_length] = filtered_signal[1:ramp_length] * ramp
597 # restore the original vector orientation if modified
598 if rotate_signal:
599 filtered_signal = filtered_signal.T
601 # update the command line status
602 print(f' maximum frequency supported by kgrid: {scale_SI(f_max)}Hz (2 PPW)')
603 if ppw != 0:
604 print(f' filter cutoff frequency: {scale_SI(filter_cutoff_f)}Hz ({ppw} PPW)')
605 if rppw != 0:
606 print(
607 f' ramp frequency: {scale_SI(2 * np.pi / (2 * ramp_length * kgrid.dt))}Hz (ramp_points_per_wavelength PPW)')
608 print(' computation complete.')
610 # plot signals if required
611 if plot_signals or plot_spectrums:
612 raise NotImplementedError
614 return filtered_signal
617def apply_filter(signal, Fs, cutoff_f, filter_type, zero_phase=False, transition_width=0.1, stop_band_atten=60):
618 """
619 applyFilter filters an input signal using filter. The FIR filter
620 coefficients are based on a Kaiser window with the specified cut-off
621 frequency and filter type ('HighPass', 'LowPass' or 'BandPass'). Both
622 causal and zero phase filters can be applied.
624 Args:
625 signal:
626 Fs:
627 cutoff_f:
628 filter_type:
629 zero_phase:
630 transition_width: as proportion of sampling frequency
631 stop_band_atten: [dB]
633 Returns:
635 """
637 # for a bandpass filter, use applyFilter recursively
638 if filter_type == 'BandPass':
639 assert isinstance(cutoff_f, list), "List of two frequencies required as for filter type 'BandPass'"
640 assert len(cutoff_f) == 2, "List of two frequencies required as for filter type 'BandPass'"
642 # apply the low pass filter
643 func_filt_lp = apply_filter(signal, Fs, cutoff_f[1], 'LowPass', stop_band_atten=stop_band_atten,
644 transition_width=transition_width, zero_phase=zero_phase)
646 # apply the high pass filter
647 filtered_signal = apply_filter(func_filt_lp, Fs, cutoff_f[0], 'HighPass', stop_band_atten=stop_band_atten,
648 transition_width=transition_width, zero_phase=zero_phase)
650 else:
652 # check filter type
653 if filter_type == 'LowPass':
654 high_pass = False
655 elif filter_type == 'HighPass':
656 high_pass = True
657 cutoff_f = (Fs / 2 - cutoff_f)
658 else:
659 raise ValueError(f'Unknown filter type {filter_type}. Options are "LowPass, HighPass, BandPass"')
661 # make sure input is the correct way around
662 m, n = signal.shape
663 if m > n:
664 signal = signal.T
666 # correct the stopband attenuation if a zero phase filter is being used
667 if zero_phase:
668 stop_band_atten = stop_band_atten / 2
670 # decide the filter order
671 N = np.ceil((stop_band_atten - 7.95) / (2.285 * (transition_width * np.pi)))
672 N = int(N)
674 # construct impulse response of ideal bandpass filter h(n), a sinc function
675 fc = cutoff_f / Fs # normalised cut-off
676 n = np.arange(-N / 2, N / 2)
677 h = 2 * fc * sinc(2 * np.pi * fc * n)
679 # if no window is given, use a Kaiser window
680 # TODO: there is no window argument
681 if 'w' not in locals():
683 # compute Kaiser window parameter beta
684 if stop_band_atten > 50:
685 beta = 0.1102 * (stop_band_atten - 8.7)
686 elif stop_band_atten >= 21:
687 beta = 0.5842 * (stop_band_atten - 21) ^ 0.4 + 0.07886 * (stop_band_atten - 21)
688 else:
689 beta = 0
691 # construct the Kaiser smoothing window w(n)
692 m = np.arange(0, N)
693 w = np.real(scipy.special.iv(0, np.pi * beta * np.sqrt(1 - (2 * m / N - 1) ** 2))) / np.real(
694 scipy.special.iv(0, np.pi * beta))
696 # window the ideal impulse response with Kaiser window to obtain the FIR filter coefficients hw(n)
697 hw = w * h
699 # modify to make a high_pass filter
700 if high_pass:
701 hw = (-1 * np.ones((1, len(hw))) ** (np.arange(1, len(hw) + 1))) * hw
703 # add some zeros to allow the reverse (zero phase) filtering room to work
704 L = signal.size # length of original input signal
705 filtered_signal = np.hstack([np.zeros((1, N)), signal]).squeeze()
707 # apply the filter
708 filtered_signal = lfilter(hw.squeeze(), 1, filtered_signal)
709 if zero_phase:
710 filtered_signal = np.fliplr(lfilter(hw.squeeze(), 1, filtered_signal[np.arange(L + N, 1, -1)]))
712 # remove the part of the signal corresponding to the added zeros
713 filtered_signal = filtered_signal[N:]
715 return filtered_signal[np.newaxis]
718def smooth(A, restore_max=False, window_type="Blackman"):
719 """
720 Smooth a matrix.
722 DESCRIPTION:
723 smooth filters an input matrix using an n - dimensional frequency
724 domain window created using get_win. If no window type is specified, a
725 Blackman window is used.
727 Args:
728 A: spatial distribution to smooth
729 restore_max: Boolean controlling whether the maximum value is restored after smoothing(default=false).
730 window_type: shape of the smoothing window; any valid inputs to get_win are supported(default='Blackman').
732 OUTPUTS:
733 A_sm - smoothed
734 """
735 # get the grid size
736 grid_size = A.shape
738 # remove singleton dimensions
739 if num_dim(A) is not len(grid_size):
740 A = A.squeeze()
741 grid_size = A.shape
743 # use a symmetric filter for odd grid sizes, and a non-symmetric filter for
744 # even grid sizes to ensure the DC component of the window has a value of
745 # unity
746 window_symmetry = [bool(n % 2) for n in grid_size]
748 # get the window, taking the absolute value to discard machine precision
749 # negative values
750 win_tmp, _ = get_win(grid_size, window_type, rotation=True, symmetric=window_symmetry)
751 win = abs(win_tmp)
753 # rotate window if input A is (1, N)
754 if win.shape[0] == 1:
755 win = win.transpose()
757 # apply the filter
758 A_sm = np.real(ifftn(fftn(A) * ifftshift(win)))
760 # restore magnitude if required
761 if restore_max:
762 A_sm = (max(abs(A[:])) / max(abs(A_sm[:]))) * A_sm
764 return A_sm