Coverage for kwave/kWaveSimulation_helper/save_to_disk_func.py: 12%
204 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
1import os
3from scipy.io import savemat
5from kwave import kWaveMedium, kWaveGrid, SimulationOptions
6from kwave.utils import scale_time, TicToc, num_dim2, write_matrix, write_attributes
7from kwave.utils import dotdict
8import numpy as np
11def save_to_disk_func(
12 kgrid: kWaveGrid, medium: kWaveMedium, source,
13 opt: SimulationOptions, values: dotdict, flags: dotdict):
14 # update command line status
15 print(' precomputation completed in ', scale_time(TicToc.toc()))
16 TicToc.tic()
17 print(' saving input files to disk...')
19 # check for a binary sensor mask or cuboid corners
20 # modified by Farid | disabled temporarily!
21 # assert self.binary_sensor_mask or self.cuboid_corners, \
22 # "Optional input ''SaveToDisk'' only supported for sensor masks defined as a binary matrix or the opposing corners of a rectangle (2D) or cuboid (3D)."
24 # =========================================================================
25 # VARIABLE LIST
26 # =========================================================================
27 integer_variables = dotdict()
28 float_variables = dotdict()
30 grab_integer_variables(integer_variables, kgrid, flags, medium)
31 grab_pml_size(integer_variables, opt)
32 grab_float_variables(float_variables, kgrid, opt, values, flags.elastic_code, flags.axisymmetric)
34 # overwrite z-values for 2D simulations
35 if kgrid.dim == 2:
36 integer_variables.Nz = 1
37 integer_variables.pml_z_size = 0
39 grab_medium_props(integer_variables, float_variables, medium, flags.elastic_code)
40 grab_source_props(integer_variables, float_variables, source,
41 values.u_source_pos_index, values.s_source_pos_index, values.p_source_pos_index,
42 values.transducer_input_signal, values.delay_mask)
44 grab_sensor_props(integer_variables, kgrid.dim, values.sensor_mask_index, values.record.cuboid_corners_list)
45 grab_nonuniform_grid_props(float_variables, kgrid, flags.nonuniform_grid)
47 # =========================================================================
48 # DATACAST AND SAVING
49 # =========================================================================
51 remove_z_dimension(float_variables, kgrid.dim)
52 save_file(opt.save_to_disk, integer_variables, float_variables, opt.hdf_compression_level)
54 # update command line status
55 print(' completed in ', scale_time(TicToc.toc()))
58def grab_integer_variables(integer_variables, kgrid, flags, medium):
59 # integer variables used within the time loop for all codes
61 variables = dotdict({
62 'Nx': kgrid.Nx,
63 'Ny': kgrid.Ny,
64 'Nz': kgrid.Nz,
65 'Nt': kgrid.Nt,
66 'p_source_flag': flags.source_p,
67 'p0_source_flag': flags.source_p0,
69 'ux_source_flag': flags.source_ux,
70 'uy_source_flag': flags.source_uy,
71 'uz_source_flag': flags.source_uz,
73 'sxx_source_flag': flags.source_sxx,
74 'syy_source_flag': flags.source_syy,
75 'szz_source_flag': flags.source_szz,
77 'sxy_source_flag': flags.source_sxy,
78 'sxz_source_flag': flags.source_sxz,
79 'syz_source_flag': flags.source_syz,
81 'transducer_source_flag': flags.transducer_source,
82 'nonuniform_grid_flag': flags.nonuniform_grid,
83 'nonlinear_flag': medium.is_nonlinear(),
84 'absorbing_flag': None,
85 'elastic_flag': flags.elastic_code,
86 'axisymmetric_flag': flags.axisymmetric,
88 # create pseudonyms for the sensor flgs
89 # 0: binary mask indices
90 # 1: cuboid corners
91 'sensor_mask_type': flags.cuboid_corners
92 })
93 integer_variables.update(variables)
96def grab_pml_size(integer_variables, opt):
97 # additional integer variables not used within time loop but stored directly to output file
98 integer_variables['pml_x_size'] = opt.pml_x_size
99 integer_variables['pml_y_size'] = opt.pml_y_size
100 integer_variables['pml_z_size'] = opt.pml_z_size
103def grab_float_variables(float_variables: dotdict, kgrid, opt, values, is_elastic_code, is_axisymmetric):
104 # single precision variables not used within time loop but stored directly
105 # to the output file for all files
106 variables = dotdict({
107 'dx': kgrid.dx,
108 'dy': kgrid.dy,
109 'dz': kgrid.dz,
111 'pml_x_alpha': opt.pml_x_alpha,
112 'pml_y_alpha': opt.pml_y_alpha,
113 'pml_z_alpha': opt.pml_z_alpha
114 })
115 float_variables.update(variables)
117 if is_elastic_code: # pragma: no cover
118 grab_elastic_code_variables(float_variables, kgrid, values)
119 elif is_axisymmetric:
120 grab_axisymmetric_variables(float_variables, values)
121 else:
122 # single precision variables used within the time loop
123 float_variables['dt'] = values.dt
124 float_variables['c0'] = values.c0
125 float_variables['c_ref'] = values.c_ref
126 float_variables['rho0'] = values.rho0
127 float_variables['rho0_sgx'] = values.rho0_sgx
128 float_variables['rho0_sgy'] = values.rho0_sgy
129 float_variables['rho0_sgz'] = values.rho0_sgz
132def grab_elastic_code_variables(float_variables, kgrid, values): # pragma: no cover
133 # single precision variables used within the time loop
134 float_variables['dt'] = None
135 float_variables['c_ref'] = None
136 float_variables['lambda'] = None
137 float_variables['mu'] = None
139 float_variables['rho0_sgx'] = None
140 float_variables['rho0_sgy'] = None
141 float_variables['rho0_sgz'] = None
143 float_variables['mu_sgxy'] = None
144 float_variables['mu_sgxz'] = None
145 float_variables['mu_sgyz'] = None
147 # create shift variables used for calculating u_non_staggered and I outputs
148 x_shift_neg = np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.x * kgrid.dx / 2))
149 y_shift_neg = np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.y * kgrid.dy / 2)).T
150 z_shift_neg = np.transpose(np.fft.ifftshift(np.exp(-1j * kgrid.k_vec.z * kgrid.dz / 2)), (1, 2, 0))
152 # create reduced variables for use with real-to-complex FFT
153 Nz = kgrid.Nz if kgrid.dim != 2 else 1
154 Nx_r = kgrid.Nx // 2 + 1
155 Ny_r = kgrid.Ny // 2 + 1
156 Nz_r = Nz // 2 + 1
158 ddx_k_shift_pos = values.ddx_k_shift_pos
159 ddx_k_shift_neg = values.ddx_k_shift_neg
161 float_variables['ddx_k_shift_pos_r'] = ddx_k_shift_pos[:Nx_r]
162 float_variables['ddy_k_shift_pos'] = None
163 float_variables['ddz_k_shift_pos'] = None
165 float_variables['ddx_k_shift_neg_r'] = ddx_k_shift_neg[:Nx_r]
166 float_variables['ddy_k_shift_neg'] = None
167 float_variables['ddz_k_shift_neg'] = None
169 float_variables['x_shift_neg_r'] = x_shift_neg[:Nx_r]
170 float_variables['y_shift_neg_r'] = y_shift_neg[:Ny_r]
171 float_variables['z_shift_neg_r'] = z_shift_neg[:Nz_r]
173 del x_shift_neg
175 float_variables['pml_x'] = None
176 float_variables['pml_y'] = None
177 float_variables['pml_z'] = None
179 float_variables['pml_x_sgx'] = None
180 float_variables['pml_y_sgy'] = None
181 float_variables['pml_z_sgz'] = None
183 float_variables['mpml_x_sgx'] = None
184 float_variables['mpml_y_sgy'] = None
185 float_variables['mpml_z_sgz'] = None
187 float_variables['mpml_x'] = None
188 float_variables['mpml_y'] = None
189 float_variables['mpml_z'] = None
192def grab_axisymmetric_variables(float_variables, values):
193 # single precision variables used within the time loop
194 float_variables['dt'] = values.dt
195 float_variables['c0'] = values.c0
196 float_variables['c_ref'] = values.c_ref
197 float_variables['rho0'] = values.rho0
198 float_variables['rho0_sgx'] = values.rho0_sgx
199 float_variables['rho0_sgy'] = values.rho0_sgy
202def grab_medium_props(integer_variables, float_variables, medium, is_elastic_code):
203 # =========================================================================
204 # VARIABLES USED IN NONLINEAR SIMULATIONS
205 # =========================================================================
206 if medium.is_nonlinear():
207 float_variables['BonA'] = medium.BonA
209 # =========================================================================
210 # VARIABLES USED IN ABSORBING SIMULATIONS
211 # =========================================================================
213 # set absorbing flag
214 if medium.absorbing:
215 integer_variables.absorbing_flag = 2 if medium.stokes else 1
216 else:
217 integer_variables.absorbing_flag = 0
219 if medium.absorbing:
220 if is_elastic_code: # pragma: no cover
221 # add to the variable list
222 float_variables['chi'] = None
223 float_variables['eta'] = None
224 float_variables['eta_sgxy'] = None
225 float_variables['eta_sgxz'] = None
226 float_variables['eta_sgyz'] = None
227 else:
228 float_variables['alpha_coeff'] = medium.alpha_coeff
229 float_variables['alpha_power'] = medium.alpha_power
232def grab_source_props(integer_variables, float_variables, source,
233 u_source_pos_index, s_source_pos_index, p_source_pos_index,
234 transducer_input_signal, delay_mask):
235 # =========================================================================
236 # SOURCE VARIABLES
237 # =========================================================================
238 # source modes and indicies
239 # - these are only defined if the source flgs are > 0
240 # - the source mode describes whether the source will be added or replaced
241 # - the source indicies describe which grid points act as the source
242 # - the u_source_index is reused for any of the u sources and the transducer source
244 grab_velocity_source_props(integer_variables, source, u_source_pos_index)
245 grab_stress_source_props(integer_variables, source, s_source_pos_index)
246 grab_pressure_source_props(integer_variables, source, p_source_pos_index, u_source_pos_index)
247 grab_time_varying_source_props(integer_variables, float_variables, source, transducer_input_signal, delay_mask)
250def grab_velocity_source_props(integer_variables, source, u_source_pos_index):
251 # velocity source
252 if any(integer_variables.get(k) for k in ['ux_source_flag', 'uy_source_flag', 'uz_source_flag']):
253 integer_variables['u_source_mode'] = {
254 'dirichlet': 0,
255 'additive-no-correction': 1,
256 'additive': 2,
257 }[source.u_mode]
259 if integer_variables.ux_source_flag:
260 u_source_many = num_dim2(source.ux) > 1
261 elif integer_variables.uy_source_flag:
262 u_source_many = num_dim2(source.uy) > 1
263 elif integer_variables.uz_source_flag:
264 u_source_many = num_dim2(source.uz) > 1
265 integer_variables['u_source_many'] = u_source_many
267 integer_variables.u_source_index = u_source_pos_index
270def grab_stress_source_props(integer_variables, source, s_source_pos_index):
271 # stress source
272 if integer_variables.sxx_source_flag or integer_variables.syy_source_flag or integer_variables.szz_source_flag \
273 or integer_variables.sxy_source_flag or integer_variables.sxz_source_flag or integer_variables.syz_source_flag:
274 integer_variables.s_source_mode = source.s_mode != 'dirichlet'
275 if integer_variables.sxx_source_flag:
276 s_source_many = num_dim2(source.sxx) > 1
277 elif integer_variables.syy_source_flag:
278 s_source_many = num_dim2(source.syy) > 1
279 elif integer_variables.szz_source_flag:
280 s_source_many = num_dim2(source.szz) > 1
281 elif integer_variables.sxy_source_flag:
282 s_source_many = num_dim2(source.sxy) > 1
283 elif integer_variables.sxz_source_flag:
284 s_source_many = num_dim2(source.sxz) > 1
285 elif integer_variables.syz_source_flag:
286 s_source_many = num_dim2(source.syz) > 1
287 integer_variables.s_source_many = s_source_many
288 integer_variables.s_source_index = s_source_pos_index
291def grab_pressure_source_props(integer_variables, source, p_source_pos_index, u_source_pos_index):
292 # pressure source
293 if integer_variables.p_source_flag:
294 integer_variables.p_source_mode = {
295 'dirichlet': 0,
296 'additive-no-correction': 1,
297 'additive': 2,
298 }[source.p_mode]
299 integer_variables.p_source_many = num_dim2(source.p) > 1
300 integer_variables.p_source_index = p_source_pos_index
302 # transducer source
303 if integer_variables.transducer_source_flag:
304 integer_variables.u_source_index = u_source_pos_index
307def grab_time_varying_source_props(integer_variables, float_variables, source, transducer_input_signal, delay_mask):
308 # time varying source variables
309 # - these are only defined if the source flgs are > 0
310 # - these are the actual source values
311 # - these are indexed as (position_index, time_index)
312 if integer_variables.ux_source_flag:
313 float_variables.ux_source_input = source.ux
315 if integer_variables.uy_source_flag:
316 float_variables.uy_source_input = source.uy
318 if integer_variables.uz_source_flag:
319 float_variables.uz_source_input = source.uz
321 if integer_variables.sxx_source_flag:
322 float_variables.sxx_source_input = source.sxx
324 if integer_variables.syy_source_flag:
325 float_variables.syy_source_input = source.syy
327 if integer_variables.szz_source_flag:
328 float_variables.szz_source_input = source.szz
330 if integer_variables.sxy_source_flag:
331 float_variables.sxy_source_input = source.sxy
333 if integer_variables.sxz_source_flag:
334 float_variables.sxz_source_input = source.sxz
336 if integer_variables.syz_source_flag:
337 float_variables.syz_source_input = source.syz
339 if integer_variables.p_source_flag:
340 float_variables.p_source_input = source.p
342 if integer_variables.transducer_source_flag:
343 float_variables.transducer_source_input = transducer_input_signal
344 integer_variables.delay_mask = delay_mask
346 # initial pressure source variable
347 # - this is only defined if the p0 source flag is 1
348 # - this defines the initial pressure everywhere (there is no indicies)
349 if integer_variables.p0_source_flag:
350 float_variables.p0_source_input = source.p0
353def grab_sensor_props(integer_variables, kgrid_dim, sensor_mask_index, cuboid_corners_list):
354 # =========================================================================
355 # SENSOR VARIABLES
356 # =========================================================================
358 if integer_variables.sensor_mask_type == 0:
359 # mask is defined as a list of grid indices
360 integer_variables.sensor_mask_index = sensor_mask_index
362 elif integer_variables.sensor_mask_type == 1:
364 cuboid_corners_list = cuboid_corners_list
365 # mask is defined as a list of cuboid corners
366 if kgrid_dim == 2:
367 sensor_mask_corners = np.ones((6, cuboid_corners_list.shape[1]))
368 sensor_mask_corners[0, :] = cuboid_corners_list[0, :]
369 sensor_mask_corners[1, :] = cuboid_corners_list[1, :]
370 sensor_mask_corners[3, :] = cuboid_corners_list[2, :]
371 sensor_mask_corners[4, :] = cuboid_corners_list[3, :]
372 else:
373 sensor_mask_corners = cuboid_corners_list
374 integer_variables.sensor_mask_corners = sensor_mask_corners
376 else:
377 raise NotImplementedError('unknown option for sensor_mask_type')
380def grab_nonuniform_grid_props(float_variables, kgrid, is_nonuniform_grid):
381 # =========================================================================
382 # VARIABLES USED FOR NONUNIFORM GRIDS
383 # =========================================================================
385 # set nonuniform flag and variables
386 # - these are only defined if nonuniform_grid_flag is 1
387 # - these are applied using the bsxfun formulation
388 if not is_nonuniform_grid:
389 return
391 dxudxn = kgrid.dudn.x
392 if np.array(dxudxn).size == 1:
393 dxudxn = np.ones((kgrid.Nx, 1))
394 float_variables['dxudxn'] = dxudxn
396 dyudyn = kgrid.dudn.y
397 if np.array(dyudyn).size == 1:
398 dyudyn = np.ones((1, kgrid.Ny))
399 float_variables['dyudyn'] = dyudyn
401 dzudzn = kgrid.dudn.z
402 if np.array(dzudzn).size == 1:
403 dzudzn = np.ones((1, 1, kgrid.Nz))
404 float_variables['dzudzn'] = dzudzn
406 dxudxn_sgx = kgrid.dudn_sg.x
407 if np.array(dxudxn).size == 1:
408 dxudxn_sgx = np.ones((kgrid.Nx, 1))
409 float_variables['dxudxn_sgx'] = dxudxn_sgx
411 dyudyn_sgy = kgrid.dudn_sg.y
412 if np.array(dyudyn).size == 1:
413 dyudyn_sgy = np.ones((1, kgrid.Ny))
414 float_variables['dyudyn_sgy'] = dyudyn_sgy
416 dzudzn_sgz = kgrid.dudn_sg.z
417 if np.array(dzudzn).size == 1:
418 dzudzn_sgz = np.ones((1, 1, kgrid.Nz))
419 float_variables['dzudzn_sgz'] = dzudzn_sgz
422def remove_z_dimension(float_variables, kgrid_dim):
423 # remove z-dimension variables for saving 2D files
424 if kgrid_dim == 2:
425 for k in list(float_variables.keys()):
426 if 'z' in k:
427 del float_variables[k]
430def enforce_filename_standards(filepath):
431 # check for HDF5 filename extension
432 filename_ext = os.path.splitext(filepath)[1]
434 # use .h5 as default if no extension is given
435 if len(filename_ext) == 0:
436 filename_ext = '.h5'
437 filepath = filepath + '.h5'
438 return filepath, filename_ext
441def save_file(filepath, integer_variables, float_variables, hdf_compression_level):
442 filepath, filename_ext = enforce_filename_standards(filepath)
444 # save file
445 if filename_ext == '.h5':
446 save_h5_file(filepath, integer_variables, float_variables, hdf_compression_level)
448 elif filename_ext == '.mat':
449 save_mat_file(filepath, integer_variables, float_variables)
450 else:
451 # throw error for unknown filetype
452 raise NotImplementedError('unknown file extension for ''SaveToDisk'' filename')
455def save_h5_file(filepath, integer_variables, float_variables, hdf_compression_level):
456 # ----------------
457 # SAVE HDF5 FILE
458 # ----------------
460 # check if file exists, and delete if it does (the hdf5 library will
461 # give an error if the file already exists)
462 if os.path.exists(filepath):
463 os.remove(filepath)
465 # change all the variables to be in single precision (float in C++),
466 # then add to HDF5 File
467 for key, value in float_variables.items():
468 # cast matrix to single precision
469 value = np.array(value, dtype=np.float32)
470 write_matrix(filepath, value, key, hdf_compression_level)
471 del value
473 # change all the index variables to be in 64-bit unsigned integers
474 # (long in C++), then add to HDF5 file
475 for key, value in integer_variables.items():
476 # cast matrix to 64-bit unsigned integer
477 value = np.array(value, dtype=np.uint64)
478 write_matrix(filepath, value, key, hdf_compression_level)
479 del value
481 # set additional file attributes
482 write_attributes(filepath, legacy=True) # TODO: update to currently breaking code after references are updated
485def save_mat_file(filepath, integer_variables, float_variables):
486 # ----------------
487 # SAVE .MAT FILE
488 # ----------------
490 # change all the variables to be in single precision (float in C++)
491 for key, value in float_variables.items():
492 float_variables[key] = np.array(value, dtype=np.float32)
494 for key, value in integer_variables.items():
495 integer_variables[key] = np.array(value, dtype=np.uin64)
497 # save the input variables to disk as a MATLAB binary file
498 float_variables = dict(**float_variables, **integer_variables)
499 savemat(filepath, float_variables)