Coverage for kwave/utils/ioutils.py: 65%
173 statements
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
« prev ^ index » next coverage.py v6.5.0, created at 2022-10-24 11:55 -0700
1import os
2import platform
3import socket
4import warnings
5from datetime import datetime
7import cv2
8import h5py
9import numpy as np
11from .misc import get_date_string
12from .conversionutils import cast_to_type
13from kwave.utils import dotdict
16def get_h5_literals():
17 literals = dotdict({
18 # data type
19 'DATA_TYPE_ATT_NAME': 'data_type',
20 'MATRIX_DATA_TYPE_MATLAB': 'single',
21 'MATRIX_DATA_TYPE_C': 'float',
22 'INTEGER_DATA_TYPE_MATLAB': 'uint64',
23 'INTEGER_DATA_TYPE_C': 'long',
25 # real / complex
26 'DOMAIN_TYPE_ATT_NAME': 'domain_type',
27 'DOMAIN_TYPE_REAL': 'real',
28 'DOMAIN_TYPE_COMPLEX': 'complex',
30 # file descriptors
31 'FILE_MAJOR_VER_ATT_NAME': 'major_version',
32 'FILE_MINOR_VER_ATT_NAME': 'minor_version',
33 'FILE_DESCR_ATT_NAME': 'file_description',
34 'FILE_CREATION_DATE_ATT_NAME': 'creation_date',
35 'CREATED_BY_ATT_NAME': 'created_by',
37 # file type
38 'FILE_TYPE_ATT_NAME': 'file_type',
39 'HDF_INPUT_FILE': 'input',
40 'HDF_OUTPUT_FILE': 'output',
41 'HDF_CHECKPOINT_FILE': 'checkpoint',
43 # file version information
44 'HDF_FILE_MAJOR_VERSION': '1',
45 'HDF_FILE_MINOR_VERSION': '2',
47 # compression level
48 'HDF_COMPRESSION_LEVEL': 0
49 })
50 return literals
53def write_matrix(filename, matrix: np.ndarray, matrix_name, compression_level=None):
54 # get literals
55 h5_literals = get_h5_literals()
57 if compression_level is None:
58 compression_level = h5_literals.HDF_COMPRESSION_LEVEL
60 # dims = num_dim(matrix)
61 dims = len(matrix.shape)
63 if dims == 3:
64 matrix = np.transpose(matrix, [2, 1, 0]) # C <=> Fortran ordering
65 if dims == 2:
66 matrix = np.transpose(matrix) # C <=> Fortran ordering
68 # get the size of the input matrix
69 if dims == 3:
70 Nx, Ny, Nz = matrix.shape
71 elif dims == 2:
72 Ny, Nz = matrix.shape
73 Nx = 1
74 else:
75 Nx, Ny, Nz = 1, 1, 1
77 # check size of matrix and set chunk size and compression level
78 if dims == 3:
79 # set chunk size to Nx * Ny
80 chunk_size = [Nx, Ny, 1]
81 elif dims == 2:
82 # set chunk size to Nx
83 chunk_size = [Nx, 1, 1]
84 elif dims <= 1:
85 # check that the matrix size is greater than 1 MB
86 one_mb = (1024**2) / 8
87 if matrix.size > one_mb:
88 # set chunk size to 1 MB
89 if Nx > Ny:
90 chunk_size = [one_mb, 1, 1]
91 elif Ny > Nz:
92 chunk_size = [1, one_mb, 1]
93 else:
94 chunk_size = [1, 1, one_mb]
95 else:
97 # set no compression
98 compression_level = 0
100 # set chunk size to grid size
101 if matrix.size == 1:
102 chunk_size = (1, 1, 1)
103 elif Nx > Ny:
104 chunk_size = (Nx, 1, 1)
105 elif Ny > Nz:
106 chunk_size = (1, Ny, 1)
107 else:
108 chunk_size = (1, 1, Nz)
109 else:
110 # throw error for unknown matrix size
111 raise ValueError('Input matrix must have 1, 2 or 3 dimensions.')
113 # check the format of the matrix is either single precision (float in C++)
114 # or uint64 (unsigned long in C++)
115 if matrix.dtype == np.float32:
116 # set data type flags
117 data_type_matlab = h5_literals.MATRIX_DATA_TYPE_MATLAB
118 data_type_c = h5_literals.MATRIX_DATA_TYPE_C
119 elif matrix.dtype == np.uint64:
121 # set data type flags
122 data_type_matlab = h5_literals.INTEGER_DATA_TYPE_MATLAB
123 data_type_c = h5_literals.INTEGER_DATA_TYPE_C
125 else:
126 # throw error for unknown data type
127 raise ValueError('Input matrix must be of type ''single'' or ''uint64''.')
129 # check if the input matrix is real or complex, if complex, rearrange the
130 # data in the C++ format
131 if np.isreal(matrix).all():
133 # set file tag
134 domain_type = 'real' # DOMAIN_TYPE_REAL
136 elif dims == 3:
138 # set file tag
139 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX
141 # rearrange the data so the real and imaginary parts are stored in the
142 # same matrix
143 matrix = np.concatenate(matrix.real, matrix.imag, axis=0)
144 matrix = matrix.reshape((Nx, 2, Ny, Nz))
145 matrix = np.transpose(matrix, (1, 0, 2, 3))
146 matrix = matrix.reshape((2 * Nx, Ny, Nz))
148 # update the size of Nx
149 Nx = 2 * Nx
151 elif dims <= 1:
153 # set file tag
154 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX
156 # rearrange the data so the real and imaginary parts are stored in the
157 # same matrix
158 nelems = matrix.size
159 matrix = matrix.reshape((nelems, 1))
160 matrix = np.concatenate(matrix.real, matrix.imag, axis=0)
161 matrix = matrix.reshape((nelems, 2, 1, 1))
162 matrix = np.transpose(matrix, (1, 0, 2, 3))
164 # update the matrix size
165 Nx = Nx * (2 - np.array(Nx == 1).astype(np.float))
166 Ny = Ny * (2 - np.array(Ny == 1).astype(np.float))
167 Nz = Nz * (2 - np.array(Nz == 1).astype(np.float))
169 # double store in x-direction if a complex scalar
170 if Nx == 1 and Ny == 1 and Nz == 1:
171 Nx = 2 * Nx
173 # put in correct dimension
174 matrix = matrix.reshape((Nx, Ny, Nz))
176 else:
177 raise NotImplementedError('Currently there is no support for saving 2D complex matrices.')
179 # allocate a holder for the new matrix within the file
180 opts = {
181 'dtype': data_type_matlab,
182 'chunks': tuple(chunk_size)
183 }
184 if compression_level != 0:
185 # use compression
186 opts['compression'] = compression_level
188 # write the matrix into the file
189 with h5py.File(filename, "a") as f:
190 f.create_dataset(f'/{matrix_name}', [Nx, Ny, Nz], data=matrix, **opts)
192 # set attributes for the matrix (used by k-Wave++)
193 assign_str_attr(f[f'/{matrix_name}'].attrs, h5_literals.DOMAIN_TYPE_ATT_NAME, domain_type)
194 assign_str_attr(f[f'/{matrix_name}'].attrs, h5_literals.DATA_TYPE_ATT_NAME, data_type_c)
197def write_attributes_typed(filename, file_description=None):
198 # get literals
199 h5_literals = get_h5_literals()
201 # get computer infor
202 comp_info = dotdict({
203 'date': datetime.now().strftime("%d-%b-%Y"),
204 'computer_name': socket.gethostname(),
205 'operating_system_type': platform.system(),
206 'operating_system': platform.system() + " " + platform.release() + " " + platform.version(),
207 'user_name': os.environ.get('USERNAME'),
208 'matlab_version': 'N/A',
209 'kwave_version': '1.3',
210 'kwave_path': 'N/A',
211 })
213 # set file description if not provided by user
214 if file_description is None:
215 file_description = f'Input data created by {comp_info.user_name} running MATLAB ' \
216 f'{comp_info.matlab_version} on {comp_info.operating_system_type}'
218 # set additional file attributes
219 with h5py.File(filename, "a") as f:
220 f[h5_literals.FILE_MAJOR_VER_ATT_NAME] = h5_literals.HDF_FILE_MAJOR_VERSION
221 f[h5_literals.FILE_MINOR_VER_ATT_NAME] = h5_literals.HDF_FILE_MINOR_VERSION
222 f[h5_literals.CREATED_BY_ATT_NAME] = f'k-Wave 1.3'
223 f[h5_literals.FILE_DESCR_ATT_NAME] = file_description
224 f[h5_literals.FILE_TYPE_ATT_NAME] = h5_literals.HDF_INPUT_FILE
225 f[h5_literals.FILE_CREATION_DATE_ATT_NAME] = get_date_string()
228def write_attributes(filename, file_description=None, legacy=False):
230 if not legacy:
231 write_attributes_typed(filename, file_description)
232 return
234 warnings.warn("Attributes will soon be typed when saved and not saved ", DeprecationWarning)
235 # get literals
236 h5_literals = get_h5_literals()
238 # get computer infor
239 comp_info = dotdict({
240 'date': datetime.now().strftime("%d-%b-%Y"),
241 'computer_name': socket.gethostname(),
242 'operating_system_type': platform.system(),
243 'operating_system': platform.system() + " " + platform.release() + " " + platform.version(),
244 'user_name': os.environ.get('USERNAME'),
245 'matlab_version': 'N/A',
246 'kwave_version': '1.3',
247 'kwave_path': 'N/A',
248 })
250 # set file description if not provided by user
251 if file_description is None:
252 file_description = f'Input data created by {comp_info.user_name} running MATLAB ' \
253 f'{comp_info.matlab_version} on {comp_info.operating_system_type}'
255 # set additional file attributes
256 with h5py.File(filename, "a") as f:
257 assign_str_attr(f.attrs, h5_literals.FILE_MAJOR_VER_ATT_NAME, h5_literals.HDF_FILE_MAJOR_VERSION)
258 assign_str_attr(f.attrs, h5_literals.FILE_MINOR_VER_ATT_NAME, h5_literals.HDF_FILE_MINOR_VERSION)
259 assign_str_attr(f.attrs, h5_literals.CREATED_BY_ATT_NAME, f'k-Wave N/A')
260 assign_str_attr(f.attrs, h5_literals.FILE_DESCR_ATT_NAME, file_description)
261 assign_str_attr(f.attrs, h5_literals.FILE_TYPE_ATT_NAME, h5_literals.HDF_INPUT_FILE)
262 assign_str_attr(f.attrs, h5_literals.FILE_CREATION_DATE_ATT_NAME, get_date_string())
265def write_flags(filename):
266 """
267 % writeFlags reads the input HDF5 file and derives and writes the
268 % required source and medium flags based on the datasets present in the
269 % file. For example, if the file contains a data set named 'BonA', the
270 % nonlinear_flag will be written as true. Conditional flags are also
271 % written. The source mode flags are written when appropriate if they
272 % are not already present in the file. The default source mode is
273 % 'additive'.
274 %
275 % List of flags that are always written
276 % ux_source_flag
277 % uy_source_flag
278 % uz_source_flag
279 % sxx_source_flag
280 % sxy_source_flag
281 % sxz_source_flag
282 % syy_source_flag
283 % syz_source_flag
284 % szz_source_flag
285 % p_source_flag
286 % p0_source_flag
287 % transducer_source_flag
288 % nonuniform_grid_flag
289 % nonlinear_flag
290 % absorbing_flag
291 % axisymmetric_flag
292 % elastic_flag
293 % sensor_mask_type
294 %
295 % List of conditional flags
296 % u_source_mode
297 % u_source_many
298 % p_source_mode
299 % p_source_many
300 % s_source_mode
301 % s_source_many
302 Args:
303 filename:
305 Returns:
307 """
308 h5_literals = get_h5_literals()
310 with h5py.File(filename, 'r') as hf:
311 names = hf.keys()
313 v_list = [
314 ('ux_source', 'u_source_many'),
315 ('uy_source', 'u_source_many'),
316 ('uz_source', 'u_source_many'),
318 ('sxx_source', 's_source_many'),
319 ('syy_source', 's_source_many'),
320 ('szz_source', 's_source_many'),
321 ('sxy_source', 's_source_many'),
322 ('sxz_source', 's_source_many'),
323 ('syz_source', 's_source_many'),
325 ('p_source', 'p_source_many')
326 ]
327 variable_list = {}
328 for prefix, many_flag_key in v_list:
329 inp_name = f'{prefix}_input'
330 flag_name = f'{prefix}_flag'
331 if inp_name in names:
332 variable_list[flag_name] = hf[inp_name].shape[1]
334 variable_list[many_flag_key] = hf[inp_name].shape[0] != 1
335 else:
336 variable_list[flag_name] = 0
338 # --------------------
339 # u source
340 # --------------------
342 # write u_source mode if not already in file (1 is Additive, 0 is Dirichlet)
343 if any(variable_list[flag] for flag in ['ux_source_flag', 'uy_source_flag', 'uz_source_flag']) \
344 and 'u_source_mode' not in names:
345 variable_list['u_source_mode'] = 1
347 # --------------------
348 # s source
349 # --------------------
351 # write s_source mode if not already in file (1 is Additive, 0 is Dirichlet)
352 if any(variable_list[flag] for flag in ['sxx_source_flag', 'syy_source_flag', 'szz_source_flag',
353 'sxy_source_flag', 'sxz_source_flag', 'syz_source_flag']) \
354 and 's_source_mode' not in names:
355 variable_list['s_source_mode'] = 1
357 # --------------------
358 # p source
359 # --------------------
361 # write p_source mode if not already in file (1 is Additive, 0 is Dirichlet)
362 if any(variable_list[flag] for flag in ['p_source_flag']) \
363 and 'p_source_mode' not in names:
364 variable_list['p_source_mode'] = 1
366 # check for p0_source_input and set p0_source_flag
367 variable_list['p0_source_flag'] = 'p0_source_input' in names
369 # --------------------
370 # additional flags
371 # --------------------
372 # check for transducer_source_input and set transducer_source_flag
373 variable_list['transducer_source_flag'] = 'transducer_source_input' in names
375 # check for BonA and set nonlinear flag
376 variable_list['nonlinear_flag'] = 'BonA' in names
378 # check for alpha_coeff and set absorbing flag
379 variable_list['absorbing_flag'] = 'alpha_coeff' in names
381 # check for lambda and set elastic flag
382 variable_list['elastic_flag'] = 'lambda' in names
384 # set axisymmetric grid flag to false
385 variable_list['axisymmetric_flag'] = 0
387 # set nonuniform grid flag to false
388 variable_list['nonuniform_grid_flag'] = 0
390 # check for sensor_mask_index and sensor_mask_corners
391 if 'sensor_mask_index' in names:
392 variable_list['sensor_mask_type'] = 0
393 elif 'sensor_mask_corners' in names:
394 variable_list['sensor_mask_type'] = 1
395 else:
396 raise ValueError('Either sensor_mask_index or sensor_mask_corners must be defined in the input file')
398 # --------------------
399 # write flags to file
400 # --------------------
402 # change all the index variables to be in 64-bit unsigned integers (long in C++) and write to file
403 for key, value in variable_list.items():
404 # cast matrix to 64-bit unsigned integer
405 value = np.array(value, dtype=np.uint64)
406 write_matrix(filename, value, key)
407 del value
410def write_grid(filename, grid_size, grid_spacing, pml_size, pml_alpha, Nt, dt, c_ref):
411 """
414 %
415 % DESCRIPTION:
416 % writeGrid creates and writes the wavenumber grids and PML variables
417 % required by the k-Wave C++ code to the HDF5 file specified by the
418 % user.
419 %
420 % List of parameters that are written:
421 % Nx
422 % Ny
423 % Nz
424 % Nt
425 % dt
426 % dx
427 % dy
428 % dz
429 % c_ref
430 % pml_x_alpha
431 % pml_y_alpha
432 % pml_z_alpha
433 % pml_x_size
434 % pml_y_size
435 % pml_z_size
436 %
437 Returns:
439 """
440 h5_literals = get_h5_literals()
442 # =========================================================================
443 # STORE FLOATS
444 # =========================================================================
445 variable_list = {
446 'dt': dt,
447 'dx': grid_spacing[0],
448 'dy': grid_spacing[1],
449 'dz': grid_spacing[2],
450 'pml_x_alpha': pml_alpha[0],
451 'pml_y_alpha': pml_alpha[1],
452 'pml_z_alpha': pml_alpha[2],
453 'c_ref': c_ref
454 }
456 # change float variables to be in single precision (float in C++), then add to HDF5 file
457 for key, value in variable_list.items():
458 # cast matrix to single precision
459 value = cast_to_type(value, h5_literals.MATRIX_DATA_TYPE_MATLAB)
460 write_matrix(filename, value, key)
461 del value
463 # =========================================================================
464 # STORE INTEGERS
465 # =========================================================================
467 # integer variables
468 variable_list = {
469 'Nx': grid_size[0],
470 'Ny': grid_size[1],
471 'Nz': grid_size[2],
472 'Nt': Nt,
473 'pml_x_size': pml_size[0],
474 'pml_y_size': pml_size[1],
475 'pml_z_size': pml_size[2]
476 }
478 # change all the index variables to be in 64-bit unsigned integers (long in C++)
479 for key, value in variable_list.items():
480 # cast matrix to 64-bit unsigned integer
481 value = cast_to_type(value, h5_literals.INTEGER_DATA_TYPE_MATLAB)
482 write_matrix(filename, value, key)
483 del value
486def assign_str_attr(attrs, attr_name, attr_val):
487 """
488 Assigns HDF5 attribute with value as a fixed-length string
489 Args:
490 attrs:
491 attr_name:
492 attr_val:
494 Returns:
496 """
497 attrs.create(attr_name, attr_val, None, dtype=f'<S{len(attr_val)}')
500def loadImage(path, is_gray):
501 if is_gray:
502 img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
503 else:
504 img = cv2.imread(path, cv2.IMREAD_COLOR)
505 raise NotImplementedError
506 # im = squeeze(double(im(:, :, 1)) + double(im(:, :, 2)) + double(im(:, :, 3)));
507 img = img.astype(float)
509 # scale pixel values from 0 -> 1
510 img = img.max() - img
511 img = img * (1 / img.max())
512 return img