Coverage for kwave/utils/ioutils.py: 65%

173 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-24 11:55 -0700

1import os 

2import platform 

3import socket 

4import warnings 

5from datetime import datetime 

6 

7import cv2 

8import h5py 

9import numpy as np 

10 

11from .misc import get_date_string 

12from .conversionutils import cast_to_type 

13from kwave.utils import dotdict 

14 

15 

16def get_h5_literals(): 

17 literals = dotdict({ 

18 # data type 

19 'DATA_TYPE_ATT_NAME': 'data_type', 

20 'MATRIX_DATA_TYPE_MATLAB': 'single', 

21 'MATRIX_DATA_TYPE_C': 'float', 

22 'INTEGER_DATA_TYPE_MATLAB': 'uint64', 

23 'INTEGER_DATA_TYPE_C': 'long', 

24 

25 # real / complex 

26 'DOMAIN_TYPE_ATT_NAME': 'domain_type', 

27 'DOMAIN_TYPE_REAL': 'real', 

28 'DOMAIN_TYPE_COMPLEX': 'complex', 

29 

30 # file descriptors 

31 'FILE_MAJOR_VER_ATT_NAME': 'major_version', 

32 'FILE_MINOR_VER_ATT_NAME': 'minor_version', 

33 'FILE_DESCR_ATT_NAME': 'file_description', 

34 'FILE_CREATION_DATE_ATT_NAME': 'creation_date', 

35 'CREATED_BY_ATT_NAME': 'created_by', 

36 

37 # file type 

38 'FILE_TYPE_ATT_NAME': 'file_type', 

39 'HDF_INPUT_FILE': 'input', 

40 'HDF_OUTPUT_FILE': 'output', 

41 'HDF_CHECKPOINT_FILE': 'checkpoint', 

42 

43 # file version information 

44 'HDF_FILE_MAJOR_VERSION': '1', 

45 'HDF_FILE_MINOR_VERSION': '2', 

46 

47 # compression level 

48 'HDF_COMPRESSION_LEVEL': 0 

49 }) 

50 return literals 

51 

52 

53def write_matrix(filename, matrix: np.ndarray, matrix_name, compression_level=None): 

54 # get literals 

55 h5_literals = get_h5_literals() 

56 

57 if compression_level is None: 

58 compression_level = h5_literals.HDF_COMPRESSION_LEVEL 

59 

60 # dims = num_dim(matrix) 

61 dims = len(matrix.shape) 

62 

63 if dims == 3: 

64 matrix = np.transpose(matrix, [2, 1, 0]) # C <=> Fortran ordering 

65 if dims == 2: 

66 matrix = np.transpose(matrix) # C <=> Fortran ordering 

67 

68 # get the size of the input matrix 

69 if dims == 3: 

70 Nx, Ny, Nz = matrix.shape 

71 elif dims == 2: 

72 Ny, Nz = matrix.shape 

73 Nx = 1 

74 else: 

75 Nx, Ny, Nz = 1, 1, 1 

76 

77 # check size of matrix and set chunk size and compression level 

78 if dims == 3: 

79 # set chunk size to Nx * Ny 

80 chunk_size = [Nx, Ny, 1] 

81 elif dims == 2: 

82 # set chunk size to Nx 

83 chunk_size = [Nx, 1, 1] 

84 elif dims <= 1: 

85 # check that the matrix size is greater than 1 MB 

86 one_mb = (1024**2) / 8 

87 if matrix.size > one_mb: 

88 # set chunk size to 1 MB 

89 if Nx > Ny: 

90 chunk_size = [one_mb, 1, 1] 

91 elif Ny > Nz: 

92 chunk_size = [1, one_mb, 1] 

93 else: 

94 chunk_size = [1, 1, one_mb] 

95 else: 

96 

97 # set no compression 

98 compression_level = 0 

99 

100 # set chunk size to grid size 

101 if matrix.size == 1: 

102 chunk_size = (1, 1, 1) 

103 elif Nx > Ny: 

104 chunk_size = (Nx, 1, 1) 

105 elif Ny > Nz: 

106 chunk_size = (1, Ny, 1) 

107 else: 

108 chunk_size = (1, 1, Nz) 

109 else: 

110 # throw error for unknown matrix size 

111 raise ValueError('Input matrix must have 1, 2 or 3 dimensions.') 

112 

113 # check the format of the matrix is either single precision (float in C++) 

114 # or uint64 (unsigned long in C++) 

115 if matrix.dtype == np.float32: 

116 # set data type flags 

117 data_type_matlab = h5_literals.MATRIX_DATA_TYPE_MATLAB 

118 data_type_c = h5_literals.MATRIX_DATA_TYPE_C 

119 elif matrix.dtype == np.uint64: 

120 

121 # set data type flags 

122 data_type_matlab = h5_literals.INTEGER_DATA_TYPE_MATLAB 

123 data_type_c = h5_literals.INTEGER_DATA_TYPE_C 

124 

125 else: 

126 # throw error for unknown data type 

127 raise ValueError('Input matrix must be of type ''single'' or ''uint64''.') 

128 

129 # check if the input matrix is real or complex, if complex, rearrange the 

130 # data in the C++ format 

131 if np.isreal(matrix).all(): 

132 

133 # set file tag 

134 domain_type = 'real' # DOMAIN_TYPE_REAL 

135 

136 elif dims == 3: 

137 

138 # set file tag 

139 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX 

140 

141 # rearrange the data so the real and imaginary parts are stored in the 

142 # same matrix 

143 matrix = np.concatenate(matrix.real, matrix.imag, axis=0) 

144 matrix = matrix.reshape((Nx, 2, Ny, Nz)) 

145 matrix = np.transpose(matrix, (1, 0, 2, 3)) 

146 matrix = matrix.reshape((2 * Nx, Ny, Nz)) 

147 

148 # update the size of Nx 

149 Nx = 2 * Nx 

150 

151 elif dims <= 1: 

152 

153 # set file tag 

154 domain_type = h5_literals.DOMAIN_TYPE_COMPLEX 

155 

156 # rearrange the data so the real and imaginary parts are stored in the 

157 # same matrix 

158 nelems = matrix.size 

159 matrix = matrix.reshape((nelems, 1)) 

160 matrix = np.concatenate(matrix.real, matrix.imag, axis=0) 

161 matrix = matrix.reshape((nelems, 2, 1, 1)) 

162 matrix = np.transpose(matrix, (1, 0, 2, 3)) 

163 

164 # update the matrix size 

165 Nx = Nx * (2 - np.array(Nx == 1).astype(np.float)) 

166 Ny = Ny * (2 - np.array(Ny == 1).astype(np.float)) 

167 Nz = Nz * (2 - np.array(Nz == 1).astype(np.float)) 

168 

169 # double store in x-direction if a complex scalar 

170 if Nx == 1 and Ny == 1 and Nz == 1: 

171 Nx = 2 * Nx 

172 

173 # put in correct dimension 

174 matrix = matrix.reshape((Nx, Ny, Nz)) 

175 

176 else: 

177 raise NotImplementedError('Currently there is no support for saving 2D complex matrices.') 

178 

179 # allocate a holder for the new matrix within the file 

180 opts = { 

181 'dtype': data_type_matlab, 

182 'chunks': tuple(chunk_size) 

183 } 

184 if compression_level != 0: 

185 # use compression 

186 opts['compression'] = compression_level 

187 

188 # write the matrix into the file 

189 with h5py.File(filename, "a") as f: 

190 f.create_dataset(f'/{matrix_name}', [Nx, Ny, Nz], data=matrix, **opts) 

191 

192 # set attributes for the matrix (used by k-Wave++) 

193 assign_str_attr(f[f'/{matrix_name}'].attrs, h5_literals.DOMAIN_TYPE_ATT_NAME, domain_type) 

194 assign_str_attr(f[f'/{matrix_name}'].attrs, h5_literals.DATA_TYPE_ATT_NAME, data_type_c) 

195 

196 

197def write_attributes_typed(filename, file_description=None): 

198 # get literals 

199 h5_literals = get_h5_literals() 

200 

201 # get computer infor 

202 comp_info = dotdict({ 

203 'date': datetime.now().strftime("%d-%b-%Y"), 

204 'computer_name': socket.gethostname(), 

205 'operating_system_type': platform.system(), 

206 'operating_system': platform.system() + " " + platform.release() + " " + platform.version(), 

207 'user_name': os.environ.get('USERNAME'), 

208 'matlab_version': 'N/A', 

209 'kwave_version': '1.3', 

210 'kwave_path': 'N/A', 

211 }) 

212 

213 # set file description if not provided by user 

214 if file_description is None: 

215 file_description = f'Input data created by {comp_info.user_name} running MATLAB ' \ 

216 f'{comp_info.matlab_version} on {comp_info.operating_system_type}' 

217 

218 # set additional file attributes 

219 with h5py.File(filename, "a") as f: 

220 f[h5_literals.FILE_MAJOR_VER_ATT_NAME] = h5_literals.HDF_FILE_MAJOR_VERSION 

221 f[h5_literals.FILE_MINOR_VER_ATT_NAME] = h5_literals.HDF_FILE_MINOR_VERSION 

222 f[h5_literals.CREATED_BY_ATT_NAME] = f'k-Wave 1.3' 

223 f[h5_literals.FILE_DESCR_ATT_NAME] = file_description 

224 f[h5_literals.FILE_TYPE_ATT_NAME] = h5_literals.HDF_INPUT_FILE 

225 f[h5_literals.FILE_CREATION_DATE_ATT_NAME] = get_date_string() 

226 

227 

228def write_attributes(filename, file_description=None, legacy=False): 

229 

230 if not legacy: 

231 write_attributes_typed(filename, file_description) 

232 return 

233 

234 warnings.warn("Attributes will soon be typed when saved and not saved ", DeprecationWarning) 

235 # get literals 

236 h5_literals = get_h5_literals() 

237 

238 # get computer infor 

239 comp_info = dotdict({ 

240 'date': datetime.now().strftime("%d-%b-%Y"), 

241 'computer_name': socket.gethostname(), 

242 'operating_system_type': platform.system(), 

243 'operating_system': platform.system() + " " + platform.release() + " " + platform.version(), 

244 'user_name': os.environ.get('USERNAME'), 

245 'matlab_version': 'N/A', 

246 'kwave_version': '1.3', 

247 'kwave_path': 'N/A', 

248 }) 

249 

250 # set file description if not provided by user 

251 if file_description is None: 

252 file_description = f'Input data created by {comp_info.user_name} running MATLAB ' \ 

253 f'{comp_info.matlab_version} on {comp_info.operating_system_type}' 

254 

255 # set additional file attributes 

256 with h5py.File(filename, "a") as f: 

257 assign_str_attr(f.attrs, h5_literals.FILE_MAJOR_VER_ATT_NAME, h5_literals.HDF_FILE_MAJOR_VERSION) 

258 assign_str_attr(f.attrs, h5_literals.FILE_MINOR_VER_ATT_NAME, h5_literals.HDF_FILE_MINOR_VERSION) 

259 assign_str_attr(f.attrs, h5_literals.CREATED_BY_ATT_NAME, f'k-Wave N/A') 

260 assign_str_attr(f.attrs, h5_literals.FILE_DESCR_ATT_NAME, file_description) 

261 assign_str_attr(f.attrs, h5_literals.FILE_TYPE_ATT_NAME, h5_literals.HDF_INPUT_FILE) 

262 assign_str_attr(f.attrs, h5_literals.FILE_CREATION_DATE_ATT_NAME, get_date_string()) 

263 

264 

265def write_flags(filename): 

266 """ 

267 % writeFlags reads the input HDF5 file and derives and writes the 

268 % required source and medium flags based on the datasets present in the 

269 % file. For example, if the file contains a data set named 'BonA', the 

270 % nonlinear_flag will be written as true. Conditional flags are also 

271 % written. The source mode flags are written when appropriate if they 

272 % are not already present in the file. The default source mode is 

273 % 'additive'. 

274 % 

275 % List of flags that are always written 

276 % ux_source_flag 

277 % uy_source_flag 

278 % uz_source_flag 

279 % sxx_source_flag 

280 % sxy_source_flag 

281 % sxz_source_flag 

282 % syy_source_flag 

283 % syz_source_flag 

284 % szz_source_flag 

285 % p_source_flag 

286 % p0_source_flag 

287 % transducer_source_flag 

288 % nonuniform_grid_flag 

289 % nonlinear_flag 

290 % absorbing_flag 

291 % axisymmetric_flag 

292 % elastic_flag 

293 % sensor_mask_type 

294 % 

295 % List of conditional flags 

296 % u_source_mode 

297 % u_source_many 

298 % p_source_mode 

299 % p_source_many 

300 % s_source_mode 

301 % s_source_many 

302 Args: 

303 filename: 

304 

305 Returns: 

306 

307 """ 

308 h5_literals = get_h5_literals() 

309 

310 with h5py.File(filename, 'r') as hf: 

311 names = hf.keys() 

312 

313 v_list = [ 

314 ('ux_source', 'u_source_many'), 

315 ('uy_source', 'u_source_many'), 

316 ('uz_source', 'u_source_many'), 

317 

318 ('sxx_source', 's_source_many'), 

319 ('syy_source', 's_source_many'), 

320 ('szz_source', 's_source_many'), 

321 ('sxy_source', 's_source_many'), 

322 ('sxz_source', 's_source_many'), 

323 ('syz_source', 's_source_many'), 

324 

325 ('p_source', 'p_source_many') 

326 ] 

327 variable_list = {} 

328 for prefix, many_flag_key in v_list: 

329 inp_name = f'{prefix}_input' 

330 flag_name = f'{prefix}_flag' 

331 if inp_name in names: 

332 variable_list[flag_name] = hf[inp_name].shape[1] 

333 

334 variable_list[many_flag_key] = hf[inp_name].shape[0] != 1 

335 else: 

336 variable_list[flag_name] = 0 

337 

338 # -------------------- 

339 # u source 

340 # -------------------- 

341 

342 # write u_source mode if not already in file (1 is Additive, 0 is Dirichlet) 

343 if any(variable_list[flag] for flag in ['ux_source_flag', 'uy_source_flag', 'uz_source_flag']) \ 

344 and 'u_source_mode' not in names: 

345 variable_list['u_source_mode'] = 1 

346 

347 # -------------------- 

348 # s source 

349 # -------------------- 

350 

351 # write s_source mode if not already in file (1 is Additive, 0 is Dirichlet) 

352 if any(variable_list[flag] for flag in ['sxx_source_flag', 'syy_source_flag', 'szz_source_flag', 

353 'sxy_source_flag', 'sxz_source_flag', 'syz_source_flag']) \ 

354 and 's_source_mode' not in names: 

355 variable_list['s_source_mode'] = 1 

356 

357 # -------------------- 

358 # p source 

359 # -------------------- 

360 

361 # write p_source mode if not already in file (1 is Additive, 0 is Dirichlet) 

362 if any(variable_list[flag] for flag in ['p_source_flag']) \ 

363 and 'p_source_mode' not in names: 

364 variable_list['p_source_mode'] = 1 

365 

366 # check for p0_source_input and set p0_source_flag 

367 variable_list['p0_source_flag'] = 'p0_source_input' in names 

368 

369 # -------------------- 

370 # additional flags 

371 # -------------------- 

372 # check for transducer_source_input and set transducer_source_flag 

373 variable_list['transducer_source_flag'] = 'transducer_source_input' in names 

374 

375 # check for BonA and set nonlinear flag 

376 variable_list['nonlinear_flag'] = 'BonA' in names 

377 

378 # check for alpha_coeff and set absorbing flag 

379 variable_list['absorbing_flag'] = 'alpha_coeff' in names 

380 

381 # check for lambda and set elastic flag 

382 variable_list['elastic_flag'] = 'lambda' in names 

383 

384 # set axisymmetric grid flag to false 

385 variable_list['axisymmetric_flag'] = 0 

386 

387 # set nonuniform grid flag to false 

388 variable_list['nonuniform_grid_flag'] = 0 

389 

390 # check for sensor_mask_index and sensor_mask_corners 

391 if 'sensor_mask_index' in names: 

392 variable_list['sensor_mask_type'] = 0 

393 elif 'sensor_mask_corners' in names: 

394 variable_list['sensor_mask_type'] = 1 

395 else: 

396 raise ValueError('Either sensor_mask_index or sensor_mask_corners must be defined in the input file') 

397 

398 # -------------------- 

399 # write flags to file 

400 # -------------------- 

401 

402 # change all the index variables to be in 64-bit unsigned integers (long in C++) and write to file 

403 for key, value in variable_list.items(): 

404 # cast matrix to 64-bit unsigned integer 

405 value = np.array(value, dtype=np.uint64) 

406 write_matrix(filename, value, key) 

407 del value 

408 

409 

410def write_grid(filename, grid_size, grid_spacing, pml_size, pml_alpha, Nt, dt, c_ref): 

411 """ 

412 

413 

414 % 

415 % DESCRIPTION: 

416 % writeGrid creates and writes the wavenumber grids and PML variables 

417 % required by the k-Wave C++ code to the HDF5 file specified by the 

418 % user. 

419 % 

420 % List of parameters that are written: 

421 % Nx 

422 % Ny 

423 % Nz 

424 % Nt 

425 % dt 

426 % dx 

427 % dy 

428 % dz 

429 % c_ref 

430 % pml_x_alpha 

431 % pml_y_alpha 

432 % pml_z_alpha 

433 % pml_x_size 

434 % pml_y_size 

435 % pml_z_size 

436 % 

437 Returns: 

438 

439 """ 

440 h5_literals = get_h5_literals() 

441 

442 # ========================================================================= 

443 # STORE FLOATS 

444 # ========================================================================= 

445 variable_list = { 

446 'dt': dt, 

447 'dx': grid_spacing[0], 

448 'dy': grid_spacing[1], 

449 'dz': grid_spacing[2], 

450 'pml_x_alpha': pml_alpha[0], 

451 'pml_y_alpha': pml_alpha[1], 

452 'pml_z_alpha': pml_alpha[2], 

453 'c_ref': c_ref 

454 } 

455 

456 # change float variables to be in single precision (float in C++), then add to HDF5 file 

457 for key, value in variable_list.items(): 

458 # cast matrix to single precision 

459 value = cast_to_type(value, h5_literals.MATRIX_DATA_TYPE_MATLAB) 

460 write_matrix(filename, value, key) 

461 del value 

462 

463 # ========================================================================= 

464 # STORE INTEGERS 

465 # ========================================================================= 

466 

467 # integer variables 

468 variable_list = { 

469 'Nx': grid_size[0], 

470 'Ny': grid_size[1], 

471 'Nz': grid_size[2], 

472 'Nt': Nt, 

473 'pml_x_size': pml_size[0], 

474 'pml_y_size': pml_size[1], 

475 'pml_z_size': pml_size[2] 

476 } 

477 

478 # change all the index variables to be in 64-bit unsigned integers (long in C++) 

479 for key, value in variable_list.items(): 

480 # cast matrix to 64-bit unsigned integer 

481 value = cast_to_type(value, h5_literals.INTEGER_DATA_TYPE_MATLAB) 

482 write_matrix(filename, value, key) 

483 del value 

484 

485 

486def assign_str_attr(attrs, attr_name, attr_val): 

487 """ 

488 Assigns HDF5 attribute with value as a fixed-length string 

489 Args: 

490 attrs: 

491 attr_name: 

492 attr_val: 

493 

494 Returns: 

495 

496 """ 

497 attrs.create(attr_name, attr_val, None, dtype=f'<S{len(attr_val)}') 

498 

499 

500def loadImage(path, is_gray): 

501 if is_gray: 

502 img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 

503 else: 

504 img = cv2.imread(path, cv2.IMREAD_COLOR) 

505 raise NotImplementedError 

506 # im = squeeze(double(im(:, :, 1)) + double(im(:, :, 2)) + double(im(:, :, 3))); 

507 img = img.astype(float) 

508 

509 # scale pixel values from 0 -> 1 

510 img = img.max() - img 

511 img = img * (1 / img.max()) 

512 return img