Coverage for kwave/reconstruction/beamform.py: 0%

57 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-10-24 11:52 -0700

1import numpy as np 

2from scipy.signal import hilbert 

3from scipy.interpolate import interp1d 

4from uff import UFF, ChannelData 

5from matplotlib import pyplot as plt 

6import kwave.reconstruction.tools as tools 

7from uff.position import Position 

8from kwave.reconstruction.shifted_transform import ShiftedTransform 

9 

10 

11def beamform(channel_data: ChannelData): 

12 """ 

13 

14 Args: 

15 channel_data: shape => (1, 96, 32, 1585) 

16 

17 Returns: 

18 

19 """ 

20 f_number = 1.2 

21 num_px_z = 256 

22 imaging_depth = 40e-3 

23 # apodization_window = 'boxcar' 

24 apodization_window = "none" 

25 number_samples = np.size(channel_data.data, axis=-1) 

26 

27 # create depth vector 

28 z = np.linspace(0, imaging_depth, num_px_z) 

29 

30 # allocate memory for beamformed image 

31 beamformed_data = np.zeros((len(z), len(channel_data.sequence)), dtype=complex) 

32 

33 # hilbert transform rf data to get envelope 

34 channel_data.data = hilbert(channel_data.data, axis=3) 

35 

36 # allocate memory for 

37 wave_origin_x = np.empty(len(channel_data.sequence)) 

38 

39 for e_id, event in enumerate(channel_data.sequence): 

40 # todo event.event should be event.event_id or event.key 

41 # todo make itteratable getter for events 

42 event = channel_data.unique_events[e_id] 

43 probe = event.receive_setup.probe 

44 sampling_freq = event.receive_setup.sampling_frequency 

45 # We assume one transmit wave per transmit event... hence 0 index 

46 transmit_wave = event.transmit_setup.transmit_waves[0] 

47 

48 # make time vector 

49 time_vector = tools.make_time_vector(num_samples=number_samples, sampling_freq=sampling_freq, 

50 time_offset=event.receive_setup.time_offset) 

51 

52 # todo: make indexing 0 min and not 1 min 

53 wave_origin_x[e_id] = channel_data.unique_waves[transmit_wave.wave - 1].origin.position.x 

54 

55 # todo: make position objects 

56 pixel_positions = np.stack([wave_origin_x[e_id] * np.ones(len(z)), np.zeros(len(z)), z]).T 

57 expanding_aperture = pixel_positions[:, 2] / f_number 

58 

59 # time zero delays for spherical waves 

60 origin = tools.get_origin_array(channel_data, transmit_wave) 

61 t0_point = tools.get_t0(transmit_wave) 

62 

63 # print(origin, t0_point) 

64 

65 transmit_distance = np.sign(pixel_positions[:, 2] - origin[2]) * \ 

66 np.sqrt(np.sum((pixel_positions - origin) ** 2, axis=1)) + \ 

67 np.abs(1.2 * t0_point[0]) 

68 # np.sqrt(np.sum((origin - t0_point) ** 2)) 

69 

70 probe = channel_data.probes[probe - 1] 

71 # todo: why are element positions saved as transforms and not positions? 

72 transform = ShiftedTransform.deserialize(probe.transform.serialize()) 

73 # todo: remove list from channel mapping. currently [[<element_number>,]...] 

74 

75 # dataset.channel_data.unique_waves[transmit_wave.wave - 1].origin.position.x 

76 

77 # event.transmit_setup.channel_mapping = np.arange(1, 33) # Added by Farid 

78 plt.plot(transmit_distance) 

79 

80 for element_number in event.transmit_setup.channel_mapping: 

81 element_number = element_number[0] # Changed by Farid 

82 

83 # todo: why are element positions saved as transformations? 

84 element_position = Position.deserialize( 

85 probe.element[element_number - 1].transform.translation.serialize()) 

86 element_location = Position.deserialize(transform(element_position).serialize()) 

87 

88 pixel_element_lateral_distance = abs(pixel_positions[:, 0] - element_location[0]) 

89 # print(pixel_element_lateral_distance) 

90 receive_apodization = tools.apodize(pixel_element_lateral_distance, expanding_aperture, apodization_window) 

91 

92 # receive distance 

93 receive_distance = np.sqrt(np.sum((pixel_positions - np.array(element_location)) ** 2, axis=1)) 

94 

95 t0 = transmit_wave.time_offset 

96 

97 # round trip delay 

98 delay = (transmit_distance + receive_distance) / channel_data.sound_speed + t0 

99 

100 # beamformed data 

101 chan_id = element_number - 1 - event.transmit_setup.channel_mapping[0][0] # tricky part 

102 signal = np.squeeze(channel_data.data[:, e_id, chan_id, :]) 

103 interp = interp1d(x=time_vector, y=signal, kind='cubic', bounds_error=False, fill_value=0) 

104 beamformed_data[:, e_id] += np.squeeze(receive_apodization * interp(delay).T) 

105 

106 # Envelope and plot 

107 envelope_beamformed_data = np.absolute(beamformed_data) 

108 compressed_beamformed_data = 20 * np.log10(envelope_beamformed_data / np.amax(envelope_beamformed_data) + 1e-12) 

109 

110 plt.figure 

111 x_dis = 1e3 * wave_origin_x 

112 z_dis = 1e3 * z 

113 plt.imshow(compressed_beamformed_data, vmin=-60, vmax=0, cmap='Greys_r', 

114 extent=[min(x_dis), max(x_dis), max(z_dis), min(z_dis)]) 

115 plt.xlabel('x[mm]', fontsize=12) 

116 plt.ylabel('z[mm]', fontsize=12) 

117 plt.title(channel_data.description) 

118 plt.colorbar() 

119 plt.show()