Source code for thelper.nn.coordconv

import torch
import torch.nn


[docs]class AddCoords(torch.nn.Module):
[docs] def __init__(self, radius_channel=False): super().__init__() self.radius_channel = radius_channel
[docs] def forward(self, in_tensor): """ in_tensor: (batch_size, channels, x_dim, y_dim) [0,0,0,0] [0,1,2,3] [1,1,1,1] [0,1,2,3] << (i,j)th coordinates of pixels added as separate channels [2,2,2,2] [0,1,2,3] taken from mkocabas. """ batch_size_tensor = in_tensor.shape[0] xx_ones = torch.ones([1, in_tensor.shape[2]], dtype=torch.int32) xx_ones = xx_ones.unsqueeze(-1) xx_range = torch.arange(in_tensor.shape[2], dtype=torch.int32).unsqueeze(0) xx_range = xx_range.unsqueeze(1) xx_channel = torch.matmul(xx_ones, xx_range) xx_channel = xx_channel.unsqueeze(-1) yy_ones = torch.ones([1, in_tensor.shape[3]], dtype=torch.int32) yy_ones = yy_ones.unsqueeze(1) yy_range = torch.arange(in_tensor.shape[3], dtype=torch.int32).unsqueeze(0) yy_range = yy_range.unsqueeze(-1) yy_channel = torch.matmul(yy_range, yy_ones) yy_channel = yy_channel.unsqueeze(-1) xx_channel = xx_channel.permute(0, 3, 1, 2) yy_channel = yy_channel.permute(0, 3, 1, 2) xx_channel = xx_channel.float() / (in_tensor.shape[2] - 1) yy_channel = yy_channel.float() / (in_tensor.shape[3] - 1) xx_channel = xx_channel * 2 - 1 yy_channel = yy_channel * 2 - 1 xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1) yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1) dev = in_tensor.device out = torch.cat([in_tensor, xx_channel.to(dev), yy_channel.to(dev)], dim=1) if self.radius_channel: radius_calc = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2)) out = torch.cat([out, radius_calc.to(dev)], dim=1) return out
[docs]class CoordConv2d(torch.nn.Module): """ add any additional coordinate channels to the input tensor """
[docs] def __init__(self, in_channels, *args, radius_channel=False, **kwargs): super().__init__() self.addcoord = AddCoords(radius_channel=radius_channel) extra_in_channels = 3 if radius_channel else 2 self.conv = torch.nn.Conv2d(in_channels + extra_in_channels, *args, **kwargs)
[docs] def forward(self, in_tensor): out = self.addcoord(in_tensor) out = self.conv(out) return out
[docs]class CoordConvTranspose(torch.nn.Module): """CoordConvTranspose layer for segmentation tasks."""
[docs] def __init__(self, radius_channel, *args, **kwargs): super().__init__() self.addcoord = AddCoords(radius_channel=radius_channel) self.convT = torch.nn.ConvTranspose2d(*args, **kwargs)
[docs] def forward(self, in_tensor): out = self.addcoord(in_tensor) out = self.convT(out) return out