Coverage for kwave/kgrid.py: 46%
275 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
1import math
2from dataclasses import dataclass
4import numpy as np
5import sys
7from kwave.data import Array
8from kwave.enums import DiscreteCosine, DiscreteSine
10# default CFL number
11CFL_DEFAULT = 0.3
13# machine precision
14MACHINE_PRECISION = 100 * sys.float_info.epsilon
17@dataclass
18class kWaveGrid(object):
19 """
20 kWaveGrid is the grid class used across the k-Wave Toolbox. An object
21 of the kWaveGrid class contains the grid coordinates and wavenumber
22 matrices used within the simulation and reconstruction functions in
23 k-Wave. The grid matrices are indexed as: (x, 1) in 1D; (x, y) in
24 2D; and (x, y, z) in 3D. The grid is assumed to be a regularly spaced
25 Cartesian grid, with grid spacing given by dx, dy, dz (typically the
26 grid spacing in each direction is constant).
27 """
29 def __init__(self, N, spacing):
30 """
32 Args:
33 N: grid size in each dimension [grid points]
34 spacing: grid point spacing in each direction [m]
35 """
36 N, spacing = np.atleast_1d(N), np.atleast_1d(spacing) # if inputs are lists
37 assert N.ndim == 1 and spacing.ndim == 1 # ensure no multidimensional lists
38 assert (1 <= N.size <= 3) and (1 <= spacing.size <= 3) # ensure valid dimensionality
39 assert N.size == spacing.size, "Size list N and spacing list do not have the same size."
41 self.N = N.astype(int) #: grid size in each dimension [grid points]
42 self.spacing = spacing #: grid point spacing in each direction [m]
43 self.dim = self.N.size #: Number of dimensions (1, 2 or 3)
45 self.nonuniform = False #: flag that indicates grid non-uniformity
46 self.dt = 'auto' #: size of time step [s]
47 self.Nt = 'auto' #: number of time steps [s]
49 # originally there was [xn_vec, yn_vec, zn_vec]
50 self.n_vec = Array([0] * self.dim) #: position vectors for the grid points in [0, 1]
51 # originally there was [xn_vec_sgx, yn_vec_sgy, zn_vec_sgz]
52 self.n_vec_sg = Array([0] * self.dim) #: position vectors for the staggered grid points in [0, 1]
54 # originally there was [dxudxn, dyudyn, dzudzn]
55 self.dudn = Array([0] * self.dim) #: transformation gradients between uniform and staggered grids
56 # originally there was [dxudxn_sgx, dyudyn_sgy, dzudzn_sgz]
57 self.dudn_sg = Array([0] * self.dim) #: transformation gradients between uniform and staggered grids
59 # assign the grid parameters for the x spatial direction
60 # originally kx_vec
61 self.k_vec = Array([self.makeDim(self.Nx, self.dx)]) #: Nx x 1 vector of wavenumber components in the x-direction [rad/m]
63 if self.dim == 1:
64 # define the scalar wavenumber based on the wavenumber components
65 self.k = abs(self.k_vec.x) #: scalar wavenumber
67 if self.dim >= 2:
68 # assign the grid parameters for the x and y spatial directions
69 # Ny x 1 vector of wavenumber components in the y-direction [rad/m]
70 self.k_vec.append(self.makeDim(self.Ny, self.dy))
72 if self.dim == 2:
73 # define the wavenumber based on the wavenumber components
74 self.k = np.zeros((self.Nx, self.Ny))
75 self.k = np.reshape(self.k_vec.x, (-1, 1)) ** 2 + self.k
76 self.k = np.reshape(self.k_vec.y, (1, -1)) ** 2 + self.k
77 self.k = np.sqrt(self.k) #: scalar wavenumber
79 if self.dim == 3:
80 # assign the grid parameters for the x, y, and z spatial directions
81 # Nz x 1 vector of wavenumber components in the z-direction [rad/m]
82 self.k_vec.append(self.makeDim(self.Nz, self.dz))
84 # define the wavenumber based on the wavenumber components
85 self.k = np.zeros((self.Nx, self.Ny, self.Nz))
86 self.k = np.reshape(self.k_vec.x, (-1, 1, 1)) ** 2 + self.k
87 self.k = np.reshape(self.k_vec.y, (1, -1, 1)) ** 2 + self.k
88 self.k = np.reshape(self.k_vec.z, (1, 1, -1)) ** 2 + self.k
89 self.k = np.sqrt(self.k) #: scalar wavenumber
92 @property
93 def t_array(self):
94 """
95 time array [s]
96 """
97 if self.Nt == 'auto' or self.dt == 'auto':
98 return 'auto'
99 else:
100 t_array = np.arange(0, self.Nt) * self.dt
101 return np.expand_dims(t_array, axis=0)
103 @t_array.setter
104 def t_array(self, t_array):
105 # check for 'auto' input
106 if t_array == 'auto':
107 # set values to auto
108 self.Nt = 'auto'
109 self.dt = 'auto'
111 else:
112 # extract property values
113 Nt_temp = t_array.size
114 dt_temp = t_array[1] - t_array[0]
116 # check the time array begins at zero
117 assert t_array[0] == 0, 't_array must begin at zero.'
119 # check the time array is evenly spaced
120 assert (t_array[1:] - t_array[0:-1] - dt_temp).max() < MACHINE_PRECISION, \
121 't_array must be evenly spaced.'
123 # check the time steps are increasing
124 assert dt_temp > 0, 't_array must be monotonically increasing.'
126 # assign values
127 self.Nt = Nt_temp
128 self.dt = dt_temp
130 def setTime(self, Nt, dt) -> None:
131 """
132 Set Nt and dt based on user input
134 Args:
135 Nt:
136 dt:
138 Returns: None
139 """
140 # check the value for Nt
141 assert (isinstance(Nt, int) or np.issubdtype(Nt, np.int)) and Nt > 0, 'Nt must be a positive integer.'
143 # check the value for dt
144 assert dt > 0, 'dt must be positive.'
146 # assign values
147 self.Nt = Nt
148 self.dt = dt
150 @property
151 def Nx(self):
152 """
153 grid size in x-direction [grid points]
154 """
155 return self.N[0]
157 @property
158 def Ny(self):
159 """
160 grid size in y-direction [grid points]
161 """
162 return self.N[1] if self.N.size >= 2 else 0
164 @property
165 def Nz(self):
166 """
167 grid size in z-direction [grid points]
168 """
169 return self.N[2] if self.N.size == 3 else 0
171 @property
172 def dx(self):
173 """
174 grid point spacing in x-direction [m]
175 """
176 return self.spacing[0]
178 @property
179 def dy(self):
180 """
181 grid point spacing in y-direction [m]
182 """
183 return self.spacing[1] if self.spacing.size >= 2 else 0
185 @property
186 def dz(self):
187 """
188 grid point spacing in z-direction [m]
189 """
190 return self.spacing[2] if self.spacing.size == 3 else 0
192 @property
193 def x_vec(self):
194 """
195 Nx x 1 vector of the grid coordinates in the x-direction [m]
196 """
197 # calculate x_vec based on kx_vec
198 return self.size[0] * self.k_vec.x * self.dx / (2 * np.pi)
200 @property
201 def y_vec(self):
202 """
203 Ny x 1 vector of the grid coordinates in the y-direction [m]
204 """
205 # calculate y_vec based on ky_vec
206 if self.dim < 2:
207 return np.nan
208 return self.size[1] * self.k_vec.y * self.dy / (2 * np.pi)
210 @property
211 def z_vec(self):
212 """
213 Nz x 1 vector of the grid coordinates in the z-direction [m]
214 """
215 # calculate z_vec based on kz_vec
216 if self.dim < 3:
217 return np.nan
218 return self.size[2] * self.k_vec.z * self.dz / (2 * np.pi)
220 @property
221 def x(self):
222 """
223 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the x-direction [m]
224 """
225 return self.size[0] * self.kx * self.dx / (2 * math.pi)
227 @property
228 def y(self):
229 """
230 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the y-direction [m]
231 """
232 if self.dim < 2:
233 return np.nan
234 return self.size[1] * self.ky * self.dy / (2 * math.pi)
236 @property
237 def z(self):
238 """
239 Nx x Ny x Nz grid containing repeated copies of the grid coordinates in the z-direction [m]
240 """
241 if self.dim < 3:
242 return np.nan
243 return self.size[2] * self.kz * self.dz / (2 * math.pi)
245 @property
246 def xn(self):
247 """
248 3D plaid non-uniform spatial grids
250 Returns:
251 plaid xn matrix
252 """
253 if self.dim == 1:
254 return self.n_vec.x if self.nonuniform else 0
255 elif self.dim == 2:
256 return np.tile(self.n_vec.x, (1, self.Ny)) if self.nonuniform else 0
257 else:
258 return np.tile(self.n_vec.x, (1, self.Ny, self.Nz)) if self.nonuniform else 0
260 @property
261 def yn(self):
262 """
263 3D plaid non-uniform spatial grids
265 Returns:
266 plaid yn matrix
267 """
268 if self.dim < 2:
269 return np.nan
270 if self.dim == 2:
271 return np.tile(self.n_vec.y.T, (self.Nx, 1)) if self.nonuniform else 0
272 else:
273 return np.tile(self.n_vec.y.T, (self.Nx, 1, self.Nz)) if self.nonuniform else 0
275 @property
276 def zn(self):
277 """
278 3D plaid non-uniform spatial grids
279 Returns:
280 plaid zn matrix
281 """
282 if self.dim < 3:
283 return np.nan
284 return np.tile(np.transpose(self.n_vec.z, (1, 2, 0)), (self.Nx, self.Ny, 1)) if self.nonuniform else 0
286 @property
287 def size(self):
288 """
289 Size of grid in the all directions [m]
290 """
291 return self.N * self.spacing
293 @property
294 def total_grid_points(self) -> np.ndarray:
295 """
296 Total number of grid points (equal to Nx * Ny * Nz)
297 """
298 return np.prod(self.N)
300 @property
301 def kx(self):
302 """
303 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the x-direction [rad/m]
305 Returns:
306 plaid xn matrix
307 """
308 if self.dim == 1:
309 return self.k_vec.x
310 elif self.dim == 2:
311 return np.tile(self.k_vec.x, (1, self.Ny))
312 else:
313 return np.tile(self.k_vec.x[:, :, None], (1, self.Ny, self.Nz))
315 @property
316 def ky(self):
317 """
318 Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the y-direction [rad/m]
320 Returns:
321 plaid yn matrix
322 """
323 if self.dim == 2:
324 return np.tile(self.k_vec.y.T, (self.Nx, 1))
325 elif self.dim == 3:
326 return np.tile(self.k_vec.y[None, :, :], (self.Nx, 1, self.Nz))
327 return np.nan
329 @property
330 def kz(self):
331 """
332 # Nx x Ny x Nz grid containing repeated copies of the wavenumber components in the z-direction [rad/m]
334 Returns:
335 plaid zn matrix
336 """
337 if self.dim == 3:
338 return np.tile(self.k_vec.z.T[None, :, :], (self.Nx, self.Ny, 1))
339 else:
340 return np.nan
342 @property
343 def y_size(self):
344 """
345 Size of grid in the y-direction [m]
346 """
347 return self.Ny * self.dy
349 @property
350 def z_size(self):
351 """
352 Size of grid in the z-direction [m]
353 """
354 return self.Nz * self.dz
356 @property
357 def k_max(self): # added by us, not the same as kWave k_max (see k_max_all for KwaveGrid.k_max)
358 """
359 Maximum supported spatial frequency in the 3 directions [rad/m]
361 Returns:
362 Vector of 3 elements each in [rad/m]. Value for higher dimensions set to NaN
363 """
364 #
365 kx_max = np.abs(self.k_vec.x).max()
366 ky_max = np.abs(self.k_vec.y).max() if self.dim >= 2 else np.nan
367 kz_max = np.abs(self.k_vec.z).max() if self.dim == 3 else np.nan
368 return Array([kx_max, ky_max, kz_max])
370 @property
371 def k_max_all(self):
372 """
373 Maximum supported spatial frequency in all directions [rad/m]
374 Originally k_max in kWave.kWaveGrid!
376 Returns:
377 Scalar in [rad/m]
378 """
379 #
380 return np.nanmin(self.k_max.numpy())
382 ########################################
383 # functions that can only be accessed by class members
384 ########################################
385 @staticmethod
386 def makeDim(num_points, spacing):
387 """
388 Create the grid parameters for a single spatial direction
390 Args:
391 num_points:
392 spacing:
394 Returns:
396 """
397 # define the discretisation of the spatial dimension such that there is always a DC component
398 if num_points % 2 == 0:
399 # grid dimension has an even number of points
400 nx = np.arange(-num_points / 2, num_points / 2) / num_points
401 else:
402 # grid dimension has an odd number of points
403 nx = np.arange(-(num_points - 1) / 2, (num_points - 1) / 2 + 1) / num_points
404 nx = np.array(nx).T
406 # force middle value to be zero in case 1/Nx is a recurring
407 # number and the series doesn't give exactly zero
408 nx[int(num_points // 2)] = 0
410 # define the wavenumber vector components
411 res = (2 * math.pi / spacing) * nx
412 return res[:, None]
414 def highest_prime_factors(self, axisymmetric=None) -> np.ndarray:
415 """
416 calculate highest prime factors
418 Args:
419 axisymmetric: Axisymmetric code or None
421 Returns:
422 Vector of three elements
423 """
424 # import statement place here in order to avoid circular dependencies
425 from kwave.utils import largest_prime_factor
426 if axisymmetric is not None:
427 if axisymmetric == 'WSWA':
428 prime_facs = [largest_prime_factor(self.Nx),
429 largest_prime_factor(self.Ny * 4),
430 largest_prime_factor(self.Nz)]
431 elif axisymmetric == 'WSWS':
432 prime_facs = [largest_prime_factor(self.Nx),
433 largest_prime_factor(self.Ny * 2 - 2),
434 largest_prime_factor(self.Nz)]
435 else:
436 raise ValueError('Unknown axisymmetric symmetry.')
437 else:
438 prime_facs = [largest_prime_factor(self.Nx),
439 largest_prime_factor(self.Ny),
440 largest_prime_factor(self.Nz)]
441 return np.array(prime_facs)
443 def makeTime(self, c, cfl=CFL_DEFAULT, t_end=None):
444 """
445 Compute Nt and dt based on the cfl number and grid size, where
446 the number of time-steps is chosen based on the time it takes to
447 travel from one corner of the grid to the geometrically opposite
448 corner. Note, if c is given as a matrix, the calculation for dt
449 is based on the maximum value, and the calculation for t_end
450 based on the minimum value.
452 Args:
453 c:
454 cfl:
455 t_end:
457 Returns:
458 Nothing
459 """
460 # if c is a matrix, find the minimum and maximum values
461 c = np.array(c)
462 c_min, c_max = c.min(), c.max()
464 # check for user define t_end, otherwise set the simulation
465 # length based on the size of the grid diagonal and the maximum
466 # sound speed in the medium
467 if t_end is None:
468 t_end = np.linalg.norm(self.size, ord=2) / c_min
470 # extract the smallest grid spacing
471 min_grid_dim = self.spacing.min()
473 # assign time step based on CFL stability criterion
474 self.dt = cfl * min_grid_dim / c_max
476 # assign number of time steps based on t_end
477 self.Nt = int(t_end / self.dt) + 1
479 # catch case were dt is a recurring number
480 if (int(t_end / self.dt) != math.ceil(t_end / self.dt)) and (t_end % self.dt == 0):
481 self.Nt = self.Nt + 1
483 ##################################################
484 ####
485 #### FUNCTIONS BELOW WERE NOT TESTED FOR CORRECTNESS!
486 ####
487 ##################################################
488 def kx_vec_dtt(self, dtt_type):
489 """
490 Compute the DTT wavenumber vector in the x-direction
492 Args:
493 dtt_type:
495 Returns:
497 """
498 kx_vec_dtt, M = self.makeDTTDim(self.Nx, self.dx, dtt_type)
499 return kx_vec_dtt, M
501 def ky_vec_dtt(self, dtt_type):
502 """
503 Compute the DTT wavenumber vector in the y-direction
505 Args:
506 dtt_type:
508 Returns:
510 """
511 ky_vec_dtt, M = self.makeDTTDim(self.Ny, self.dy, dtt_type)
512 return ky_vec_dtt, M
514 def kz_vec_dtt(self, dtt_type):
515 """
516 Compute the DTT wavenumber vector in the z-direction
518 Args:
519 dtt_type:
521 Returns:
523 """
524 kz_vec_dtt, M = self.makeDTTDim(self.Nz, self.dz, dtt_type)
525 return kz_vec_dtt, M
527 @staticmethod
528 def makeDTTDim(Nx, dx, dtt_type):
529 """
530 Create the DTT grid parameters for a single spatial direction
532 Args:
533 Nx:
534 dx:
535 dtt_type:
537 Returns:
539 """
541 # compute the implied period of the input function
542 if dtt_type == 1:
543 M = 2 * (Nx - 1)
544 elif dtt_type == 5:
545 M = 2 * (Nx + 1)
546 else:
547 M = 2 * Nx
549 # calculate the wavenumbers
550 if dtt_type == DiscreteCosine.TYPE_1:
551 # whole-wavenumber DTT
552 # WSWS / DCT-I
553 n = np.arange(0, M // 2).T
554 kx_vec = 2 * math.pi * n / (M * dx)
555 elif dtt_type == DiscreteCosine.TYPE_2:
556 # whole-wavenumber DTT
557 # HSHS / DCT-II
558 n = np.arange(0, M // 2).T
559 kx_vec = 2 * math.pi * n / (M * dx)
560 elif dtt_type == DiscreteSine.TYPE_1:
561 # whole-wavenumber DTT
562 # WAWA / DST-I
563 n = np.arange(1, M // 2).T
564 kx_vec = 2 * math.pi * n / (M * dx)
565 elif dtt_type == DiscreteSine.TYPE_2:
566 # whole-wavenumber DTT
567 # HAHA / DST-II
568 n = np.arange(1, M // 2).T
569 kx_vec = 2 * math.pi * n / (M * dx)
570 elif dtt_type in [DiscreteCosine.TYPE_3, DiscreteCosine.TYPE_4,
571 DiscreteSine.TYPE_3, DiscreteSine.TYPE_4]:
572 # half-wavenumber DTTs
573 # WSWA / DCT-III
574 # HSHA / DCT-IV
575 # WAWS / DST-III
576 # HAHS / DST-IV
577 n = np.arange(0, M // 2).T
578 kx_vec = 2 * math.pi * (n + 0.5) / (M * dx)
579 else:
580 raise ValueError
582 return kx_vec, M
584 ########################################
585 # functions for non-uniform grids
586 ########################################
587 def setNUGrid(self, dim, n_vec, dudn, n_vec_sg, dudn_sg):
588 """
589 Function to set non-uniform grid parameters in specified dimension
591 Args:
592 dim:
593 n_vec:
594 dudn:
595 n_vec_sg:
596 dudn_sg:
598 Returns:
600 """
602 # check the dimension to set the nonuniform grid is appropriate
603 assert dim <= self.dim, f'Cannot set nonuniform parameters for dimension {dim} of {self.dim}-dimensional grid.'
605 # force non-uniform grid spacing to be column vectors, and the
606 # gradients to be in the correct direction for use with bsxfun
607 n_vec = np.reshape(n_vec, (-1, 1))
608 n_vec_sg = np.reshape(n_vec_sg, (-1, 1))
610 if dim == 1:
611 dudn = np.reshape(dudn, (-1, 1))
612 dudn_sg = np.reshape(dudn_sg, (-1, 1))
613 elif dim == 2:
614 dudn = np.reshape(dudn, (1, -1))
615 dudn_sg = np.reshape(dudn_sg, (1, -1))
616 elif dim == 3:
617 dudn = np.reshape(dudn, (1, 1, -1))
618 dudn_sg = np.reshape(dudn_sg, (1, 1, -1))
620 self.n_vec.assign_dim(self.dim, n_vec)
621 self.n_vec_sg.assign_dim(self.dim, n_vec_sg)
623 self.dudn.assign_dim(self.dim, dudn)
624 self.dudn_sg.assign_dim(self.dudn_sg, dudn_sg)
626 # set non-uniform flag
627 self.nonuniform = True
629 def k_dtt(self, dtt_type): # Not tested for correctness!
630 """
631 compute the individual wavenumber vectors, where dtt_type is the
632 type of discrete trigonometric transform, which corresponds to
633 the assumed input symmetry of the input function, where:
635 1. DCT-I WSWS
636 2. DCT-II HSHS
637 3. DCT-III WSWA
638 4. DCT-IV HSHA
639 5. DST-I WAWA
640 6. DST-II HAHA
641 7. DST-III WAWS
642 8. DST-IV HAHS
644 Args:
645 dtt_type:
647 Returns:
649 """
650 # check dtt_type is a scalar or a vector the same size self.dim
651 dtt_type = np.array(dtt_type)
652 assert (dtt_type.size in [1, self.dim]), f'dtt_type must be a scalar, or {self.dim}D vector'
653 if self.dim == 1:
654 k, M = self.kx_vec_dtt(dtt_type)
655 return k, M
656 elif self.dim == 2:
657 # assign the grid parameters for the x and y spatial directions
658 kx_vec_dtt, Mx = self.kx_vec_dtt(dtt_type[0])
659 ky_vec_dtt, My = self.ky_vec_dtt(dtt_type[-1])
661 # define the wavenumber based on the wavenumber components
662 k = np.zeros((self.Nx, self.Ny))
663 assert len(kx_vec_dtt.shape) == 3
664 k = np.reshape(kx_vec_dtt, (-1, 1, 1)) ** 2 + k
665 k = np.reshape(ky_vec_dtt, (1, -1, 1)) ** 2 + k
666 k = np.sqrt(k)
668 # define product of implied period
669 M = Mx * My
670 return k, M
671 elif self.dim == 3:
672 # assign the grid parameters for the x, y, and z spatial directions
673 kx_vec_dtt, Mx = self.kx_vec_dtt(dtt_type[0])
674 ky_vec_dtt, My = self.ky_vec_dtt(dtt_type[len(dtt_type) // 2])
675 kz_vec_dtt, Mz = self.kz_vec_dtt(dtt_type[-1])
677 # define the wavenumber based on the wavenumber components
678 k = np.zeros((self.Nx, self.Ny, self.Nz))
679 k = np.reshape(kx_vec_dtt, (-1, 1, 1)) ** 2 + k
680 k = np.reshape(ky_vec_dtt, (1, -1, 1)) ** 2 + k
681 k = np.reshape(kz_vec_dtt, (1, 1, -1)) ** 2 + k
682 k = np.sqrt(k)
684 # define product of implied period
685 M = Mx * My * Mz
686 return k, M