Coverage for kwave/utils/interputils.py: 20%

177 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-24 11:55 -0700

1from typing import List 

2 

3import numpy as np 

4from numpy.fft import fft, fftshift 

5from scipy.interpolate import interpn 

6from scipy.signal import resample 

7from kwave.utils.tictoc import TicToc 

8from kwave.utils.checkutils import num_dim 

9from kwave.utils.conversionutils import scale_time 

10 

11 

12def sortrows(arr: np.ndarray, index: int): 

13 assert arr.ndim == 2, "'sortrows' currently supports only 2-dimensional matrices" 

14 return arr[arr[:, index].argsort(),] 

15 

16 

17def interpolate3D(grid_points: List[np.ndarray], grid_values: np.ndarray, interp_locs: List[np.ndarray]) -> np.ndarray: 

18 """ 

19 Interpolates input grid values at the given locations 

20 Added by Farid 

21 

22 Matlab version of this function assumes unstructured grid. Interpolating such grid in Python using 

23 SciPy is very expensive. Thankfully, working with structured grid is fine for our purposes. 

24 We still support 3D arguments for backward compatibility even though they are mapped to 1D grid. 

25 While mapping we assume that only one axis per 3D grid changes throughout the grid. 

26 Args: 

27 grid_points: List of 1D or 3D Numpy arrays 

28 grid_values: A 3D Numpy array which holds values at grid_points 

29 interp_locs: List of 1D or 3D Numpy arrays 

30 Returns: 

31 

32 """ 

33 assert len(grid_points) == 3, 'interpolate3D supports only 3D interpolation' 

34 assert len(grid_points) == len(interp_locs) 

35 

36 def unpack_and_make_1D(pts): 

37 pts_x, pts_y, pts_z = pts 

38 if pts_x.ndim == 3: 

39 pts_x = pts_x[:, 0, 0] 

40 if pts_y.ndim == 3: 

41 pts_y = pts_y[0, :, 0] 

42 if pts_z.ndim == 3: 

43 pts_z = pts_z[0, 0, :] 

44 return pts_x, pts_y, pts_z 

45 

46 g_x, g_y, g_z = unpack_and_make_1D(grid_points) 

47 q_x, q_y, q_z = unpack_and_make_1D(interp_locs) 

48 

49 # 'ij' indexing is crucial for Matlab compatibility 

50 queries = np.array(np.meshgrid(q_x, q_y, q_z, indexing='ij')) 

51 # Queries are just a list of 3D points 

52 queries = queries.reshape(3, -1).T 

53 

54 # Out of bound points will get NaN values 

55 result = interpn((g_x, g_y, g_z), grid_values, queries, method='linear', bounds_error=False, fill_value=np.nan) 

56 # Go back from list of interpolated values to 3D volume 

57 result = result.reshape((g_x.size, g_y.size, g_z.size)) 

58 # set values outside of the interpolation range to original values 

59 result[np.isnan(result)] = grid_values[np.isnan(result)] 

60 return result 

61 

62 

63def interpolate2D(grid_points: List[np.ndarray], grid_values: np.ndarray, interp_locs: List[np.ndarray], 

64 method='linear', copy_nans=True) -> np.ndarray: 

65 """ 

66 Interpolates input grid values at the given locations 

67 Added by Farid 

68 

69 Matlab version of this function assumes unstructured grid. Interpolating such grid in Python using 

70 SciPy is very expensive. Thankfully, working with structured grid is fine for our purposes. 

71 We still support 3D arguments for backward compatibility even though they are mapped to 1D grid. 

72 While mapping we assume that only one axis per 3D grid changes throughout the grid. 

73 Args: 

74 copy_nans: 

75 grid_points: List of 1D or 3D Numpy arrays 

76 grid_values: A 3D Numpy array which holds values at grid_points 

77 interp_locs: List of 1D or 3D Numpy arrays 

78 Returns: 

79 

80 """ 

81 assert len(grid_points) == 2, 'interpolate2D supports only 2D interpolation' 

82 assert len(grid_points) == len(interp_locs) 

83 

84 def unpack_and_make_1D(pts): 

85 pts_x, pts_y = pts 

86 if pts_x.ndim == 2: 

87 pts_x = pts_x[:, 0] 

88 if pts_y.ndim == 2: 

89 pts_y = pts_y[0, :] 

90 return pts_x, pts_y 

91 

92 g_x, g_y = unpack_and_make_1D(grid_points) 

93 q_x, q_y = unpack_and_make_1D(interp_locs) 

94 

95 # 'ij' indexing is crucial for Matlab compatibility 

96 queries = np.array(np.meshgrid(q_x, q_y, indexing='ij')) 

97 # Queries are just a list of 3D points 

98 queries = queries.reshape(2, -1).T 

99 

100 # Out of bound points will get NaN values 

101 result = interpn((g_x, g_y), grid_values, queries, method=method, bounds_error=False, fill_value=np.nan) 

102 # Go back from list of interpolated values to 3D volume 

103 result = result.reshape((q_x.size, q_y.size)) 

104 if copy_nans: 

105 assert result.shape == grid_values.shape 

106 # set values outside of the interpolation range to original values 

107 result[np.isnan(result)] = grid_values[np.isnan(result)] 

108 return result 

109 

110 

111def interpolate2D_with_queries( 

112 grid_points: List[np.ndarray], 

113 grid_values: np.ndarray, 

114 queries: np.ndarray, 

115 method='linear', 

116 copy_nans=True 

117) -> np.ndarray: 

118 """ 

119 Interpolates input grid values at the given locations 

120 Added by Farid 

121 

122 Simplified version of `interpolate2D_coords`. 

123 Expects `interp_locs` to be [N, 2] coordinates of the interpolation locations. 

124 Does not create meshgrid on the `interp_locs` as `interpolate2D_coords`! 

125 WARNING: supposed to support only 2D interpolation! 

126 Args: 

127 copy_nans: 

128 grid_points: List of 1D or 3D Numpy arrays 

129 grid_values: A 3D Numpy array which holds values at grid_points 

130 queries: Numpy array with shape [N, 2] 

131 Returns: 

132 

133 """ 

134 assert len(grid_points) == 2, 'interpolate2D supports only 2D interpolation' 

135 

136 g_x, g_y = grid_points 

137 

138 assert g_x.ndim == 1 # is a list 

139 assert g_y.ndim == 1 # is a list 

140 assert queries.ndim == 2 and queries.shape[1] == 2 

141 

142 # Out of bound points will get NaN values 

143 result = interpn((g_x, g_y), grid_values, queries, method=method, bounds_error=False, fill_value=np.nan) 

144 if copy_nans: 

145 assert result.shape == grid_values.shape 

146 # set values outside the interpolation range to original values 

147 result[np.isnan(result)] = grid_values[np.isnan(result)] 

148 return result 

149 

150 

151def cart2grid(kgrid, cart_data, axisymmetric=False): 

152 """ 

153 Interpolate a set of Cartesian points onto a binary grid. 

154 

155 Args: 

156 kgrid: 

157 cart_data: 

158 axisymmetric: 

159 

160 Returns: 

161 cart2grid interpolates the set of Cartesian points defined by 

162 cart_data onto a binary matrix defined by the kWaveGrid object 

163 kgrid using nearest neighbour interpolation. An error is returned if 

164 the Cartesian points are outside the computational domain defined by 

165 kgrid. 

166 """ 

167 # check for axisymmetric input 

168 if axisymmetric and kgrid.dim != 2: 

169 raise AssertionError('Axisymmetric flag only supported in 2D.') 

170 

171 # detect whether the inputs are for one, two, or three dimensions 

172 if kgrid.dim == 1: 

173 # one-dimensional 

174 data_x = cart_data[0, :] 

175 

176 # scale position values to grid centered pixel coordinates using 

177 # nearest neighbour interpolation 

178 data_x = np.round(data_x / kgrid.dx).astype(int) 

179 

180 # shift pixel coordinates to coincide with matrix indexing 

181 data_x = data_x + np.floor(kgrid.Nx // 2).astype(int) 

182 

183 # check if the points all lie within the grid 

184 if data_x.max() > kgrid.Nx or data_x.min() < 1: 

185 raise AssertionError('Cartesian points must lie within the grid defined by kgrid.') 

186 

187 # create empty grid 

188 grid_data = np.zeros((kgrid.Nx, 1)) 

189 

190 # create index variable 

191 point_index = np.arange(1, data_x.size + 1) 

192 

193 # map values 

194 for data_index in range(data_x.size): 

195 grid_data[data_x[data_index]] = point_index[data_index] 

196 

197 # extract reordering index 

198 reorder_index = np.reshape(grid_data[grid_data != 0], (-1, 1)) 

199 

200 elif kgrid.dim == 2: 

201 # two-dimensional 

202 data_x = cart_data[0, :] 

203 data_y = cart_data[1, :] 

204 

205 # scale position values to grid centered pixel coordinates using 

206 # nearest neighbour interpolation 

207 data_x = np.round(data_x / kgrid.dx).astype(int) 

208 data_y = np.round(data_y / kgrid.dy).astype(int) 

209 

210 # shift pixel coordinates to coincide with matrix indexing (leave 

211 # y-direction = radial-direction if axisymmetric) 

212 data_x = data_x + np.floor(kgrid.Nx // 2).astype(int) 

213 if not axisymmetric: 

214 data_y = data_y + np.floor(kgrid.Ny // 2).astype(int) 

215 else: 

216 data_y = data_y + 1 

217 

218 # check if the points all lie within the grid 

219 if data_x.max() > kgrid.Nx or data_y.max() > kgrid.Ny or data_x.min() < 1 or data_y.min() < 1: 

220 raise AssertionError('Cartesian points must lie within the grid defined by kgrid.') 

221 

222 # create empty grid 

223 grid_data = np.zeros((kgrid.Nx, kgrid.Ny)) 

224 

225 # create index variable 

226 point_index = np.arange(1, data_x.size + 1, dtype=int) 

227 

228 # map values 

229 for data_index in range(data_x.size): 

230 grid_data[data_x[data_index], data_y[data_index]] = point_index[data_index] 

231 

232 # extract reordering index 

233 reorder_index = grid_data.flatten(order='F')[ 

234 grid_data.flatten(order='F') != 0 

235 ] 

236 reorder_index = reorder_index[:, None] # [N] => [N, 1] 

237 

238 elif kgrid.dim == 3: 

239 

240 # three dimensional 

241 data_x = cart_data[0, :] 

242 data_y = cart_data[1, :] 

243 data_z = cart_data[2, :] 

244 

245 # scale position values to grid centered pixel coordinates using 

246 # nearest neighbour interpolation 

247 data_x = np.round(data_x / kgrid.dx).astype(int) 

248 data_y = np.round(data_y / kgrid.dy).astype(int) 

249 data_z = np.round(data_z / kgrid.dz).astype(int) 

250 

251 # shift pixel coordinates to coincide with matrix indexing 

252 data_x = data_x + np.floor(kgrid.Nx // 2).astype(int) 

253 data_y = data_y + np.floor(kgrid.Ny // 2).astype(int) 

254 data_z = data_z + np.floor(kgrid.Nz // 2).astype(int) 

255 

256 # check if the points all lie within the grid 

257 assert 1 <= data_x.min() and 1 <= data_y.min() and 1 <= data_z.min() and \ 

258 data_x.max() <= kgrid.Nx and data_y.max() <= kgrid.Ny and data_z.max() <= kgrid.Nz, \ 

259 "Cartesian points must lie within the grid defined by kgrid." 

260 

261 # create empty grid 

262 grid_data = np.zeros((kgrid.Nx, kgrid.Ny, kgrid.Nz), dtype=int) 

263 

264 # create index variable 

265 point_index = np.arange(1, data_x.size + 1) 

266 

267 # map values 

268 for data_index in range(data_x.size): 

269 grid_data[data_x[data_index], data_y[data_index], data_z[data_index]] = point_index[data_index] 

270 

271 # extract reordering index 

272 reorder_index = grid_data.flatten(order='F')[ 

273 grid_data.flatten(order='F') != 0 

274 ] 

275 reorder_index = reorder_index[:, None, None] # [N] => [N, 1, 1] 

276 else: 

277 raise ValueError('Input cart_data must be a 1, 2, or 3 dimensional matrix.') 

278 

279 # compute the reverse ordering index (i.e., what is the index of each point 

280 # in the reordering vector) 

281 order_index = np.ones((reorder_index.size, 2), dtype=int) 

282 order_index[:, 0] = np.squeeze(reorder_index) 

283 order_index[:, 1] = np.arange(1, reorder_index.size + 1) 

284 order_index = sortrows(order_index, 0) 

285 order_index = order_index[:, 1] 

286 order_index = order_index[:, None] # [N] => [N, 1] 

287 

288 # reset binary grid values 

289 grid_data[grid_data != 0] = 1 

290 

291 # check if any Cartesian points have been mapped to the same grid point, 

292 # thereby reducing the total number of points 

293 num_discarded_points = cart_data.shape[1] - np.sum(grid_data) 

294 if num_discarded_points != 0: 

295 print(f' cart2grid: {num_discarded_points} Cartesian points mapped to overlapping grid points') 

296 return grid_data, order_index, reorder_index 

297 

298 

299def get_bli(func, dx=1, up_sampling_factor=20, plot=False): 

300 """ 

301 

302 Args: 

303 func: 1d input function 

304 dx: spatial sampling [m] (default=1) 

305 up_sampling_factor: up-sampling factor used to sample the underlying BLI (default=20) 

306 plot: 

307 

308 Returns: 

309 bli: band-limited interpolant 

310 x_fine: x-grid for BLI 

311 """ 

312 

313 func = np.squeeze(func) 

314 assert len(func.shape) == 1, f"func not 1D but rather {len(func.shape)}D" 

315 nx = len(func) 

316 

317 dk = 2 * np.pi / (dx * nx) 

318 if nx % 2: 

319 # odd 

320 k_min = -np.pi / dx + dk / 2 

321 k_max = np.pi / dx - dk / 2 

322 else: 

323 # even 

324 k_min = -np.pi / dx 

325 k_max = np.pi / dx - dk 

326 

327 k = np.arange(start=k_min, stop=k_max + dk, step=dk, ) 

328 x_fine = np.arange(start=0, stop=((nx - 1) * dx) + dx / up_sampling_factor, step=dx / up_sampling_factor) 

329 

330 func_k = fftshift(fft(func)) / nx 

331 

332 bli = np.real(np.sum(np.matmul(func_k[np.newaxis], np.exp(1j * np.outer(k, x_fine))), axis=0)) 

333 if plot: 

334 raise NotImplementedError 

335 return bli, x_fine 

336 

337 

338def interpCartData(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask, interp='nearest'): 

339 """ 

340 interpCartData takes a matrix of time-series data recorded over a set 

341 of Cartesian sensor points given by cart_sensor_mask and computes the 

342 equivalent time-series at each sensor position on the binary sensor 

343 mask binary_sensor_mask using interpolation. The properties of 

344 binary_sensor_mask are defined by the k-Wave grid object kgrid. 

345 Two and three dimensional data are supported. 

346 

347 Usage: 

348 binary_sensor_data = interpCartData(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask) 

349 binary_sensor_data = interpCartData(kgrid, cart_sensor_data, cart_sensor_mask, binary_sensor_mask, interp) 

350 

351 Args: 

352 kgrid: k-Wave grid object returned by kWaveGrid 

353 cart_sensor_data: original sensor data measured over 

354 cart_sensor_mask indexed as 

355 cart_sensor_data(sensor position, time) 

356 cart_sensor_mask: Cartesian sensor mask over which 

357 cart_sensor_data is measured 

358 binary_sensor_mask: binary sensor mask at which equivalent 

359 time-series are computed via interpolation 

360 

361 interp: (optional) interpolation mode used to compute the 

362 time-series, both 'nearest' and 'linear' 

363 (two-point) modes are supported 

364 (default = 'nearest') 

365 

366 returns: 

367 binary_sensor_data: array of time-series corresponding to the 

368 sensor positions given by binary_sensor_mask 

369 """ 

370 

371 # make timer 

372 timer = TicToc() 

373 # start the clock 

374 timer.tic() 

375 

376 # extract the number of data points 

377 num_cart_data_points, num_time_points = cart_sensor_data.shape 

378 num_binary_sensor_points = np.sum(binary_sensor_mask.flatten()) 

379 

380 # update command line status 

381 print('Interpolating Cartesian sensor data...') 

382 print(f' interpolation mode: {interp}') 

383 print(f' number of Cartesian sensor points: {num_cart_data_points}') 

384 print(f' number of binary sensor points: {num_binary_sensor_points}') 

385 

386 binary_sensor_data = np.zeros((num_binary_sensor_points, num_time_points)) 

387 

388 # Check dimensionality of data passed 

389 if kgrid.dim not in [2, 3]: 

390 raise ValueError('Data must be two- or three-dimensional.') 

391 

392 from kwave.utils.kutils import grid2cart 

393 cart_bsm, _ = grid2cart(kgrid, binary_sensor_mask) 

394 

395 # nearest neighbour interpolation of the data points 

396 for point_index in range(num_binary_sensor_points): 

397 

398 # find the measured data point that is closest 

399 dist = np.linalg.norm(cart_bsm[:, point_index] - cart_sensor_mask.T, ord=2, axis=1) 

400 if interp == 'nearest': 

401 

402 dist_min_index = np.argmin(dist) 

403 

404 # assign value 

405 binary_sensor_data[point_index, :] = cart_sensor_data[dist_min_index, :] 

406 

407 elif interp == 'linear': 

408 # raise NotImplementedError 

409 # append the distance information onto the data set 

410 cart_sensor_data_ro = cart_sensor_data 

411 np.append(cart_sensor_data_ro, dist[:, None], axis=1) 

412 new_col_pos = -1 

413 

414 # reorder the data set based on distance information 

415 cart_sensor_data_ro = sortrows(cart_sensor_data_ro, new_col_pos) 

416 

417 # linearly interpolate between the two closest points 

418 perc = cart_sensor_data_ro[2, new_col_pos] / ( 

419 cart_sensor_data_ro[1, new_col_pos] + cart_sensor_data_ro[2, new_col_pos]) 

420 binary_sensor_data[point_index, :] = perc * cart_sensor_data_ro[1, :] + \ 

421 (1 - perc) * cart_sensor_data_ro[2, :] 

422 

423 else: 

424 raise ValueError('Unknown interpolation option.') 

425 

426 # elif interp == 'linear': 

427 # 

428 # # dist = np.sqrt((cart_bsm[0, point_index] - cart_sensor_mask[0, :])**2 + (cart_bsm[1, point_index] - cart_sensor_mask[1, :])**2) 

429 # # dist = np.linalg.norm(cart_bsm[:, point_index] - cart_sensor_mask.T, axis=1) 

430 # # append the distance information onto the data set 

431 # new_col_pos = len(cart_sensor_data[1, :]) -1 

432 # cart_sensor_data_ro = cart_sensor_data 

433 # cart_sensor_data_ro[:, new_col_pos] = dist 

434 # 

435 # # reorder the data set based on distance information 

436 # cart_sensor_data_ro = sortrows(cart_sensor_data_ro, new_col_pos) 

437 # 

438 # # linearly interpolate between the two closest points 

439 # perc = cart_sensor_data_ro[1, new_col_pos] / (cart_sensor_data_ro[0, new_col_pos] + cart_sensor_data_ro[1, new_col_pos] ) 

440 # binary_sensor_data[point_index, :] = perc * cart_sensor_data_ro[1, :new_col_pos - 1] + (1 - perc) * cart_sensor_data_ro[1, :new_col_pos - 1] 

441 # 

442 # else: 

443 # raise ValueError('Unknown interpolation option.') 

444 

445 # update command line status 

446 print(f' computation completed in {scale_time(timer.toc())}') 

447 return binary_sensor_data 

448 

449 

450def interpftn(x, sz: tuple, win=None): 

451 """ 

452 Resamples an N-D matrix to the size given in sz using Fourier interpolation. 

453 

454 USAGE: 

455 y = interpftn(x, sz) 

456 y = interpftn(x, sz, win) 

457 

458 Args: 

459 x: matrix to interpolate 

460 sz: list or tupple of new size 

461 win: (optional) name of windowing function to use 

462 

463 Returns: 

464 y: resampled matrix 

465 """ 

466 

467 # extract the size of the input matrix 

468 x_sz = x.shape 

469 

470 # check enough coefficients have been given 

471 if sum([x != 1 for x in x_sz]) != len(sz): 

472 raise ValueError('The number of scaling coefficients must equal the number of dimensions in x.') 

473 

474 # interpolate for each matrix dimension (dimensions with no interpolation required are skipped) 

475 y = x 

476 for p_idx, p in enumerate(sz): 

477 if p != x_sz[p_idx]: 

478 y = resample(y, p, axis=p_idx, window=win) 

479 

480 return y