Coverage for kwave/ksource.py: 27%

140 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-24 11:55 -0700

1from dataclasses import dataclass 

2from warnings import warn 

3 

4import numpy as np 

5 

6from kwave.kgrid import kWaveGrid 

7from kwave.utils import num_dim2, matlab_find 

8 

9 

10@dataclass 

11class kSource(object): 

12 _p0 = None 

13 p = None #: time varying pressure at each of the source positions given by source.p_mask 

14 p_mask = None #: binary matrix specifying the positions of the time varying pressure source distribution 

15 p_mode = None #: optional input to control whether the input pressure is injected as a mass source or enforced as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet' 

16 p_frequency_ref = None #: Pressure reference frequency 

17 

18 ux = None #: time varying particle velocity in the x-direction at each of the source positions given by source.u_mask 

19 uy = None #: time varying particle velocity in the y-direction at each of the source positions given by source.u_mask 

20 uz = None #: time varying particle velocity in the z-direction at each of the source positions given by source.u_mask 

21 u_mask = None #: binary matrix specifying the positions of the time varying particle velocity distribution 

22 u_mode = None #: optional input to control whether the input velocity is applied as a force source or enforced as a dirichlet boundary condition; valid inputs are 'additive' (the default) or 'dirichlet' 

23 u_frequency_ref = None #: Velocity reference frequency 

24 

25 sxx = None #: Stress source in x -> x direction 

26 syy = None #: Stress source in y -> y direction 

27 szz = None #: Stress source in z -> z direction 

28 sxy = None #: Stress source in x -> y direction 

29 sxz = None #: Stress source in x -> z direction 

30 syz = None #: Stress source in y -> z direction 

31 s_mask = None #: Stress source mask 

32 s_mode = None #: Stress source mode 

33 

34 def is_p0_empty(self) -> bool: 

35 """ 

36 Check if the `p0` field is set and not empty 

37 """ 

38 return self.p0 is None or len(self.p0) == 0 or (np.sum(self.p0 != 0) == 0) 

39 

40 @property 

41 def p0(self): 

42 """ 

43 Initial pressure within the acoustic medium 

44 """ 

45 return self._p0 

46 

47 @p0.setter 

48 def p0(self, val): 

49 # check size and contents 

50 if len(val) == 0 or np.sum(val != 0) == 0: 

51 # if the initial pressure is empty, remove field 

52 self._p0 = None 

53 else: 

54 self._p0 = val 

55 

56 def validate(self, kgrid: kWaveGrid) -> None: 

57 """ 

58 Validate the object fields for correctness 

59 

60 Args: 

61 kgrid: Instance of `~kwave.kgrid.kWaveGrid` class 

62 

63 Returns: 

64 None 

65 """ 

66 if self.p0 is not None: 

67 if self.p0.shape != kgrid.k.shape: 

68 # throw an error if p0 is not the correct size 

69 raise ValueError('source.p0 must be the same size as the computational grid.') 

70 

71 # if using the elastic code, reformulate source.p0 in terms of the 

72 # stress source terms using the fact that source.p = [0.5 0.5] / 

73 # (2*CFL) is the same as source.p0 = 1 

74 # if self.elastic_code: 

75 # raise NotImplementedError 

76 

77 # check for a time varying pressure source input 

78 if self.p is not None: 

79 

80 # force p_mask to be given if p is given 

81 assert self.p_mask is not None 

82 

83 # check mask is the correct size 

84 # noinspection PyTypeChecker 

85 if (num_dim2(self.p_mask) != kgrid.dim) or (self.p_mask.shape != kgrid.k.shape): 

86 raise ValueError('source.p_mask must be the same size as the computational grid.') 

87 

88 # check mask is not empty 

89 assert np.sum(self.p_mask) != 0, 'source.p_mask must be a binary grid with at least one element set to 1.' 

90 

91 # don't allow both source.p0 and source.p in the same simulation 

92 # USERS: please contact us via http://www.k-wave.org/forum if this 

93 # is a problem 

94 assert self.p0 is None, "source.p0 and source.p can't be defined in the same simulation." 

95 

96 # check the source mode input is valid 

97 if self.p_mode is not None: 

98 assert self.p_mode in ['additive', 'dirichlet', 'additive-no-correction'], \ 

99 "source.p_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''." 

100 

101 # check if a reference frequency is defined 

102 if self.p_frequency_ref is not None: 

103 

104 # check frequency is a scalar, positive number 

105 assert np.isscalar(self.p_frequency_ref) and self.p_frequency_ref > 0 

106 

107 # check frequency is within range 

108 assert self.p_frequency_ref <= kgrid.k_max_all * np.min(self.medium.sound_speed / 2 * np.pi), \ 

109 'source.p_frequency_ref is higher than the maximum frequency supported by the spatial grid.' 

110 

111 # change source mode to no include k-space correction 

112 self.p_mode = 'additive-no-correction' 

113 

114 if len(self.p[0]) > kgrid.Nt: 

115 warn(' WARNING: source.p has more time points than kgrid.Nt, remaining time points will not be used.') 

116 

117 # check if the mask is binary or labelled 

118 p_unique = np.unique(self.p_mask) 

119 

120 # create a second indexing variable 

121 if p_unique.size <= 2 and p_unique.sum() == 1: 

122 

123 # if more than one time series is given, check the number of time 

124 # series given matches the number of source elements, or the number 

125 # of labelled sources 

126 if self.p.shape[0] > 1 and (len(self.p[:, 0]) != self.p_mask.sum()): 

127 raise ValueError('The number of time series in source.p must match the number of source elements in source.p_mask.') 

128 else: 

129 

130 # check the source labels are monotonic, and start from 1 

131 if eng.eval('(sum(p_unique(2:end) - p_unique(1:end-1)) != (numel(p_unique) - 1)) || (~any(p_unique == 1))'): 

132 raise ValueError('If using a labelled source.p_mask, the source labels must be monotonically increasing and start from 1.') 

133 

134 # make sure the correct number of input signals are given 

135 if eng.eval('size(source.p, 1) != (numel(p_unique) - 1)'): 

136 raise ValueError('The number of time series in source.p must match the number of labelled source elements in source.p_mask.') 

137 

138 # check for time varying velocity source input and set source flag 

139 if any([(getattr(self, k) is not None) for k in ['ux', 'uy', 'uz', 'u_mask']]): 

140 

141 # force u_mask to be given 

142 assert self.u_mask is not None 

143 

144 # check mask is the correct size 

145 assert num_dim2(self.u_mask) == kgrid.dim and self.u_mask.shape == kgrid.k.shape, \ 

146 'source.u_mask must be the same size as the computational grid.' 

147 

148 # check mask is not empty 

149 assert np.array(self.u_mask).sum() != 0, \ 

150 'source.u_mask must be a binary grid with at least one element set to 1.' 

151 

152 # check the source mode input is valid 

153 if self.u_mode is not None: 

154 assert self.u_mode in ['additive', 'dirichlet', 'additive-no-correction'], \ 

155 "source.u_mode must be set to ''additive'', ''additive-no-correction'', or ''dirichlet''." 

156 

157 # check if a reference frequency is defined 

158 if self.u_frequency_ref is not None: 

159 

160 # check frequency is a scalar, positive number 

161 u_frequency_ref = self.u_frequency_ref 

162 assert np.isscalar(u_frequency_ref) and u_frequency_ref > 0 

163 

164 # check frequency is within range 

165 assert self.u_frequency_ref <= (kgrid.k_max_all * np.min(self.medium.sound_speed) / 2 * np.pi), \ 

166 "source.u_frequency_ref is higher than the maximum frequency supported by the spatial grid." 

167 

168 # change source mode to no include k-space correction 

169 self.u_mode = 'additive-no-correction' 

170 

171 if self.ux is not None: 

172 if self.flag_ux > kgrid.Nt: 

173 warn(' WARNING: source.ux has more time points than kgrid.Nt, remaining time points will not be used.') 

174 if self.uy is not None: 

175 if self.flag_uy > kgrid.Nt: 

176 warn(' WARNING: source.uy has more time points than kgrid.Nt, remaining time points will not be used.') 

177 if self.uz is not None: 

178 if self.flag_uz > kgrid.Nt: 

179 warn(' WARNING: source.uz has more time points than kgrid.Nt, remaining time points will not be used.') 

180 

181 # check if the mask is binary or labelled 

182 u_unique = np.unique(self.u_mask) 

183 

184 # create a second indexing variable 

185 if u_unique.size <= 2 and u_unique.sum() == 1: 

186 # if more than one time series is given, check the number of time 

187 # series given matches the number of source elements 

188 ux_size = self.ux[:, 0].size 

189 uy_size = self.uy[:, 0].size if (self.uy is not None) else None 

190 uz_size = self.uz[:, 0].size if (self.uz is not None) else None 

191 u_sum = np.sum(self.u_mask) 

192 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)): 

193 if (self.flag_ux and (ux_size != u_sum)) and (self.flag_uy and (uy_size != u_sum)) or (self.flag_uz and (uz_size != u_sum)): 

194 raise ValueError('The number of time series in source.ux (etc) must match the number of source elements in source.u_mask.') 

195 

196 # if more than one time series is given, check the number of time 

197 # series given matches the number of source elements 

198 if (self.flag_ux and (ux_size > 1)) or (self.flag_uy and (uy_size > 1)) or (self.flag_uz and (uz_size > 1)): 

199 if (self.flag_ux and (ux_size != u_sum)) or (self.flag_uy and (uy_size != u_sum)) or (self.flag_uz and (uz_size != u_sum)): 

200 raise ValueError('The number of time series in source.ux (etc) must match the number of source elements in source.u_mask.') 

201 else: 

202 raise NotImplementedError 

203 

204 # check the source labels are monotonic, and start from 1 

205 # if (sum(u_unique(2:end) - u_unique(1:end-1)) != (numel(u_unique) - 1)) or (~any(u_unique == 1)) 

206 if eng.eval('(sum(u_unique(2:end) - u_unique(1:end-1)) ~= (numel(u_unique) - 1)) || (~any(u_unique == 1))'): 

207 raise ValueError('If using a labelled source.u_mask, the source labels must be monotonically increasing and start from 1.') 

208 

209 # if more than one time series is given, check the number of time 

210 # series given matches the number of source elements 

211 # if (flgs.source_ux and (size(source.ux, 1) != (numel(u_unique) - 1))) or (flgs.source_uy and (size(source.uy, 1) != (numel(u_unique) - 1))) or (flgs.source_uz and (size(source.uz, 1) != (numel(u_unique) - 1))) 

212 if eng.eval('(flgs.source_ux && (size(source.ux, 1) ~= (numel(u_unique) - 1))) || (flgs.source_uy && (size(source.uy, 1) ~= (numel(u_unique) - 1))) || (flgs.source_uz && (size(source.uz, 1) ~= (numel(u_unique) - 1)))'): 

213 raise ValueError('The number of time series in source.ux (etc) must match the number of labelled source elements in source.u_mask.') 

214 

215 # check for time varying stress source input and set source flag 

216 if any([(getattr(self, k) is not None) for k in ['sxx', 'syy', 'szz', 'sxy', 'sxz', 'syz', 's_mask']]): 

217 

218 # force s_mask to be given 

219 enforce_fields(self, 's_mask') 

220 

221 # check mask is the correct size 

222 # if (numDim(source.s_mask) != kgrid.dim) or (all(size(source.s_mask) != size(kgrid.k))) 

223 if eng.eval('(numDim(source.s_mask) ~= kgrid.dim) || (all(size(source.s_mask) ~= size(kgrid.k)))'): 

224 raise ValueError('source.s_mask must be the same size as the computational grid.') 

225 

226 # check mask is not empty 

227 assert np.array(eng.getfield(source, 's_mask')) != 0, \ 

228 "source.s_mask must be a binary grid with at least one element set to 1." 

229 

230 # check the source mode input is valid 

231 if eng.isfield(source, 's_mode'): 

232 assert eng.getfield(source, 's_mode') in ['additive', 'dirichlet'], \ 

233 "source.s_mode must be set to ''additive'' or ''dirichlet''." 

234 else: 

235 eng.setfield(source, 's_mode', self.SOURCE_S_MODE_DEF) 

236 

237 # set source flgs to the length of the sources, this allows the 

238 # inputs to be defined independently and be of any length 

239 if self.sxx is not None and self_sxx > k_Nt: 

240 warn(' WARNING: source.sxx has more time points than kgrid.Nt,' 

241 ' remaining time points will not be used.') 

242 if self.syy is not None and self_syy > k_Nt: 

243 warn(' WARNING: source.syy has more time points than kgrid.Nt,' 

244 ' remaining time points will not be used.') 

245 if self.szz is not None and self_szz > k_Nt: 

246 warn(' WARNING: source.szz has more time points than kgrid.Nt,' 

247 ' remaining time points will not be used.') 

248 if self.sxy is not None and self_sxy > k_Nt: 

249 warn(' WARNING: source.sxy has more time points than kgrid.Nt,' 

250 ' remaining time points will not be used.') 

251 if self.sxz is not None and self_sxz > k_Nt: 

252 warn(' WARNING: source.sxz has more time points than kgrid.Nt,' 

253 ' remaining time points will not be used.') 

254 if self.syz is not None and self_syz > k_Nt: 

255 warn(' WARNING: source.syz has more time points than kgrid.Nt,' 

256 ' remaining time points will not be used.') 

257 

258 # create an indexing variable corresponding to the location of all 

259 # the source elements 

260 raise NotImplementedError 

261 

262 # check if the mask is binary or labelled 

263 's_unique = unique(source.s_mask);' 

264 

265 # create a second indexing variable 

266 if eng.eval('numel(s_unique) <= 2 && sum(s_unique) == 1'): 

267 s_mask = eng.getfield(source, 's_mask') 

268 s_mask_sum = np.array(s_mask).sum() 

269 

270 # if more than one time series is given, check the number of time 

271 # series given matches the number of source elements 

272 if (self.source_sxx and (eng.eval('length(source.sxx(:,1)) > 1))'))) or \ 

273 (self.source_syy and (eng.eval('length(source.syy(:,1)) > 1))'))) or \ 

274 (self.source_szz and (eng.eval('length(source.szz(:,1)) > 1))'))) or \ 

275 (self.source_sxy and (eng.eval('length(source.sxy(:,1)) > 1))'))) or \ 

276 (self.source_sxz and (eng.eval('length(source.sxz(:,1)) > 1))'))) or \ 

277 (self.source_syz and (eng.eval('length(source.syz(:,1)) > 1))'))): 

278 if (self.source_sxx and (eng.eval('length(source.sxx(:,1))') != s_mask_sum)) or \ 

279 (self.source_syy and (eng.eval('length(source.syy(:,1))') != s_mask_sum)) or \ 

280 (self.source_szz and (eng.eval('length(source.szz(:,1))') != s_mask_sum)) or \ 

281 (self.source_sxy and (eng.eval('length(source.sxy(:,1))') != s_mask_sum)) or \ 

282 (self.source_sxz and (eng.eval('length(source.sxz(:,1))') != s_mask_sum)) or \ 

283 (self.source_syz and (eng.eval('length(source.syz(:,1))') != s_mask_sum)): 

284 raise ValueError('The number of time series in source.sxx (etc) must match the number of source elements in source.s_mask.') 

285 

286 else: 

287 # check the source labels are monotonic, and start from 1 

288 # if (sum(s_unique(2:end) - s_unique(1:end-1)) != (numel(s_unique) - 1)) or (~any(s_unique == 1)) 

289 if eng.eval('(sum(s_unique(2:end) - s_unique(1:end-1)) ~= (numel(s_unique) - 1)) || (~any(s_unique == 1))'): 

290 raise ValueError('If using a labelled source.s_mask, the source labels must be monotonically increasing and start from 1.') 

291 

292 numel_s_unique = eng.eval('numel(s_unique) - 1;') 

293 # if more than one time series is given, check the number of time 

294 # series given matches the number of source elements 

295 if (self.source_sxx and (eng.eval('size(source.sxx, 1)') != numel_s_unique)) or \ 

296 (self.source_syy and (eng.eval('size(source.syy, 1)') != numel_s_unique)) or \ 

297 (self.source_szz and (eng.eval('size(source.szz, 1)') != numel_s_unique)) or \ 

298 (self.source_sxy and (eng.eval('size(source.sxy, 1)') != numel_s_unique)) or \ 

299 (self.source_sxz and (eng.eval('size(source.sxz, 1)') != numel_s_unique)) or \ 

300 (self.source_syz and (eng.eval('size(source.syz, 1)') != numel_s_unique)): 

301 raise ValueError('The number of time series in source.sxx (etc) must match the number of labelled source elements in source.u_mask.') 

302 

303 @property 

304 def flag_ux(self): 

305 """ 

306 Get the length of the sources in X-direction, this allows the 

307 inputs to be defined independently and be of any length 

308 

309 Returns: 

310 Length of the sources 

311 """ 

312 return len(self.ux[0]) if self.ux is not None else 0 

313 

314 @property 

315 def flag_uy(self): 

316 """ 

317 Get the length of the sources in X-direction, this allows the 

318 inputs to be defined independently and be of any length 

319 

320 Returns: 

321 Length of the sources 

322 """ 

323 return len(self.uy[0]) if self.uy is not None else 0 

324 

325 @property 

326 def flag_uz(self): 

327 """ 

328 Get the length of the sources in X-direction, this allows the 

329 inputs to be defined independently and be of any length 

330 

331 Returns: 

332 Length of the sources 

333 """ 

334 return len(self.uz[0]) if self.uz is not None else 0