Coverage for kwave/utils/filterutils.py: 9%

276 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-24 11:55 -0700

1import numpy as np 

2from kwave.utils.kutils import get_win 

3from kwave.utils.misc import find_closest, sinc 

4import scipy 

5from scipy.signal import lfilter 

6from scipy.fftpack import fft, ifft, ifftshift, fftshift, fftn, ifftn 

7import math 

8from math import pi 

9 

10from kwave.utils.conversionutils import scale_SI 

11from kwave.utils.checkutils import num_dim, num_dim2 

12 

13 

14# Compute the next highest power of 2 of a 32–bit number `n` 

15def next_pow2(n): 

16 # decrement `n` (to handle cases when `n` itself is a power of 2) 

17 n = n - 1 

18 

19 # set all bits after the last set bit 

20 n |= n >> 1 

21 n |= n >> 2 

22 n |= n >> 4 

23 n |= n >> 8 

24 n |= n >> 16 

25 

26 # increment `n` and return 

27 return n + 1 

28 

29 

30def single_sided_correction(func_fft, fft_len, dim): 

31 """ 

32 correct the single - sided magnitude by multiplying the symmetric points by 

33 2(the DC and Nyquist components are unique and are not multiplied by 2 

34 and the Nyquist component only exists for even numbered FFT lengths) 

35 """ 

36 if fft_len % 2: 

37 

38 # odd FFT length switch dim case 

39 if dim == 0: 

40 func_fft[1:, :] = func_fft[1:, :] * 2 

41 elif dim == 1: 

42 func_fft[:, 1:] = func_fft[:, 1:] * 2 

43 elif dim == 2: 

44 func_fft[:, :, 1:] = func_fft[:, :, 1:] * 2 

45 elif dim == 3: 

46 func_fft[:, :, :, 1:] = func_fft[:, :, :, 1:] * 2 

47 else: 

48 

49 # even FFT length 

50 if dim == 0: 

51 func_fft[1: -1, :, :, :] = func_fft[1: -1, :, :, :] * 2 

52 elif dim == 2: 

53 func_fft[:, 1: -1, :, :] = func_fft[:, 1: -1, :, :] * 2 

54 elif dim == 3: 

55 func_fft[:, :, 1: -1, :] = func_fft[:, :, 1: -1, :] * 2 

56 elif dim == 4: 

57 func_fft[:, :, :, 1: -1] = func_fft[:, :, :, 1: -1] * 2 

58 

59 return func_fft 

60 

61 

62def spect(func, Fs, dim='auto', fft_len=0, power_two=False, unwrap=False, window='Rectangular'): 

63 """ 

64 

65 Args: 

66 func: signal to analyse 

67 Fs: sampling frequency [Hz] 

68 dim: dimension over which the spectrum is calculated 

69 fft_len: length of FFT. If the set 

70 length is smaller than the signal length, the default 

71 value is used instead (default = signal length). 

72 power_two: Boolean controlling whether the FFT length is forced to 

73 be the next highest power of 2 (default = false). 

74 unwrap (bool): 

75 window: parameter string controlling the window type used to 

76 filter the signal before the FFT is taken (default = 

77 'Rectangular'). Any valid input types for get_win may be 

78 used. 

79 

80 Returns: 

81 f: frequency array 

82 func_as: single-sided amplitude spectrum 

83 func_ps: single-sided phase spectrum 

84 

85 """ 

86 

87 # check the size of the input 

88 sz = func.shape 

89 

90 # check input isn't scalar 

91 if np.size(func) == 1: 

92 raise ValueError('Input signal cannot be scalar.') 

93 

94 # check input doesn't have more than 4 dimensions 

95 if len(sz) > 4: 

96 raise ValueError('Input signal must have 1, 2, 3, or 4 dimensions.') 

97 

98 # automatically set dimension to first non - singleton dimension 

99 if dim == 'auto': 

100 dim_index = 0 

101 while dim_index <= len(sz): 

102 if sz[dim_index] > 1: 

103 dim = dim_index 

104 break 

105 dim_index = dim_index + 1 

106 

107 # assign the number of points being analysed 

108 func_length = sz[dim] 

109 

110 # set the length of the FFT 

111 if not fft_len > func_length: 

112 if power_two: 

113 # find an appropriate FFT length of the form 2 ^ N that is equal to or 

114 # larger than the length of the input signal 

115 fft_len = 2 ** (next_pow2(func_length)) 

116 else: 

117 # set the FFT length to the function length 

118 fft_len = func_length 

119 

120 # window the signal, reshaping the window to be in the correct direction 

121 win, coherent_gain = get_win(func_length, window, symmetric=False) 

122 win = np.reshape(win, tuple(([1] * dim + [func_length] + [1] * (len(sz) - 2)))) 

123 func = win * func 

124 

125 # compute the fft using the defined FFT length, if fft_len > 

126 # func_length, the input signal is padded with zeros 

127 func_fft = fft(func, fft_len, dim) 

128 

129 # correct for the magnitude scaling of the FFT and the coherent gain of the 

130 # window(note that the correction is equal to func_length NOT fft_len) 

131 func_fft = func_fft / (func_length * coherent_gain) 

132 

133 # reduce to a single sided spectrum where the number of unique points for 

134 # even numbered FFT lengths is given by N / 2 + 1, and for odd(N + 1) / 2 

135 num_unique_pts = int(np.ceil((fft_len + 1) / 2)) 

136 if dim == 0: 

137 func_fft = func_fft[0:num_unique_pts] 

138 elif dim == 1: 

139 func_fft = func_fft[:, 0: num_unique_pts] 

140 elif dim == 2: 

141 func_fft = func_fft[:, :, 0: num_unique_pts] 

142 elif dim == 3: 

143 func_fft = func_fft[:, :, :, 0: num_unique_pts] 

144 

145 func_fft = single_sided_correction(func_fft, fft_len, dim) 

146 

147 # create the frequency axis variable 

148 f = np.arange(0, func_fft.shape[dim]) * Fs / fft_len 

149 

150 # calculate the amplitude spectrum 

151 func_as = np.abs(func_fft) 

152 

153 # calculate the phase spectrum 

154 func_ps = np.angle(func_fft) 

155 

156 # unwrap the phase spectrum if required 

157 if unwrap: 

158 func_ps = unwrap(func_ps, [], dim) 

159 

160 return f, func_as, func_ps 

161 

162 

163def extract_amp_phase(data, Fs, source_freq, dim='auto', fft_padding=3, window='Hanning'): 

164 """ 

165 extract_amp_phase extracts the amplitude and phase information at a 

166 specified frequency from a vector or matrix of time series data. By 

167 default the time dimension is set to the highest non-singleton 

168 dimension. The amplitude and phase are extracted from the frequency 

169 spectrum, which is calculated using a windowed and zero padded FFT. 

170 The values are extracted at the frequency closest to source_freq. 

171 

172 Args: 

173 

174 data: matrix of time signals [s] 

175 Fs: sampling frequency [Hz] 

176 source_freq: frequency at which the amplitude and phase should be 

177 extracted [Hz] 

178 dim: 

179 fft_padding: 

180 window: 

181 

182 Returns: 

183 

184 """ 

185 

186 # check for the dim input 

187 if dim == 'auto': 

188 dim = num_dim(data) 

189 if dim == 2 and data.shape[1] == 1: 

190 dim = 1 

191 

192 # create 1D window and reshape to be oriented in the time dimension of the 

193 # input data 

194 win, coherent_gain = get_win(data.shape[dim], window) 

195 # this list magic in Python comes from the use of ones in MATLAB 

196 # TODO: simplify this 

197 win = np.reshape(win, [1] * (dim - 1) + [len(win)]) 

198 

199 # apply window to time dimension of input data 

200 data = win * data 

201 

202 # compute amplitude and phase spectra 

203 f, func_as, func_ps = spect(data, Fs, fft_len=fft_padding * data.shape[dim], dim=dim) 

204 

205 # correct for coherent gain 

206 func_as = func_as / coherent_gain 

207 

208 # find the index of the frequency component closest to source_freq 

209 _, f_index = find_closest(f, source_freq) 

210 

211 # get size of output variable, collapsing the time dimension 

212 sz = list(data.shape) 

213 sz[dim - 1] = 1 

214 

215 # extract amplitude and relative phase at freq_index 

216 if dim == 0: 

217 amp = func_as[f_index] 

218 phase = func_ps[f_index] 

219 elif dim == 1: 

220 amp = func_as[:, f_index] 

221 phase = func_ps[:, f_index] 

222 elif dim == 2: 

223 amp = func_as[:, :, f_index] 

224 phase = func_ps[:, :, f_index] 

225 elif dim == 3: 

226 amp = func_as[:, :, :, f_index] 

227 phase = func_ps[:, :, :, f_index] 

228 else: 

229 raise ValueError('dim must be 0, 1, 2, or 3'); 

230 

231 return amp.squeeze(), phase.squeeze(), f[f_index] 

232 

233 

234def create_cw_signals(t_array, freq, amp, phase, ramp_length=4): 

235 """ 

236 create_cw_signals generates a series of continuous wave (CW) signals 

237 based on the 1D or 2D input matrices amp and phase, where each signal 

238 is given by: 

239 

240 amp(i, j) .* sin(2 .* pi .* freq .* t_array + phase(i, j)); 

241 

242 To avoid startup transients, a cosine tapered up-ramp is applied to 

243 the beginning of the signal. By default, the length of this ramp is 

244 four periods of the wave. The up-ramp can be turned off by setting 

245 the ramp_length to 0. 

246 

247 Example: 

248 

249 # define sampling parameters 

250 f = 5e6 

251 T = 1/f 

252 Fs = 100e6 

253 dt = 1/Fs 

254 t_array = np.arange(0, 10*T, dt) 

255 

256 # define amplitude and phase 

257 amp = get_win(9, 'Gaussian') 

258 phase = np.arange(0, 2*pi, 9).T 

259 

260 # create signals and plot 

261 cw_signal = create_cw_signals(t_array, f, amp, phase) 

262 

263 Args: 

264 t_array: 

265 freq: 

266 amp: 

267 phase: 

268 ramp_length: 

269 

270 Returns: 

271 cw_signals: 

272 

273 """ 

274 if len(phase) == 1: 

275 phase = phase * np.ones(amp.shape) 

276 

277 N1, N2 = amp.T.shape 

278 

279 cw_signals = np.zeros([N1, N2, len(t_array)]) 

280 

281 for idx1 in range(N1 - 1): 

282 for idx2 in range(N2 - 1): 

283 cw_signals[idx1, idx2, :] = amp[idx1, idx2] * np.sin(2 * pi * freq * t_array + phase[idx1, idx2]) 

284 

285 if ramp_length != 0: 

286 # get period and time-step 

287 period = 1 / freq 

288 dt = t_array[1] - t_array[0] 

289 

290 # create ramp x-axis between 0 and pi 

291 ramp_length_points = int(np.round(ramp_length * period / dt)) 

292 ramp_axis = np.arange(0, pi, pi / (ramp_length_points)) 

293 

294 # create ramp using a shifted cosine 

295 ramp = (-np.cos(ramp_axis) + 1) * 0.5 

296 ramp = np.reshape(ramp, (1, 1, -1)) 

297 

298 # apply ramp to all signals simultaneously 

299 

300 cw_signals[:, :, :ramp_length_points] *= ramp 

301 

302 return np.squeeze(cw_signals) 

303 

304 

305def envelope_detection(signal): 

306 """ 

307 envelopeDetection applies the Hilbert transform to extract the 

308 envelope from an input vector x. If x is a matrix, the envelope along 

309 the last axis. 

310 

311 Args: 

312 signal: 

313 

314 Returns: 

315 signal_envelope: 

316 

317 """ 

318 

319 return np.abs(scipy.signal.hilbert(signal)) 

320 

321 

322def brenner_sharpness(im): 

323 ndim = len(np.squeeze(im).shape) 

324 

325 if ndim == 2: 

326 bren_x = (im[:-2, :] - im[2:, :]) ** 2 

327 bren_y = (im[:, :-2] - im[:, 2:]) ** 2 

328 s = bren_x.sum() + bren_y.sum() 

329 elif ndim == 3: 

330 bren_x = (im[:-2, :, :] - im[2:, :, :]) ** 2 

331 bren_y = (im[:, :-2, :] - im[:, 2:, :]) ** 2 

332 bren_z = (im[:, :, :-2] - im[:, :, 2:]) ** 2 

333 s = bren_x.sum() + bren_y.sum() + bren_z.sum() 

334 else: 

335 raise ValueError("Invalid number of dimensions in im") 

336 return s 

337 

338 

339def norm_var(im): 

340 mu = np.mean(im) 

341 s = np.sum((im - mu) ** 2) / mu 

342 return s 

343 

344 

345def tenenbaum_sharpness(im): 

346 ndim = len(np.squeeze(im).shape) 

347 if ndim == 2: 

348 sobel = scipy.ndimage.sobel(im) 

349 elif ndim == 3: 

350 sobel = scipy.ndimage.sobel(im) 

351 

352 else: 

353 raise ValueError("Invalid number of dimensions in im") 

354 return sobel.sum() 

355 

356 

357def sharpness(im, metric="Brenner"): 

358 """ 

359 sharpness returns a scalar metric related to the sharpness of the 2D 

360 or 3D image matrix defined by im. By default, the metric is based on 

361 the Brenner gradient which returns the sum of the centered 

362 finite-difference at each matrix element in each Cartesian direction. 

363 Metrics calculated using the Sobel operator or the normalised 

364 variance can also be returned by setting the input paramater metric. 

365 

366 For further details, see B. E. Treeby, T. K. Varslot, E. Z. Zhang, 

367 J. G. Laufer, and P. C. Beard, "Automatic sound speed selection in 

368 photoacoustic image reconstruction using an autofocus approach," J. 

369 Biomed. Opt., vol. 16, no. 9, p. 090501, 2011. 

370 

371 Args: 

372 im: 

373 metric (str): Defaults "Brenner" 

374 

375 Returns: 

376 sharp_met 

377 

378 """ 

379 assert isinstance(im, np.ndarray), "Argument im must be of type numpy array" 

380 

381 if metric == "Brenner": 

382 sharp_met = brenner_sharpness(im) 

383 elif metric == "Tenenbaum": 

384 sharp_met = tenenbaum_sharpness(im) 

385 elif metric == "NormVariance": 

386 sharp_met = norm_var(im) 

387 else: 

388 raise ValueError( 

389 "Unrecognized sharpness metric passed. Valid values are ['Brenner', 'Tanenbaum', 'NormVariance']") 

390 

391 return sharp_met 

392 

393 

394def fwhm(f, x): 

395 """ 

396 fwhm calculates the Full Width at Half Maximum (FWHM) of a positive 

397 1D input function f(x) with spacing given by x. 

398 

399 

400 Args: 

401 f: f(x) 

402 x: x 

403 

404 Returns: 

405 fwhm_val: FWHM of f(x) 

406 

407 """ 

408 

409 # ensure f is numpy array 

410 f = np.array(f) 

411 if len(f.squeeze().shape) != 1: 

412 raise ValueError("Input function must be 1-dimensional.") 

413 

414 def lin_interp(x, y, i, half): 

415 return x[i] + (x[i + 1] - x[i]) * ((half - y[i]) / (y[i + 1] - y[i])) 

416 

417 def half_max_x(x, y): 

418 half = max(y) / 2.0 

419 signs = np.sign(np.add(y, -half)) 

420 zero_crossings = (signs[0:-2] != signs[1:-1]) 

421 zero_crossings_i = np.where(zero_crossings)[0] 

422 return [lin_interp(x, y, zero_crossings_i[0], half), 

423 lin_interp(x, y, zero_crossings_i[1], half)] 

424 

425 hmx = half_max_x(x, f) 

426 fwhm_val = hmx[1] - hmx[0] 

427 

428 return fwhm_val 

429 

430 

431def gaussian(x, magnitude=None, mean=0, variance=1): 

432 """ 

433 gaussian returns a Gaussian distribution f(x) with the specified 

434 magnitude, mean, and variance. If these values are not specified, the 

435 magnitude is normalised and values of variance = 1 and mean = 0 are 

436 used. For example running 

437 

438 import matplotlib.pyplot as plt 

439 x = np.arrange(-3,0.05,3) 

440 plt.plot(x, gaussian(x)) 

441 

442 will plot a normalised Gaussian distribution. 

443 

444 Note, the full width at half maximum of the resulting distribution 

445 can be calculated by FWHM = 2 * sqrt(2 * log(2) * variance). 

446 

447 

448 Args: 

449 x: 

450 magnitude: Bell height. Defaults to normalized. 

451 mean (float): mean or expected value. Defaults to 0. 

452 variance (float): variance ~ bell width. Defaults to 1. 

453 

454 Returns: 

455 gauss_distr: Gaussian distribution 

456 

457 """ 

458 if magnitude is None: 

459 magnitude = (2 * math.pi * variance) ** -0.5 

460 

461 gauss_distr = magnitude * np.exp(-(x - mean) ** 2 / (2 * variance)) 

462 

463 return gauss_distr 

464 # return magnitude * norm.pdf(x, loc=mean, scale=variance) 

465 

466 

467def gaussian_filter(signal, Fs, frequency, bandwidth): 

468 """ 

469 gaussian_filter applies a frequency domain Gaussian filter with the 

470 specified center frequency and percentage bandwidth to the input 

471 signal. If the input signal is given as a matrix, the filter is 

472 applied to each matrix row. 

473 

474 Args: 

475 signal: signal to filter 

476 Fs: sampling frequency [Hz] 

477 frequency: center frequency of filter [Hz] 

478 bandwidth: bandwidth of filter 

479 

480 Returns: 

481 signal: filtered signal 

482 

483 """ 

484 N = len(signal) 

485 if N % 2 == 0: 

486 f = np.arange(-N / 2, N / 2) * Fs / N 

487 else: 

488 f = np.arange(-(N - 1) / 2, (N - 1) / 2 + 1) * Fs / N 

489 

490 mean = frequency 

491 variance = (bandwidth / 100 * frequency / (2 * np.sqrt(2 * np.log(2)))) ** 2 

492 magnitude = 1 

493 

494 # create double-sided Gaussain filter 

495 gfilter = np.fmax(gaussian(f, magnitude, mean, variance), gaussian(f, magnitude, -mean, variance)) 

496 

497 # apply filter 

498 signal = np.real(ifft(ifftshift(gfilter * fftshift(fft(signal))))) 

499 

500 return signal 

501 

502 

503def filterTimeSeries(kgrid, medium, signal, ppw=3, rppw=0, stop_band_atten=60, transition_width=0.1, zerophase=False, 

504 plot_spectrums=False, plot_signals=False): 

505 """ 

506 Filter signal using the Kaiser windowing method 

507 filterTimeSeries filters an input time domain signal using a low pass 

508 filter applied by applyFilter with a specified cut-off frequency, 

509 stop-band attenuation, and transition bandwidth. It uses the Kaiser 

510 Windowing method to design the FIR filter, which can be implemented 

511 as either a zero phase or linear phase filter. The cutoff frequency 

512 is defined by a minimum number of points per wavelength. A smoothing 

513 ramp can also be applied to the beginning of the signal to reduce 

514 high frequency transients. 

515 Args: 

516 kgrid: 

517 medium: 

518 signal: 

519 ppw: Points-Per-Wavelength (default 3) 

520 rppw: Ramp-Points-Per-Wavelength (default 0) 

521 stop_band_atten: Stop-Band-Attenuation (default 60) 

522 transition_width: Transition-width (default 0.1) 

523 zero-phase: (default False) 

524 

525 

526 Returns: 

527 

528 """ 

529 

530 # check the input is a row vector 

531 if num_dim2(signal) == 1: 

532 m, n = signal.shape 

533 if n == 1: 

534 signal = signal.T 

535 rotate_signal = True 

536 else: 

537 rotate_signal = False 

538 else: 

539 raise TypeError('Input signal must be a vector.') 

540 

541 # update the command line status 

542 print('Filtering input signal...') 

543 

544 # extract the time step 

545 assert not isinstance(kgrid.t_array, str) or kgrid.t_array != 'auto', 'kgrid.t_array must be explicitly defined.' 

546 

547 # compute the sampling frequency 

548 Fs = 1 / kgrid.dt 

549 

550 # extract the minium sound speed 

551 if medium.sound_speed is not None: 

552 

553 # for the fluid code, use medium.sound_speed 

554 c0 = medium.sound_speed.min() 

555 

556 elif all(medium.is_defined('sound_speed_compression', 'sound_speed_shear')): # pragma: no cover 

557 

558 # for the elastic code, combine the shear and compression sound speeds and remove zeros values 

559 ss = np.hstack([medium.sound_speed_compression, medium.sound_speed_shear]) 

560 ss[ss == 0] = np.nan 

561 c0 = np.nanmin(ss) 

562 

563 # cleanup unused variables 

564 del ss 

565 

566 else: 

567 raise ValueError( 

568 'The input fields medium.sound_speed or medium.sound_speed_compression and medium.sound_speed_shear must ' 

569 'be defined.') 

570 

571 # extract the maximum supported frequency (two points per wavelength) 

572 f_max = kgrid.k_max_all * c0 / (2 * np.pi) 

573 

574 # calculate the filter cut-off frequency 

575 filter_cutoff_f = 2 * f_max / ppw 

576 

577 # calculate the wavelength of the filter cut-off frequency as a number of time steps 

578 filter_wavelength = ((2 * np.pi / filter_cutoff_f) / kgrid.dt) 

579 

580 # filter the signal if required 

581 if ppw != 0: 

582 filtered_signal = apply_filter(signal, Fs, float(filter_cutoff_f), 'LowPass', 

583 zero_phase=zerophase, stop_band_atten=float(stop_band_atten), 

584 transition_width=transition_width) 

585 

586 # add a start-up ramp if required 

587 if rppw != 0: 

588 # calculate the length of the ramp in time steps 

589 ramp_length = round(rppw * filter_wavelength / (2 * ppw)) 

590 

591 # create the ramp 

592 ramp = (-np.cos(np.arange(0, ramp_length - 1 + 1) * np.pi / ramp_length) + 1) / 2 

593 

594 # apply the ramp 

595 filtered_signal[1:ramp_length] = filtered_signal[1:ramp_length] * ramp 

596 

597 # restore the original vector orientation if modified 

598 if rotate_signal: 

599 filtered_signal = filtered_signal.T 

600 

601 # update the command line status 

602 print(f' maximum frequency supported by kgrid: {scale_SI(f_max)}Hz (2 PPW)') 

603 if ppw != 0: 

604 print(f' filter cutoff frequency: {scale_SI(filter_cutoff_f)}Hz ({ppw} PPW)') 

605 if rppw != 0: 

606 print( 

607 f' ramp frequency: {scale_SI(2 * np.pi / (2 * ramp_length * kgrid.dt))}Hz (ramp_points_per_wavelength PPW)') 

608 print(' computation complete.') 

609 

610 # plot signals if required 

611 if plot_signals or plot_spectrums: 

612 raise NotImplementedError 

613 

614 return filtered_signal 

615 

616 

617def apply_filter(signal, Fs, cutoff_f, filter_type, zero_phase=False, transition_width=0.1, stop_band_atten=60): 

618 """ 

619 applyFilter filters an input signal using filter. The FIR filter 

620 coefficients are based on a Kaiser window with the specified cut-off 

621 frequency and filter type ('HighPass', 'LowPass' or 'BandPass'). Both 

622 causal and zero phase filters can be applied. 

623 

624 Args: 

625 signal: 

626 Fs: 

627 cutoff_f: 

628 filter_type: 

629 zero_phase: 

630 transition_width: as proportion of sampling frequency 

631 stop_band_atten: [dB] 

632 

633 Returns: 

634 

635 """ 

636 

637 # for a bandpass filter, use applyFilter recursively 

638 if filter_type == 'BandPass': 

639 assert isinstance(cutoff_f, list), "List of two frequencies required as for filter type 'BandPass'" 

640 assert len(cutoff_f) == 2, "List of two frequencies required as for filter type 'BandPass'" 

641 

642 # apply the low pass filter 

643 func_filt_lp = apply_filter(signal, Fs, cutoff_f[1], 'LowPass', stop_band_atten=stop_band_atten, 

644 transition_width=transition_width, zero_phase=zero_phase) 

645 

646 # apply the high pass filter 

647 filtered_signal = apply_filter(func_filt_lp, Fs, cutoff_f[0], 'HighPass', stop_band_atten=stop_band_atten, 

648 transition_width=transition_width, zero_phase=zero_phase) 

649 

650 else: 

651 

652 # check filter type 

653 if filter_type == 'LowPass': 

654 high_pass = False 

655 elif filter_type == 'HighPass': 

656 high_pass = True 

657 cutoff_f = (Fs / 2 - cutoff_f) 

658 else: 

659 raise ValueError(f'Unknown filter type {filter_type}. Options are "LowPass, HighPass, BandPass"') 

660 

661 # make sure input is the correct way around 

662 m, n = signal.shape 

663 if m > n: 

664 signal = signal.T 

665 

666 # correct the stopband attenuation if a zero phase filter is being used 

667 if zero_phase: 

668 stop_band_atten = stop_band_atten / 2 

669 

670 # decide the filter order 

671 N = np.ceil((stop_band_atten - 7.95) / (2.285 * (transition_width * np.pi))) 

672 N = int(N) 

673 

674 # construct impulse response of ideal bandpass filter h(n), a sinc function 

675 fc = cutoff_f / Fs # normalised cut-off 

676 n = np.arange(-N / 2, N / 2) 

677 h = 2 * fc * sinc(2 * np.pi * fc * n) 

678 

679 # if no window is given, use a Kaiser window 

680 # TODO: there is no window argument 

681 if 'w' not in locals(): 

682 

683 # compute Kaiser window parameter beta 

684 if stop_band_atten > 50: 

685 beta = 0.1102 * (stop_band_atten - 8.7) 

686 elif stop_band_atten >= 21: 

687 beta = 0.5842 * (stop_band_atten - 21) ^ 0.4 + 0.07886 * (stop_band_atten - 21) 

688 else: 

689 beta = 0 

690 

691 # construct the Kaiser smoothing window w(n) 

692 m = np.arange(0, N) 

693 w = np.real(scipy.special.iv(0, np.pi * beta * np.sqrt(1 - (2 * m / N - 1) ** 2))) / np.real( 

694 scipy.special.iv(0, np.pi * beta)) 

695 

696 # window the ideal impulse response with Kaiser window to obtain the FIR filter coefficients hw(n) 

697 hw = w * h 

698 

699 # modify to make a high_pass filter 

700 if high_pass: 

701 hw = (-1 * np.ones((1, len(hw))) ** (np.arange(1, len(hw) + 1))) * hw 

702 

703 # add some zeros to allow the reverse (zero phase) filtering room to work 

704 L = signal.size # length of original input signal 

705 filtered_signal = np.hstack([np.zeros((1, N)), signal]).squeeze() 

706 

707 # apply the filter 

708 filtered_signal = lfilter(hw.squeeze(), 1, filtered_signal) 

709 if zero_phase: 

710 filtered_signal = np.fliplr(lfilter(hw.squeeze(), 1, filtered_signal[np.arange(L + N, 1, -1)])) 

711 

712 # remove the part of the signal corresponding to the added zeros 

713 filtered_signal = filtered_signal[N:] 

714 

715 return filtered_signal[np.newaxis] 

716 

717 

718def smooth(A, restore_max=False, window_type="Blackman"): 

719 """ 

720 Smooth a matrix. 

721 

722 DESCRIPTION: 

723 smooth filters an input matrix using an n - dimensional frequency 

724 domain window created using get_win. If no window type is specified, a 

725 Blackman window is used. 

726 

727 Args: 

728 A: spatial distribution to smooth 

729 restore_max: Boolean controlling whether the maximum value is restored after smoothing(default=false). 

730 window_type: shape of the smoothing window; any valid inputs to get_win are supported(default='Blackman'). 

731 

732 OUTPUTS: 

733 A_sm - smoothed 

734 """ 

735 # get the grid size 

736 grid_size = A.shape 

737 

738 # remove singleton dimensions 

739 if num_dim(A) is not len(grid_size): 

740 A = A.squeeze() 

741 grid_size = A.shape 

742 

743 # use a symmetric filter for odd grid sizes, and a non-symmetric filter for 

744 # even grid sizes to ensure the DC component of the window has a value of 

745 # unity 

746 window_symmetry = [bool(n % 2) for n in grid_size] 

747 

748 # get the window, taking the absolute value to discard machine precision 

749 # negative values 

750 win_tmp, _ = get_win(grid_size, window_type, rotation=True, symmetric=window_symmetry) 

751 win = abs(win_tmp) 

752 

753 # rotate window if input A is (1, N) 

754 if win.shape[0] == 1: 

755 win = win.transpose() 

756 

757 # apply the filter 

758 A_sm = np.real(ifftn(fftn(A) * ifftshift(win))) 

759 

760 # restore magnitude if required 

761 if restore_max: 

762 A_sm = (max(abs(A[:])) / max(abs(A_sm[:]))) * A_sm 

763 

764 return A_sm