Source code for pytomography.transforms.SPECT.psf
from __future__ import annotations
import torch
import torch.nn as nn
import numpy as np
import pytomography
from pytomography.utils import get_distance, compute_pad_size, pad_object, pad_object_z, unpad_object_z, rotate_detector_z, unpad_object
from pytomography.transforms import Transform
from pytomography.metadata import ObjectMeta, ImageMeta, PSFMeta
from torchvision.transforms import InterpolationMode
[docs]def get_PSF_transform(
sigma: np.array,
kernel_size: int,
kernel_dimensions: str ='2D',
delta: float = 1e-12,
device=pytomography.device
) -> torch.nn.Conv2d:
"""Creates a 2D convolutional layer that is used for PSF modeling.
Args:
sigma (array): Array of length Lx corresponding to blurring (sigma of Gaussian) as a function of distance from scanner
kernel_size (int): Size of the kernel used in each layer. Needs to be large enough to cover most of Gaussian
delta (float, optional): Used to prevent division by 0 when sigma=0. Defaults to 1e-9.
device (str, optional): Pytorch device used for computation. Defaults to 'cpu'.
kernel_dimensions (str, optional): Whether or not blurring is done independently in each transaxial slice ('1D') or
if blurring is done between transaxial slices ('2D'). Defaults to '2D'.
Returns:
torch.nn.Conv2d: Convolutional neural network layer used to apply blurring to objects of shape [batch_size, Lx, Ly, Lz]
"""
N = len(sigma)
layer = torch.nn.Conv2d(N, N, kernel_size, groups=N, padding='same',
padding_mode='zeros', bias=0, device=device)
x_grid, y_grid = torch.meshgrid(2*[torch.arange(-int(kernel_size//2), int(kernel_size//2)+1)], indexing='ij')
x_grid = x_grid.unsqueeze(dim=0).repeat((N,1,1))
y_grid = y_grid.unsqueeze(dim=0).repeat((N,1,1))
sigma = torch.tensor(sigma, dtype=torch.float32).reshape((N,1,1))
kernel = torch.exp(-(x_grid**2 + y_grid**2) / (2*sigma**2 + delta))
if kernel_dimensions=='1D':
kernel[y_grid!=0] = 0
kernel = kernel / kernel.sum(axis=(1,2)).reshape(N,1,1)
layer.weight.data = kernel.unsqueeze(dim=1).to(device)
return layer
[docs]class SPECTPSFTransform(Transform):
"""obj2obj transform used to model the effects of PSF blurring in SPECT. The smoothing kernel used to apply PSF modeling uses a Gaussian kernel with width :math:`\sigma` dependent on the distance of the point to the detector; that information is specified in the ``PSFMeta`` parameter.
Args:
psf_meta (PSFMeta): Metadata corresponding to the parameters of PSF blurring
"""
def __init__(
self,
psf_meta: PSFMeta,
) -> None:
"""Initializer that sets corresponding psf parameters"""
super(SPECTPSFTransform, self).__init__()
self.psf_meta = psf_meta
[docs] def configure(
self,
object_meta: ObjectMeta,
image_meta: ImageMeta
) -> None:
"""Function used to initalize the transform using corresponding object and image metadata
Args:
object_meta (ObjectMeta): Object metadata.
image_meta (ImageMeta): Image metadata.
"""
super(SPECTPSFTransform, self).configure(object_meta, image_meta)
self.kernel_size = self.compute_kernel_size()
self.layers = {}
for radius in np.unique(image_meta.radii):
sigma = self.get_sigma(radius, object_meta.dx, object_meta.shape, self.psf_meta.collimator_slope, self.psf_meta.collimator_intercept)
self.layers[radius] = get_PSF_transform(sigma/object_meta.dx, self.kernel_size, kernel_dimensions=self.psf_meta.kernel_dimensions, device=self.device)
[docs] def compute_kernel_size(self) -> int:
"""Function used to compute the kernel size used for PSF blurring. In particular, uses the ``max_sigmas`` attribute of ``PSFMeta`` to determine what the kernel size should be such that the kernel encompasses at least ``max_sigmas`` at all points in the object.
Returns:
int: The corresponding kernel size used for PSF blurring.
"""
s = self.object_meta.padded_shape[0]
dx = self.object_meta.dr[0]
largest_sigma = self.psf_meta.collimator_slope*(s/2 * dx) + self.psf_meta.collimator_intercept
return int(np.round(largest_sigma/dx * self.psf_meta.max_sigmas)*2 + 1)
[docs] def get_sigma(
self,
radius: float,
dx: float,
shape: tuple,
collimator_slope: float,
collimator_intercept: float
) -> np.array:
"""Uses PSF Meta data information to get blurring :math:`\sigma` as a function of distance from detector. It is assumed that ``sigma=collimator_slope*d + collimator_intercept`` where :math:`d` is the distance from the detector.
Args:
radius (float): The distance from the detector
dx (float): Transaxial plane pixel spacing
shape (tuple): Tuple containing (Lx, Ly, Lz): dimensions of object space
collimator_slope (float): See collimator intercept
collimator_intercept (float): Collimator slope and collimator intercept are defined such that sigma(d) = collimator_slope*d + collimator_intercept
where sigma corresponds to sigma of a Gaussian function that characterizes blurring as a function of distance from the detector.
Returns:
array: An array of length Lx corresponding to blurring at each point along the 1st axis in object space
"""
dim = shape[0] + 2*compute_pad_size(shape[0])
distances = get_distance(dim, radius, dx)
sigma = collimator_slope * distances + collimator_intercept
return sigma
@torch.no_grad()
[docs] def __call__(
self,
object_i: torch.Tensor,
i: int,
norm_constant: torch.Tensor | None = None,
) -> torch.tensor:
"""Applies PSF modeling for the situation where an object is being detector by a detector at the :math:`+x` axis.
Args:
object_i (torch.tensor): Tensor of size [batch_size, Lx, Ly, Lz] being projected along its first axis
i (int): The projection index: used to find the corresponding angle in image space corresponding to ``object_i``. In particular, the x axis (tensor `axis=1`) of the object is aligned with the detector at angle i.
norm_constant (torch.tensor, optional): A tensor used to normalize the output during back projection. Defaults to None.
Returns:
torch.tensor: Tensor of size [batch_size, Lx, Ly, Lz] such that projection of this tensor along the first axis corresponds to n PSF corrected projection.
"""
z_pad_size = int((self.kernel_size-1)/2)
object_i = pad_object_z(object_i, z_pad_size)
object_i = self.layers[self.image_meta.radii[i]](object_i)
object_i = unpad_object_z(object_i, pad_size=z_pad_size)
# Adjust normalization constant
if norm_constant is not None:
norm_constant = pad_object_z(norm_constant, z_pad_size)
norm_constant = self.layers[self.image_meta.radii[i]](norm_constant)
norm_constant = unpad_object_z(norm_constant, pad_size=z_pad_size)
return object_i, norm_constant
else:
return object_i