Coverage for kwave/kgrid.py: 46%

275 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-24 11:55 -0700

1import math 

2from dataclasses import dataclass 

3 

4import numpy as np 

5import sys 

6 

7from kwave.data import Array 

8from kwave.enums import DiscreteCosine, DiscreteSine 

9 

10# default CFL number 

11CFL_DEFAULT = 0.3 

12 

13# machine precision 

14MACHINE_PRECISION = 100 * sys.float_info.epsilon 

15 

16 

17@dataclass 

18class kWaveGrid(object): 

19 """ 

20 kWaveGrid is the grid class used across the k-Wave Toolbox. An object 

21 of the kWaveGrid class contains the grid coordinates and wavenumber 

22 matrices used within the simulation and reconstruction functions in 

23 k-Wave. The grid matrices are indexed as: (x, 1) in 1D; (x, y) in 

24 2D; and (x, y, z) in 3D. The grid is assumed to be a regularly spaced 

25 Cartesian grid, with grid spacing given by dx, dy, dz (typically the 

26 grid spacing in each direction is constant). 

27 """ 

28 

29 def __init__(self, N, spacing): 

30 """ 

31 

32 Args: 

33 N: grid size in each dimension [grid points] 

34 spacing: grid point spacing in each direction [m] 

35 """ 

36 N, spacing = np.atleast_1d(N), np.atleast_1d(spacing) # if inputs are lists 

37 assert N.ndim == 1 and spacing.ndim == 1 # ensure no multidimensional lists 

38 assert (1 <= N.size <= 3) and (1 <= spacing.size <= 3) # ensure valid dimensionality 

39 assert N.size == spacing.size, "Size list N and spacing list do not have the same size." 

40 

41 self.N = N.astype(int) #: grid size in each dimension [grid points] 

42 self.spacing = spacing #: grid point spacing in each direction [m] 

43 self.dim = self.N.size #: Number of dimensions (1, 2 or 3) 

44 

45 self.nonuniform = False #: flag that indicates grid non-uniformity 

46 self.dt = 'auto' #: size of time step [s] 

47 self.Nt = 'auto' #: number of time steps [s] 

48 

49 # originally there was [xn_vec, yn_vec, zn_vec] 

50 self.n_vec = Array([0] * self.dim) #: position vectors for the grid points in [0, 1] 

51 # originally there was [xn_vec_sgx, yn_vec_sgy, zn_vec_sgz] 

52 self.n_vec_sg = Array([0] * self.dim) #: position vectors for the staggered grid points in [0, 1] 

53 

54 # originally there was [dxudxn, dyudyn, dzudzn] 

55 self.dudn = Array([0] * self.dim) #: transformation gradients between uniform and staggered grids 

56 # originally there was [dxudxn_sgx, dyudyn_sgy, dzudzn_sgz] 

57 self.dudn_sg = Array([0] * self.dim) #: transformation gradients between uniform and staggered grids 

58 

59 # assign the grid parameters for the x spatial direction 

60 # originally kx_vec 

61 self.k_vec = Array([self.makeDim(self.Nx, self.dx)]) #: Nx x 1 vector of wavenumber components in the x-direction [rad/m] 

62 

63 if self.dim == 1: 

64 # define the scalar wavenumber based on the wavenumber components 

65 self.k = abs(self.k_vec.x) #: scalar wavenumber 

66 

67 if self.dim >= 2: 

68 # assign the grid parameters for the x and y spatial directions 

69 # Ny x 1 vector of wavenumber components in the y-direction [rad/m] 

70 self.k_vec.append(self.makeDim(self.Ny, self.dy)) 

71 

72 if self.dim == 2: 

73 # define the wavenumber based on the wavenumber components 

74 self.k = np.zeros((self.Nx, self.Ny)) 

75 self.k = np.reshape(self.k_vec.x, (-1, 1)) ** 2 + self.k 

76 self.k = np.reshape(self.k_vec.y, (1, -1)) ** 2 + self.k 

77 self.k = np.sqrt(self.k) #: scalar wavenumber 

78 

79 if self.dim == 3: 

80 # assign the grid parameters for the x, y, and z spatial directions 

81 # Nz x 1 vector of wavenumber components in the z-direction [rad/m] 

82 self.k_vec.append(self.makeDim(self.Nz, self.dz)) 

83 

84 # define the wavenumber based on the wavenumber components 

85 self.k = np.zeros((self.Nx, self.Ny, self.Nz)) 

86 self.k = np.reshape(self.k_vec.x, (-1, 1, 1)) ** 2 + self.k 

87 self.k = np.reshape(self.k_vec.y, (1, -1, 1)) ** 2 + self.k 

88 self.k = np.reshape(self.k_vec.z, (1, 1, -1)) ** 2 + self.k 

89 self.k = np.sqrt(self.k) #: scalar wavenumber 

90 

91 

92 @property 

93 def t_array(self): 

94 """ 

95 time array [s] 

96 """ 

97 if self.Nt == 'auto' or self.dt == 'auto': 

98 return 'auto' 

99 else: 

100 t_array = np.arange(0, self.Nt) * self.dt 

101 return np.expand_dims(t_array, axis=0) 

102 

103 @t_array.setter 

104 def t_array(self, t_array): 

105 # check for 'auto' input 

106 if t_array == 'auto': 

107 # set values to auto 

108 self.Nt = 'auto' 

109 self.dt = 'auto' 

110 

111 else: 

112 # extract property values 

113 Nt_temp = t_array.size 

114 dt_temp = t_array[1] - t_array[0] 

115 

116 # check the time array begins at zero 

117 assert t_array[0] == 0, 't_array must begin at zero.' 

118 

119 # check the time array is evenly spaced 

120 assert (t_array[1:] - t_array[0:-1] - dt_temp).max() < MACHINE_PRECISION, \ 

121 't_array must be evenly spaced.' 

122 

123 # check the time steps are increasing 

124 assert dt_temp > 0, 't_array must be monotonically increasing.' 

125 

126 # assign values 

127 self.Nt = Nt_temp 

128 self.dt = dt_temp 

129 

130 def setTime(self, Nt, dt) -> None: 

131 """ 

132 Set Nt and dt based on user input 

133 

134 Args: 

135 Nt: 

136 dt: 

137 

138 Returns: None 

139 """ 

140 # check the value for Nt 

141 assert (isinstance(Nt, int) or np.issubdtype(Nt, np.int)) and Nt > 0, 'Nt must be a positive integer.' 

142 

143 # check the value for dt 

144 assert dt > 0, 'dt must be positive.' 

145 

146 # assign values 

147 self.Nt = Nt 

148 self.dt = dt 

149 

150 @property 

151 def Nx(self): 

152 """ 

153 grid size in x-direction [grid points] 

154 """ 

155 return self.N[0] 

156 

157 @property 

158 def Ny(self): 

159 """ 

160 grid size in y-direction [grid points] 

161 """ 

162 return self.N[1] if self.N.size >= 2 else 0 

163 

164 @property 

165 def Nz(self): 

166 """ 

167 grid size in z-direction [grid points] 

168 """ 

169 return self.N[2] if self.N.size == 3 else 0 

170 

171 @property 

172 def dx(self): 

173 """ 

174 grid point spacing in x-direction [m] 

175 """ 

176 return self.spacing[0] 

177 

178 @property 

179 def dy(self): 

180 """ 

181 grid point spacing in y-direction [m] 

182 """ 

183 return self.spacing[1] if self.spacing.size >= 2 else 0 

184 

185 @property 

186 def dz(self): 

187 """ 

188 grid point spacing in z-direction [m] 

189 """ 

190 return self.spacing[2] if self.spacing.size == 3 else 0 

191 

192 @property 

193 def x_vec(self): 

194 """ 

195 Nx x 1 vector of the grid coordinates in the x-direction [m] 

196 """ 

197 # calculate x_vec based on kx_vec 

198 return self.size[0] * self.k_vec.x * self.dx / (2 * np.pi) 

199 

200 @property 

201 def y_vec(self): 

202 """ 

203 Ny x 1 vector of the grid coordinates in the y-direction [m] 

204 """ 

205 # calculate y_vec based on ky_vec 

206 if self.dim < 2: 

207 return np.nan 

208 return self.size[1] * self.k_vec.y * self.dy / (2 * np.pi) 

209 

210 @property 

211 def z_vec(self): 

212 """ 

213 Nz x 1 vector of the grid coordinates in the z-direction [m] 

214 """ 

215 # calculate z_vec based on kz_vec 

216 if self.dim < 3: 

217 return np.nan 

218 return self.size[2] * self.k_vec.z * self.dz / (2 * np.pi) 

219 

220 @property 

221 def x(self): 

222 """ 

223 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the x-direction [m] 

224 """ 

225 return self.size[0] * self.kx * self.dx / (2 * math.pi) 

226 

227 @property 

228 def y(self): 

229 """ 

230 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the y-direction [m] 

231 """ 

232 if self.dim < 2: 

233 return np.nan 

234 return self.size[1] * self.ky * self.dy / (2 * math.pi) 

235 

236 @property 

237 def z(self): 

238 """ 

239 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the z-direction [m] 

240 """ 

241 if self.dim < 3: 

242 return np.nan 

243 return self.size[2] * self.kz * self.dz / (2 * math.pi) 

244 

245 @property 

246 def xn(self): 

247 """ 

248 3D plaid non-uniform spatial grids 

249 

250 Returns: 

251 plaid xn matrix 

252 """ 

253 if self.dim == 1: 

254 return self.n_vec.x if self.nonuniform else 0 

255 elif self.dim == 2: 

256 return np.tile(self.n_vec.x, (1, self.Ny)) if self.nonuniform else 0 

257 else: 

258 return np.tile(self.n_vec.x, (1, self.Ny, self.Nz)) if self.nonuniform else 0 

259 

260 @property 

261 def yn(self): 

262 """ 

263 3D plaid non-uniform spatial grids 

264 

265 Returns: 

266 plaid yn matrix 

267 """ 

268 if self.dim < 2: 

269 return np.nan 

270 if self.dim == 2: 

271 return np.tile(self.n_vec.y.T, (self.Nx, 1)) if self.nonuniform else 0 

272 else: 

273 return np.tile(self.n_vec.y.T, (self.Nx, 1, self.Nz)) if self.nonuniform else 0 

274 

275 @property 

276 def zn(self): 

277 """ 

278 3D plaid non-uniform spatial grids 

279 Returns: 

280 plaid zn matrix 

281 """ 

282 if self.dim < 3: 

283 return np.nan 

284 return np.tile(np.transpose(self.n_vec.z, (1, 2, 0)), (self.Nx, self.Ny, 1)) if self.nonuniform else 0 

285 

286 @property 

287 def size(self): 

288 """ 

289 Size of grid in the all directions [m] 

290 """ 

291 return self.N * self.spacing 

292 

293 @property 

294 def total_grid_points(self) -> np.ndarray: 

295 """ 

296 Total number of grid points (equal to Nx * Ny * Nz) 

297 """ 

298 return np.prod(self.N) 

299 

300 @property 

301 def kx(self): 

302 """ 

303 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the x-direction [rad/m] 

304 

305 Returns: 

306 plaid xn matrix 

307 """ 

308 if self.dim == 1: 

309 return self.k_vec.x 

310 elif self.dim == 2: 

311 return np.tile(self.k_vec.x, (1, self.Ny)) 

312 else: 

313 return np.tile(self.k_vec.x[:, :, None], (1, self.Ny, self.Nz)) 

314 

315 @property 

316 def ky(self): 

317 """ 

318 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the y-direction [rad/m] 

319 

320 Returns: 

321 plaid yn matrix 

322 """ 

323 if self.dim == 2: 

324 return np.tile(self.k_vec.y.T, (self.Nx, 1)) 

325 elif self.dim == 3: 

326 return np.tile(self.k_vec.y[None, :, :], (self.Nx, 1, self.Nz)) 

327 return np.nan 

328 

329 @property 

330 def kz(self): 

331 """ 

332 # Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the z-direction [rad/m] 

333 

334 Returns: 

335 plaid zn matrix 

336 """ 

337 if self.dim == 3: 

338 return np.tile(self.k_vec.z.T[None, :, :], (self.Nx, self.Ny, 1)) 

339 else: 

340 return np.nan 

341 

342 @property 

343 def y_size(self): 

344 """ 

345 Size of grid in the y-direction [m] 

346 """ 

347 return self.Ny * self.dy 

348 

349 @property 

350 def z_size(self): 

351 """ 

352 Size of grid in the z-direction [m] 

353 """ 

354 return self.Nz * self.dz 

355 

356 @property 

357 def k_max(self): # added by us, not the same as kWave k_max (see k_max_all for KwaveGrid.k_max) 

358 """ 

359 Maximum supported spatial frequency in the 3 directions [rad/m] 

360 

361 Returns: 

362 Vector of 3 elements each in [rad/m]. Value for higher dimensions set to NaN 

363 """ 

364 # 

365 kx_max = np.abs(self.k_vec.x).max() 

366 ky_max = np.abs(self.k_vec.y).max() if self.dim >= 2 else np.nan 

367 kz_max = np.abs(self.k_vec.z).max() if self.dim == 3 else np.nan 

368 return Array([kx_max, ky_max, kz_max]) 

369 

370 @property 

371 def k_max_all(self): 

372 """ 

373 Maximum supported spatial frequency in all directions [rad/m] 

374 Originally k_max in kWave.kWaveGrid! 

375 

376 Returns: 

377 Scalar in [rad/m] 

378 """ 

379 # 

380 return np.nanmin(self.k_max.numpy()) 

381 

382 ######################################## 

383 # functions that can only be accessed by class members 

384 ######################################## 

385 @staticmethod 

386 def makeDim(num_points, spacing): 

387 """ 

388 Create the grid parameters for a single spatial direction 

389 

390 Args: 

391 num_points: 

392 spacing: 

393 

394 Returns: 

395 

396 """ 

397 # define the discretisation of the spatial dimension such that there is always a DC component 

398 if num_points % 2 == 0: 

399 # grid dimension has an even number of points 

400 nx = np.arange(-num_points / 2, num_points / 2) / num_points 

401 else: 

402 # grid dimension has an odd number of points 

403 nx = np.arange(-(num_points - 1) / 2, (num_points - 1) / 2 + 1) / num_points 

404 nx = np.array(nx).T 

405 

406 # force middle value to be zero in case 1/Nx is a recurring 

407 # number and the series doesn't give exactly zero 

408 nx[int(num_points // 2)] = 0 

409 

410 # define the wavenumber vector components 

411 res = (2 * math.pi / spacing) * nx 

412 return res[:, None] 

413 

414 def highest_prime_factors(self, axisymmetric=None) -> np.ndarray: 

415 """ 

416 calculate highest prime factors 

417 

418 Args: 

419 axisymmetric: Axisymmetric code or None 

420 

421 Returns: 

422 Vector of three elements 

423 """ 

424 # import statement place here in order to avoid circular dependencies 

425 from kwave.utils import largest_prime_factor 

426 if axisymmetric is not None: 

427 if axisymmetric == 'WSWA': 

428 prime_facs = [largest_prime_factor(self.Nx), 

429 largest_prime_factor(self.Ny * 4), 

430 largest_prime_factor(self.Nz)] 

431 elif axisymmetric == 'WSWS': 

432 prime_facs = [largest_prime_factor(self.Nx), 

433 largest_prime_factor(self.Ny * 2 - 2), 

434 largest_prime_factor(self.Nz)] 

435 else: 

436 raise ValueError('Unknown axisymmetric symmetry.') 

437 else: 

438 prime_facs = [largest_prime_factor(self.Nx), 

439 largest_prime_factor(self.Ny), 

440 largest_prime_factor(self.Nz)] 

441 return np.array(prime_facs) 

442 

443 def makeTime(self, c, cfl=CFL_DEFAULT, t_end=None): 

444 """ 

445 Compute Nt and dt based on the cfl number and grid size, where 

446 the number of time-steps is chosen based on the time it takes to 

447 travel from one corner of the grid to the geometrically opposite 

448 corner. Note, if c is given as a matrix, the calculation for dt 

449 is based on the maximum value, and the calculation for t_end 

450 based on the minimum value. 

451 

452 Args: 

453 c: 

454 cfl: 

455 t_end: 

456 

457 Returns: 

458 Nothing 

459 """ 

460 # if c is a matrix, find the minimum and maximum values 

461 c = np.array(c) 

462 c_min, c_max = c.min(), c.max() 

463 

464 # check for user define t_end, otherwise set the simulation 

465 # length based on the size of the grid diagonal and the maximum 

466 # sound speed in the medium 

467 if t_end is None: 

468 t_end = np.linalg.norm(self.size, ord=2) / c_min 

469 

470 # extract the smallest grid spacing 

471 min_grid_dim = self.spacing.min() 

472 

473 # assign time step based on CFL stability criterion 

474 self.dt = cfl * min_grid_dim / c_max 

475 

476 # assign number of time steps based on t_end 

477 self.Nt = int(t_end / self.dt) + 1 

478 

479 # catch case were dt is a recurring number 

480 if (int(t_end / self.dt) != math.ceil(t_end / self.dt)) and (t_end % self.dt == 0): 

481 self.Nt = self.Nt + 1 

482 

483 ################################################## 

484 #### 

485 #### FUNCTIONS BELOW WERE NOT TESTED FOR CORRECTNESS! 

486 #### 

487 ################################################## 

488 def kx_vec_dtt(self, dtt_type): 

489 """ 

490 Compute the DTT wavenumber vector in the x-direction 

491 

492 Args: 

493 dtt_type: 

494 

495 Returns: 

496 

497 """ 

498 kx_vec_dtt, M = self.makeDTTDim(self.Nx, self.dx, dtt_type) 

499 return kx_vec_dtt, M 

500 

501 def ky_vec_dtt(self, dtt_type): 

502 """ 

503 Compute the DTT wavenumber vector in the y-direction 

504 

505 Args: 

506 dtt_type: 

507 

508 Returns: 

509 

510 """ 

511 ky_vec_dtt, M = self.makeDTTDim(self.Ny, self.dy, dtt_type) 

512 return ky_vec_dtt, M 

513 

514 def kz_vec_dtt(self, dtt_type): 

515 """ 

516 Compute the DTT wavenumber vector in the z-direction 

517 

518 Args: 

519 dtt_type: 

520 

521 Returns: 

522 

523 """ 

524 kz_vec_dtt, M = self.makeDTTDim(self.Nz, self.dz, dtt_type) 

525 return kz_vec_dtt, M 

526 

527 @staticmethod 

528 def makeDTTDim(Nx, dx, dtt_type): 

529 """ 

530 Create the DTT grid parameters for a single spatial direction 

531 

532 Args: 

533 Nx: 

534 dx: 

535 dtt_type: 

536 

537 Returns: 

538 

539 """ 

540 

541 # compute the implied period of the input function 

542 if dtt_type == 1: 

543 M = 2 * (Nx - 1) 

544 elif dtt_type == 5: 

545 M = 2 * (Nx + 1) 

546 else: 

547 M = 2 * Nx 

548 

549 # calculate the wavenumbers 

550 if dtt_type == DiscreteCosine.TYPE_1: 

551 # whole-wavenumber DTT 

552 # WSWS / DCT-I 

553 n = np.arange(0, M // 2).T 

554 kx_vec = 2 * math.pi * n / (M * dx) 

555 elif dtt_type == DiscreteCosine.TYPE_2: 

556 # whole-wavenumber DTT 

557 # HSHS / DCT-II 

558 n = np.arange(0, M // 2).T 

559 kx_vec = 2 * math.pi * n / (M * dx) 

560 elif dtt_type == DiscreteSine.TYPE_1: 

561 # whole-wavenumber DTT 

562 # WAWA / DST-I 

563 n = np.arange(1, M // 2).T 

564 kx_vec = 2 * math.pi * n / (M * dx) 

565 elif dtt_type == DiscreteSine.TYPE_2: 

566 # whole-wavenumber DTT 

567 # HAHA / DST-II 

568 n = np.arange(1, M // 2).T 

569 kx_vec = 2 * math.pi * n / (M * dx) 

570 elif dtt_type in [DiscreteCosine.TYPE_3, DiscreteCosine.TYPE_4, 

571 DiscreteSine.TYPE_3, DiscreteSine.TYPE_4]: 

572 # half-wavenumber DTTs 

573 # WSWA / DCT-III 

574 # HSHA / DCT-IV 

575 # WAWS / DST-III 

576 # HAHS / DST-IV 

577 n = np.arange(0, M // 2).T 

578 kx_vec = 2 * math.pi * (n + 0.5) / (M * dx) 

579 else: 

580 raise ValueError 

581 

582 return kx_vec, M 

583 

584 ######################################## 

585 # functions for non-uniform grids 

586 ######################################## 

587 def setNUGrid(self, dim, n_vec, dudn, n_vec_sg, dudn_sg): 

588 """ 

589 Function to set non-uniform grid parameters in specified dimension 

590 

591 Args: 

592 dim: 

593 n_vec: 

594 dudn: 

595 n_vec_sg: 

596 dudn_sg: 

597 

598 Returns: 

599 

600 """ 

601 

602 # check the dimension to set the nonuniform grid is appropriate 

603 assert dim <= self.dim, f'Cannot set nonuniform parameters for dimension {dim} of {self.dim}-dimensional grid.' 

604 

605 # force non-uniform grid spacing to be column vectors, and the 

606 # gradients to be in the correct direction for use with bsxfun 

607 n_vec = np.reshape(n_vec, (-1, 1)) 

608 n_vec_sg = np.reshape(n_vec_sg, (-1, 1)) 

609 

610 if dim == 1: 

611 dudn = np.reshape(dudn, (-1, 1)) 

612 dudn_sg = np.reshape(dudn_sg, (-1, 1)) 

613 elif dim == 2: 

614 dudn = np.reshape(dudn, (1, -1)) 

615 dudn_sg = np.reshape(dudn_sg, (1, -1)) 

616 elif dim == 3: 

617 dudn = np.reshape(dudn, (1, 1, -1)) 

618 dudn_sg = np.reshape(dudn_sg, (1, 1, -1)) 

619 

620 self.n_vec.assign_dim(self.dim, n_vec) 

621 self.n_vec_sg.assign_dim(self.dim, n_vec_sg) 

622 

623 self.dudn.assign_dim(self.dim, dudn) 

624 self.dudn_sg.assign_dim(self.dudn_sg, dudn_sg) 

625 

626 # set non-uniform flag 

627 self.nonuniform = True 

628 

629 def k_dtt(self, dtt_type): # Not tested for correctness! 

630 """ 

631 compute the individual wavenumber vectors, where dtt_type is the 

632 type of discrete trigonometric transform, which corresponds to 

633 the assumed input symmetry of the input function, where: 

634 

635 1. DCT-I WSWS 

636 2. DCT-II HSHS 

637 3. DCT-III WSWA 

638 4. DCT-IV HSHA 

639 5. DST-I WAWA 

640 6. DST-II HAHA 

641 7. DST-III WAWS 

642 8. DST-IV HAHS 

643 

644 Args: 

645 dtt_type: 

646 

647 Returns: 

648 

649 """ 

650 # check dtt_type is a scalar or a vector the same size self.dim 

651 dtt_type = np.array(dtt_type) 

652 assert (dtt_type.size in [1, self.dim]), f'dtt_type must be a scalar, or {self.dim}D vector' 

653 if self.dim == 1: 

654 k, M = self.kx_vec_dtt(dtt_type) 

655 return k, M 

656 elif self.dim == 2: 

657 # assign the grid parameters for the x and y spatial directions 

658 kx_vec_dtt, Mx = self.kx_vec_dtt(dtt_type[0]) 

659 ky_vec_dtt, My = self.ky_vec_dtt(dtt_type[-1]) 

660 

661 # define the wavenumber based on the wavenumber components 

662 k = np.zeros((self.Nx, self.Ny)) 

663 assert len(kx_vec_dtt.shape) == 3 

664 k = np.reshape(kx_vec_dtt, (-1, 1, 1)) ** 2 + k 

665 k = np.reshape(ky_vec_dtt, (1, -1, 1)) ** 2 + k 

666 k = np.sqrt(k) 

667 

668 # define product of implied period 

669 M = Mx * My 

670 return k, M 

671 elif self.dim == 3: 

672 # assign the grid parameters for the x, y, and z spatial directions 

673 kx_vec_dtt, Mx = self.kx_vec_dtt(dtt_type[0]) 

674 ky_vec_dtt, My = self.ky_vec_dtt(dtt_type[len(dtt_type) // 2]) 

675 kz_vec_dtt, Mz = self.kz_vec_dtt(dtt_type[-1]) 

676 

677 # define the wavenumber based on the wavenumber components 

678 k = np.zeros((self.Nx, self.Ny, self.Nz)) 

679 k = np.reshape(kx_vec_dtt, (-1, 1, 1)) ** 2 + k 

680 k = np.reshape(ky_vec_dtt, (1, -1, 1)) ** 2 + k 

681 k = np.reshape(kz_vec_dtt, (1, 1, -1)) ** 2 + k 

682 k = np.sqrt(k) 

683 

684 # define product of implied period 

685 M = Mx * My * Mz 

686 return k, M