Coverage for kwave/ksource.py: 27%
140 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
1from dataclasses import dataclass
2from warnings import warn
4import numpy as np
6from kwave.kgrid import kWaveGrid
7from kwave.utils import num_dim2, matlab_find
10@dataclass
11class kSource(object):
12 _p0 = None
13 p = None #: time varying pressure at each of the source positions given by source.p_mask
14 p_mask = None #: binary matrix specifying the positions of the time varying pressure source distribution
15 p_mode = None #: optional input to control whether the input pressure is injected as a mass source or enforced as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet'
16 p_frequency_ref = None #: Pressure reference frequency
18 ux = None #: time varying particle velocity in the x-direction at each of the source positions given by source.u_mask
19 uy = None #: time varying particle velocity in the y-direction at each of the source positions given by source.u_mask
20 uz = None #: time varying particle velocity in the z-direction at each of the source positions given by source.u_mask
21 u_mask = None #: binary matrix specifying the positions of the time varying particle velocity distribution
22 u_mode = None #: optional input to control whether the input velocity is applied as a force source or enforced as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet'
23 u_frequency_ref = None #: Velocity reference frequency
25 sxx = None #: Stress source in x -> x direction
26 syy = None #: Stress source in y -> y direction
27 szz = None #: Stress source in z -> z direction
28 sxy = None #: Stress source in x -> y direction
29 sxz = None #: Stress source in x -> z direction
30 syz = None #: Stress source in y -> z direction
31 s_mask = None #: Stress source mask
32 s_mode = None #: Stress source mode
34 def is_p0_empty(self) -> bool:
35 """
36 Check if the `p0` field is set and not empty
37 """
38 return self.p0 is None or len(self.p0) == 0 or (np.sum(self.p0 != 0) == 0)
40 @property
41 def p0(self):
42 """
43 Initial pressure within the acoustic medium
44 """
45 return self._p0
47 @p0.setter
48 def p0(self, val):
49 # check size and contents
50 if len(val) == 0 or np.sum(val != 0) == 0:
51 # if the initial pressure is empty, remove field
52 self._p0 = None
53 else:
54 self._p0 = val
56 def validate(self, kgrid: kWaveGrid) -> None:
57 """
58 Validate the object fields for correctness
60 Args:
61 kgrid: Instance of `~kwave.kgrid.kWaveGrid` class
63 Returns:
64 None
65 """
66 if self.p0 is not None:
67 if self.p0.shape != kgrid.k.shape:
68 # throw an error if p0 is not the correct size
69 raise ValueError('source.p0 must be the same size as the computational grid.')
71 # if using the elastic code, reformulate source.p0 in terms of the
72 # stress source terms using the fact that source.p = [0.5 0.5] /
73 # (2*CFL) is the same as source.p0 = 1
74 # if self.elastic_code:
75 # raise NotImplementedError
77 # check for a time varying pressure source input
78 if self.p is not None:
80 # force p_mask to be given if p is given
81 assert self.p_mask is not None
83 # check mask is the correct size
84 # noinspection PyTypeChecker
85 if (num_dim2(self.p_mask) != kgrid.dim) or (self.p_mask.shape != kgrid.k.shape):
86 raise ValueError('source.p_mask must be the same size as the computational grid.')
88 # check mask is not empty
89 assert np.sum(self.p_mask) != 0, 'source.p_mask must be a binary grid with at least one element set to 1.'
91 # don't allow both source.p0 and source.p in the same simulation
92 # USERS: please contact us via http://www.k-wave.org/forum if this
93 # is a problem
94 assert self.p0 is None, "source.p0 and source.p can't be defined in the same simulation."
96 # check the source mode input is valid
97 if self.p_mode is not None:
98 assert self.p_mode in ['additive', 'dirichlet', 'additive-no-correction'], \
99 "source.p_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''."
101 # check if a reference frequency is defined
102 if self.p_frequency_ref is not None:
104 # check frequency is a scalar, positive number
105 assert np.isscalar(self.p_frequency_ref) and self.p_frequency_ref > 0
107 # check frequency is within range
108 assert self.p_frequency_ref <= kgrid.k_max_all * np.min(self.medium.sound_speed / 2 * np.pi), \
109 'source.p_frequency_ref is higher than the maximum frequency supported by the spatial grid.'
111 # change source mode to no include k-space correction
112 self.p_mode = 'additive-no-correction'
114 if len(self.p[0]) > kgrid.Nt:
115 warn(' WARNING: source.p has more time points than kgrid.Nt, remaining time points will not be used.')
117 # check if the mask is binary or labelled
118 p_unique = np.unique(self.p_mask)
120 # create a second indexing variable
121 if p_unique.size <= 2 and p_unique.sum() == 1:
123 # if more than one time series is given, check the number of time
124 # series given matches the number of source elements, or the number
125 # of labelled sources
126 if self.p.shape[0] > 1 and (len(self.p[:, 0]) != self.p_mask.sum()):
127 raise ValueError('The number of time series in source.p must match the number of source elements in source.p_mask.')
128 else:
130 # check the source labels are monotonic, and start from 1
131 if eng.eval('(sum(p_unique(2:end) - p_unique(1:end-1)) != (numel(p_unique) - 1)) || (~any(p_unique == 1))'):
132 raise ValueError('If using a labelled source.p_mask, the source labels must be monotonically increasing and start from 1.')
134 # make sure the correct number of input signals are given
135 if eng.eval('size(source.p, 1) != (numel(p_unique) - 1)'):
136 raise ValueError('The number of time series in source.p must match the number of labelled source elements in source.p_mask.')
138 # check for time varying velocity source input and set source flag
139 if any([(getattr(self, k) is not None) for k in ['ux', 'uy', 'uz', 'u_mask']]):
141 # force u_mask to be given
142 assert self.u_mask is not None
144 # check mask is the correct size
145 assert num_dim2(self.u_mask) == kgrid.dim and self.u_mask.shape == kgrid.k.shape, \
146 'source.u_mask must be the same size as the computational grid.'
148 # check mask is not empty
149 assert np.array(self.u_mask).sum() != 0, \
150 'source.u_mask must be a binary grid with at least one element set to 1.'
152 # check the source mode input is valid
153 if self.u_mode is not None:
154 assert self.u_mode in ['additive', 'dirichlet', 'additive-no-correction'], \
155 "source.u_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''."
157 # check if a reference frequency is defined
158 if self.u_frequency_ref is not None:
160 # check frequency is a scalar, positive number
161 u_frequency_ref = self.u_frequency_ref
162 assert np.isscalar(u_frequency_ref) and u_frequency_ref > 0
164 # check frequency is within range
165 assert self.u_frequency_ref <= (kgrid.k_max_all * np.min(self.medium.sound_speed) / 2 * np.pi), \
166 "source.u_frequency_ref is higher than the maximum frequency supported by the spatial grid."
168 # change source mode to no include k-space correction
169 self.u_mode = 'additive-no-correction'
171 if self.ux is not None:
172 if self.flag_ux > kgrid.Nt:
173 warn(' WARNING: source.ux has more time points than kgrid.Nt, remaining time points will not be used.')
174 if self.uy is not None:
175 if self.flag_uy > kgrid.Nt:
176 warn(' WARNING: source.uy has more time points than kgrid.Nt, remaining time points will not be used.')
177 if self.uz is not None:
178 if self.flag_uz > kgrid.Nt:
179 warn(' WARNING: source.uz has more time points than kgrid.Nt, remaining time points will not be used.')
181 # check if the mask is binary or labelled
182 u_unique = np.unique(self.u_mask)
184 # create a second indexing variable
185 if u_unique.size <= 2 and u_unique.sum() == 1:
186 # if more than one time series is given, check the number of time
187 # series given matches the number of source elements
188 ux_size = self.ux[:, 0].size
189 uy_size = self.uy[:, 0].size if (self.uy is not None) else None
190 uz_size = self.uz[:, 0].size if (self.uz is not None) else None
191 u_sum = np.sum(self.u_mask)
192 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)):
193 if (self.flag_ux and (ux_size != u_sum)) and (self.flag_uy and (uy_size != u_sum)) or (self.flag_uz and (uz_size != u_sum)):
194 raise ValueError('The number of time series in source.ux (etc) must match the number of source elements in source.u_mask.')
196 # if more than one time series is given, check the number of time
197 # series given matches the number of source elements
198 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)):
199 if (self.flag_ux and (ux_size != u_sum)) or (self.flag_uy and (uy_size != u_sum)) or (self.flag_uz and (uz_size != u_sum)):
200 raise ValueError('The number of time series in source.ux (etc) must match the number of source elements in source.u_mask.')
201 else:
202 raise NotImplementedError
204 # check the source labels are monotonic, and start from 1
205 # if (sum(u_unique(2:end) - u_unique(1:end-1)) != (numel(u_unique) - 1)) or (~any(u_unique == 1))
206 if eng.eval('(sum(u_unique(2:end) - u_unique(1:end-1)) ~= (numel(u_unique) - 1)) || (~any(u_unique == 1))'):
207 raise ValueError('If using a labelled source.u_mask, the source labels must be monotonically increasing and start from 1.')
209 # if more than one time series is given, check the number of time
210 # series given matches the number of source elements
211 # if (flgs.source_ux and (size(source.ux, 1) != (numel(u_unique) - 1))) or (flgs.source_uy and (size(source.uy, 1) != (numel(u_unique) - 1))) or (flgs.source_uz and (size(source.uz, 1) != (numel(u_unique) - 1)))
212 if eng.eval('(flgs.source_ux && (size(source.ux, 1) ~= (numel(u_unique) - 1))) || (flgs.source_uy && (size(source.uy, 1) ~= (numel(u_unique) - 1))) || (flgs.source_uz && (size(source.uz, 1) ~= (numel(u_unique) - 1)))'):
213 raise ValueError('The number of time series in source.ux (etc) must match the number of labelled source elements in source.u_mask.')
215 # check for time varying stress source input and set source flag
216 if any([(getattr(self, k) is not None) for k in ['sxx', 'syy', 'szz', 'sxy', 'sxz', 'syz', 's_mask']]):
218 # force s_mask to be given
219 enforce_fields(self, 's_mask')
221 # check mask is the correct size
222 # if (numDim(source.s_mask) != kgrid.dim) or (all(size(source.s_mask) != size(kgrid.k)))
223 if eng.eval('(numDim(source.s_mask) ~= kgrid.dim) || (all(size(source.s_mask) ~= size(kgrid.k)))'):
224 raise ValueError('source.s_mask must be the same size as the computational grid.')
226 # check mask is not empty
227 assert np.array(eng.getfield(source, 's_mask')) != 0, \
228 "source.s_mask must be a binary grid with at least one element set to 1."
230 # check the source mode input is valid
231 if eng.isfield(source, 's_mode'):
232 assert eng.getfield(source, 's_mode') in ['additive', 'dirichlet'], \
233 "source.s_mode must be set to ''additive'' or ''dirichlet''."
234 else:
235 eng.setfield(source, 's_mode', self.SOURCE_S_MODE_DEF)
237 # set source flgs to the length of the sources, this allows the
238 # inputs to be defined independently and be of any length
239 if self.sxx is not None and self_sxx > k_Nt:
240 warn(' WARNING: source.sxx has more time points than kgrid.Nt,'
241 ' remaining time points will not be used.')
242 if self.syy is not None and self_syy > k_Nt:
243 warn(' WARNING: source.syy has more time points than kgrid.Nt,'
244 ' remaining time points will not be used.')
245 if self.szz is not None and self_szz > k_Nt:
246 warn(' WARNING: source.szz has more time points than kgrid.Nt,'
247 ' remaining time points will not be used.')
248 if self.sxy is not None and self_sxy > k_Nt:
249 warn(' WARNING: source.sxy has more time points than kgrid.Nt,'
250 ' remaining time points will not be used.')
251 if self.sxz is not None and self_sxz > k_Nt:
252 warn(' WARNING: source.sxz has more time points than kgrid.Nt,'
253 ' remaining time points will not be used.')
254 if self.syz is not None and self_syz > k_Nt:
255 warn(' WARNING: source.syz has more time points than kgrid.Nt,'
256 ' remaining time points will not be used.')
258 # create an indexing variable corresponding to the location of all
259 # the source elements
260 raise NotImplementedError
262 # check if the mask is binary or labelled
263 's_unique = unique(source.s_mask);'
265 # create a second indexing variable
266 if eng.eval('numel(s_unique) <= 2 && sum(s_unique) == 1'):
267 s_mask = eng.getfield(source, 's_mask')
268 s_mask_sum = np.array(s_mask).sum()
270 # if more than one time series is given, check the number of time
271 # series given matches the number of source elements
272 if (self.source_sxx and (eng.eval('length(source.sxx(:,1)) > 1))'))) or \
273 (self.source_syy and (eng.eval('length(source.syy(:,1)) > 1))'))) or \
274 (self.source_szz and (eng.eval('length(source.szz(:,1)) > 1))'))) or \
275 (self.source_sxy and (eng.eval('length(source.sxy(:,1)) > 1))'))) or \
276 (self.source_sxz and (eng.eval('length(source.sxz(:,1)) > 1))'))) or \
277 (self.source_syz and (eng.eval('length(source.syz(:,1)) > 1))'))):
278 if (self.source_sxx and (eng.eval('length(source.sxx(:,1))') != s_mask_sum)) or \
279 (self.source_syy and (eng.eval('length(source.syy(:,1))') != s_mask_sum)) or \
280 (self.source_szz and (eng.eval('length(source.szz(:,1))') != s_mask_sum)) or \
281 (self.source_sxy and (eng.eval('length(source.sxy(:,1))') != s_mask_sum)) or \
282 (self.source_sxz and (eng.eval('length(source.sxz(:,1))') != s_mask_sum)) or \
283 (self.source_syz and (eng.eval('length(source.syz(:,1))') != s_mask_sum)):
284 raise ValueError('The number of time series in source.sxx (etc) must match the number of source elements in source.s_mask.')
286 else:
287 # check the source labels are monotonic, and start from 1
288 # if (sum(s_unique(2:end) - s_unique(1:end-1)) != (numel(s_unique) - 1)) or (~any(s_unique == 1))
289 if eng.eval('(sum(s_unique(2:end) - s_unique(1:end-1)) ~= (numel(s_unique) - 1)) || (~any(s_unique == 1))'):
290 raise ValueError('If using a labelled source.s_mask, the source labels must be monotonically increasing and start from 1.')
292 numel_s_unique = eng.eval('numel(s_unique) - 1;')
293 # if more than one time series is given, check the number of time
294 # series given matches the number of source elements
295 if (self.source_sxx and (eng.eval('size(source.sxx, 1)') != numel_s_unique)) or \
296 (self.source_syy and (eng.eval('size(source.syy, 1)') != numel_s_unique)) or \
297 (self.source_szz and (eng.eval('size(source.szz, 1)') != numel_s_unique)) or \
298 (self.source_sxy and (eng.eval('size(source.sxy, 1)') != numel_s_unique)) or \
299 (self.source_sxz and (eng.eval('size(source.sxz, 1)') != numel_s_unique)) or \
300 (self.source_syz and (eng.eval('size(source.syz, 1)') != numel_s_unique)):
301 raise ValueError('The number of time series in source.sxx (etc) must match the number of labelled source elements in source.u_mask.')
303 @property
304 def flag_ux(self):
305 """
306 Get the length of the sources in X-direction, this allows the
307 inputs to be defined independently and be of any length
309 Returns:
310 Length of the sources
311 """
312 return len(self.ux[0]) if self.ux is not None else 0
314 @property
315 def flag_uy(self):
316 """
317 Get the length of the sources in X-direction, this allows the
318 inputs to be defined independently and be of any length
320 Returns:
321 Length of the sources
322 """
323 return len(self.uy[0]) if self.uy is not None else 0
325 @property
326 def flag_uz(self):
327 """
328 Get the length of the sources in X-direction, this allows the
329 inputs to be defined independently and be of any length
331 Returns:
332 Length of the sources
333 """
334 return len(self.uz[0]) if self.uz is not None else 0