Implementing Filtered Back Projection#

We’ll use the classes of PyTomography to implemented filtered back projection in SPECT.

[1]:
import sys
sys.path.append('/home/gpuvmadm/PyTomography/src')
import os
from pytomography.projections import ForwardProjectionNet, BackProjectionNet
from pytomography.metadata import ObjectMeta, ImageMeta
from pytomography.mappings import MapNet
import numpy as np
import matplotlib.pyplot as plt
import torch

The two foundational tools of image reconstruction are

  1. Forward projection \(\sum_{i} c_{ij} a_i\)

  2. Back projection \(\sum_{j} c_{ij} b_j\)

Let’s discuss what these operators actually mean. First, let’s define our quantities. \(c_{ij}\) is known as the system matrix, and may include information involving attenuation and PSF correction. \(a_i\) is an arbtriary object and \(b_j\) is an arbitrary image.

It’s worth now discussing what the indices \(i\) and \(j\) actually mean. You might think: objects are three dimensional, shouldn’t there be at least 3 indices when we’re doing linear operations? Consider the following: because we are in a discrete space, any 3 dimensional object can be converted to a single (albeit very long) one dimensional object: a 128x128x128 3D matrix can be converted into a single 1D vector of length 2097152. That’s how many voxels there are in object space: you can think of index \(i\) as indexing a single voxel.

The same can be said for an image. If we have 64 projections of matrix size 128x128, then that can be thought of as a single vector of length 1048576. That’s also how many individual detector elements there are.

So in forward projection \(\sum_{i} c_{ij} a_i\), the system matrix \(c_{ij}\) maps the contribution from voxel \(i\) to a detector element \(j\). In back projection \(\sum_{j} c_{ij} b_j\), the system matrix \(c_{ij}\) maps the intensity in detector element back to every possible voxel \(i\) that could have contributed to it. In reality, however, not every voxel that could have contributed to detector element \(j\) does so with equal intensity; it is for this reason that forward projection followed by back projection does not yield the original image.

Let’s experiment with these operators. First we’ll make a 3D rectangle in object space:

[2]:
x = torch.linspace(-1,1,128)
y = torch.linspace(-1,1,128)
z = torch.linspace(-1,1,132)
xv, yv, zv = torch.meshgrid([x,y,z], indexing='ij')
object_truth = (xv>-0.2)*(xv<0.2)*(yv>-0.15)*(yv<0.15)*(zv>-0.1)*(zv<0.1)
object_truth = object_truth.to(torch.float).unsqueeze(dim=0) # add batch dimension
object_truth.shape
[2]:
torch.Size([1, 128, 128, 132])
[3]:
plt.figure(figsize=(5,4))
plt.pcolormesh(object_truth[0].sum(axis=2).T, cmap='Greys_r')
plt.axis('off')
plt.colorbar()
[3]:
<matplotlib.colorbar.Colorbar at 0x7f9692bdac40>
../_images/notebooks_t_fbp_6_1.png

Before we do any projections, we need to get corresponding metadata for our object. In this case, we’ll assume the voxel sizes are 1cm \(^3\). For our image space, we’ll assume 60 projections are taken at angular spacing of 6 degrees.

[4]:
angles = np.arange(0,360.,6.)
object_meta = ObjectMeta(dr=(1,1,1), shape=object_truth[0].shape)
image_meta = ImageMeta(object_meta, angles=angles)

With this metadata, we can create our forward and back projection networks. We’ll use no image correction techniques now.

[6]:
fp_net = ForwardProjectionNet(obj2obj_nets=[],
                              im2im_nets=[],
                              object_meta=object_meta,
                              image_meta=image_meta)
bp_net = BackProjectionNet(obj2obj_nets=[],
                           im2im_nets=[],
                           object_meta=object_meta,
                           image_meta=image_meta)

We can now use the fp_net to convert the object into an image (this is a Mickey-Mouse version of a detector simulation)

[7]:
image = fp_net(object_truth)
image.shape
[7]:
torch.Size([1, 60, 128, 132])

We can look at a projection at 60 degrees for example:

[8]:
fig, axes = plt.subplots(1,5,figsize=(15,4))
for i, proj in enumerate([0,5,10,15,20]):
    axes[i].pcolormesh(image[0][proj].T, cmap='Greys_r')
    axes[i].set_title(f'Angle={image_meta.angles[proj]}')
    axes[i].axis('off')
../_images/notebooks_t_fbp_14_0.png

At angles like 60 degrees, the cube is darkest in the center and lighter on the outside; this is like looking through a semi-transparent cube in real life, it’s going to be darkest near the center if you’re looking at it from an off angle.

We can also back project the cube to turn it back into an object

[9]:
object_new = bp_net(image)
object_new.shape
[9]:
torch.Size([1, 128, 128, 132])

But if we look at the new object:

[10]:
plt.figure(figsize=(5,4))
plt.pcolormesh(object_new[0].sum(axis=2).T, cmap='Greys_r')
plt.axis('off')
plt.colorbar()
[10]:
<matplotlib.colorbar.Colorbar at 0x7f965df532b0>
../_images/notebooks_t_fbp_19_1.png

We can see that it has been blurred (note: no blurring in the \(z\) direction because the voxels at height \(z_0\) will only contribute to detector elements at height \(z_0\) due to collimation)

Example: Filtered Back Projection. In this case the image estimate is given by

\[\hat{f}_i = \pi \sum_j c_{ij} \left( \mathcal{F}^{-1}(|\omega|\mathcal{F}(g)) \right)_j\]

where the term in brackets involves applying a 1D convolution (in this case, multiplication in Fourier space with the Ramp filter) to the image along the \(r\) axis.

[11]:
freq_fft = torch.fft.fftfreq(image.shape[-2])
filter = torch.abs(freq_fft).reshape((-1,1))
image_fft = torch.fft.fft(image, axis=-2)
image_fft = image_fft* filter
image_filtered = torch.fft.ifft(image_fft, axis=-2).real
[22]:
object_fbp = bp_net(image_filtered, normalize=True) *np.pi
[23]:
plt.figure(figsize=(5,4))
plt.pcolormesh(object_fbp[0][:,:,64].T, cmap='Greys_r')
plt.axis('off')
plt.colorbar()
[23]:
<matplotlib.colorbar.Colorbar at 0x7f965da15f10>
../_images/notebooks_t_fbp_24_1.png

The cube is no longer blurred, but artifacts are present. Such artifacts are not present when using algorithms like OSEM for reconstruction.