Source code for thelper.nn.common

import torch


[docs]class DenseBlock(torch.nn.Module):
[docs] def __init__(self, input_size, output_size, bias=True, activation='relu', norm='batch'): super(DenseBlock, self).__init__() self.fc = torch.nn.Linear(input_size, output_size, bias=bias) self.norm = norm if self.norm =='batch': self.bn = torch.nn.BatchNorm1d(output_size) elif self.norm == 'instance': self.bn = torch.nn.InstanceNorm1d(output_size) self.activation = activation if self.activation == 'relu': self.act = torch.nn.ReLU(True) elif self.activation == 'prelu': self.act = torch.nn.PReLU() elif self.activation == 'lrelu': self.act = torch.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = torch.nn.Tanh() elif self.activation == 'sigmoid': self.act = torch.nn.Sigmoid()
[docs] def forward(self, x): if self.norm is not None: out = self.bn(self.fc(x)) else: out = self.fc(x) if self.activation is not None: return self.act(out) else: return out
[docs]class ConvBlock(torch.nn.Module):
[docs] def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='relu', norm='batch', groups=1, prelu_params=1): super(ConvBlock, self).__init__() self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias,groups=groups) self.norm = norm if self.norm =='batch': self.bn = torch.nn.BatchNorm2d(output_size) elif self.norm == 'instance': self.bn = torch.nn.InstanceNorm2d(output_size) elif self.norm is None: self.bn = None else: raise(Exception('Bad normalization selection')) self.activation = activation if self.activation == 'relu': self.act = torch.nn.ReLU(True) elif self.activation == 'prelu': if prelu_params != 1: prelu_params = input_size self.act = torch.nn.PReLU(num_parameters=prelu_params) elif self.activation == 'lrelu': self.act = torch.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = torch.nn.Tanh() elif self.activation == 'sigmoid': self.act = torch.nn.Sigmoid() elif self.activation is None: self.act = None else: raise(Exception('Bad activation selection')) self.forward = self.forward_bn_act
[docs] def forward_bn_act(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.act is not None: x = self.act(x) return x
[docs] def forward_act_bn(self, x): x = self.conv(x) if self.act is not None: x = self.act(x) if self.bn is not None: x = self.bn(x) return x
[docs]class DeconvBlock(torch.nn.Module):
[docs] def __init__(self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation='relu', norm='batch'): super(DeconvBlock, self).__init__() self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) self.norm = norm if self.norm == 'batch': self.bn = torch.nn.BatchNorm2d(output_size) elif self.norm == 'instance': self.bn = torch.nn.InstanceNorm2d(output_size) self.activation = activation if self.activation == 'relu': self.act = torch.nn.ReLU(True) elif self.activation == 'prelu': self.act = torch.nn.PReLU() elif self.activation == 'lrelu': self.act = torch.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = torch.nn.Tanh() elif self.activation == 'sigmoid': self.act = torch.nn.Sigmoid()
[docs] def forward(self, x): if self.norm is not None: out = self.bn(self.deconv(x)) else: out = self.deconv(x) if self.activation is not None: return self.act(out) else: return out
[docs]class ResNetBlock(torch.nn.Module):
[docs] def __init__(self, num_filter, kernel_size=3, stride=1, padding=1, bias=True, activation='relu', norm='batch'): super(ResNetBlock, self).__init__() self.conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) self.conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding, bias=bias) self.norm = norm if self.norm == 'batch': self.bn = torch.nn.BatchNorm2d(num_filter) elif norm == 'instance': self.bn = torch.nn.InstanceNorm2d(num_filter) self.activation = activation if self.activation == 'relu': self.act = torch.nn.ReLU(True) elif self.activation == 'prelu': self.act = torch.nn.PReLU() elif self.activation == 'lrelu': self.act = torch.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = torch.nn.Tanh() elif self.activation == 'sigmoid': self.act = torch.nn.Sigmoid()
[docs] def forward(self, x): residual = x if self.norm is not None: out = self.bn(self.conv1(x)) else: out = self.conv1(x) if self.activation is not None: out = self.act(out) if self.norm is not None: out = self.bn(self.conv2(out)) else: out = self.conv2(out) out = torch.add(out, residual) return out
[docs]class PSBlock(torch.nn.Module):
[docs] def __init__(self, input_size, output_size, scale_factor, kernel_size=3, stride=1, padding=1, bias=True, activation='relu', norm='batch'): super(PSBlock, self).__init__() self.conv = torch.nn.Conv2d(input_size, output_size * scale_factor**2, kernel_size, stride, padding, bias=bias) self.ps = torch.nn.PixelShuffle(scale_factor) self.norm = norm if self.norm == 'batch': self.bn = torch.nn.BatchNorm2d(output_size) elif norm == 'instance': self.bn = torch.nn.InstanceNorm2d(output_size) self.activation = activation if self.activation == 'relu': self.act = torch.nn.ReLU(True) elif self.activation == 'prelu': self.act = torch.nn.PReLU() elif self.activation == 'lrelu': self.act = torch.nn.LeakyReLU(0.2, True) elif self.activation == 'tanh': self.act = torch.nn.Tanh() elif self.activation == 'sigmoid': self.act = torch.nn.Sigmoid()
[docs] def forward(self, x): if self.norm is not None: out = self.bn(self.ps(self.conv(x))) else: out = self.ps(self.conv(x)) if self.activation is not None: out = self.act(out) return out
[docs]class Upsample2xBlock(torch.nn.Module):
[docs] def __init__(self, input_size, output_size, bias=True, upsample='deconv', activation='relu', norm='batch'): super(Upsample2xBlock, self).__init__() scale_factor = 2 # 1. Deconvolution (Transposed convolution) if upsample == 'deconv': self.upsample = DeconvBlock(input_size, output_size, kernel_size=4, stride=2, padding=1, bias=bias, activation=activation, norm=norm) # 2. Sub-pixel convolution (Pixel shuffler) elif upsample == 'ps': self.upsample = PSBlock(input_size, output_size, scale_factor=scale_factor, bias=bias, activation=activation, norm=norm) # 3. Resize and Convolution elif upsample == 'rnc': self.upsample = torch.nn.Sequential( torch.nn.Upsample(scale_factor=scale_factor, mode='nearest'), ConvBlock(input_size, output_size, kernel_size=3, stride=1, padding=1, bias=bias, activation=activation, norm=norm) )
[docs] def forward(self, x): out = self.upsample(x) return out
[docs]def weights_init_kaiming(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: torch.nn.init.kaiming_normal(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('Conv2d') != -1: torch.nn.init.kaiming_normal(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('ConvTranspose2d') != -1: torch.nn.init.kaiming_normal(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('Norm') != -1: m.weight.data.normal_(1.0, 0.02) if m.bias is not None: m.bias.data.zero_()
[docs]def weights_init_xavier(m): classname = m.__class__.__name__ if classname.find('Linear') != -1: torch.nn.init.xavier_uniform(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('Conv2d') != -1: torch.nn.init.xavier_uniform(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('ConvTranspose2d') != -1: torch.nn.init.xavier_uniform(m.weight) if m.bias is not None: m.bias.data.zero_() elif classname.find('Norm') != -1: m.weight.data.normal_(1.0, 0.02) if m.bias is not None: m.bias.data.zero_()
[docs]def shave(imgs, border_size=0): size = list(imgs.shape) if len(size) == 4: shave_imgs = torch.FloatTensor(size[0], size[1], size[2]-border_size*2, size[3]-border_size*2) for i, img in enumerate(imgs): shave_imgs[i, :, :, :] = img[:, border_size:-border_size, border_size:-border_size] return shave_imgs else: return imgs[:, border_size:-border_size, border_size:-border_size]