Source code for thelper.nn.sr.vdsr

import torch

import thelper.nn


[docs]class VDSR(thelper.nn.Module): """Implements the VDSR architecture. See Kim et al., "Accurate Image Super-Resolution Using Very Deep Convolutional Networks" (2015) for more information (https://arxiv.org/abs/1511.04587). """
[docs] def __init__(self, task, num_channels=1, base_filter=64, kernel_size0=3, num_residuals=18, groups=1, activation='relu', norm='batch'): # note: must always forward args to base class to keep backup super(VDSR, self).__init__(task, **{k: v for k, v in vars().items() if k not in ["self", "task", "__class__"]}) self.kernel_size0 = kernel_size0 self.num_channels = num_channels self.input_conv = thelper.nn.common.ConvBlock(input_size=num_channels, output_size=base_filter, kernel_size=self.kernel_size0, stride=1, padding=self.kernel_size0 // 2, norm=norm, bias=False, groups=groups, activation=activation) self.num_residuals = num_residuals conv_blocks = [] if self.num_residuals: for _ in range(self.num_residuals): conv_blocks.append(thelper.nn.common.ConvBlock(input_size=base_filter, output_size=base_filter, kernel_size=3, stride=1, padding=1, norm=norm, bias=False, groups=groups, activation=activation)) self.residual_layers = torch.nn.Sequential(*conv_blocks) self.output_conv = thelper.nn.common.ConvBlock(input_size=base_filter, output_size=num_channels, kernel_size=3, stride=1, padding=1, activation=None, norm=None, bias=False, groups=groups) self.weight_init() self.set_task(task)
[docs] def forward(self, x): x0 = x.view(x.shape[0]*x.shape[1],1, x.shape[2], x.shape[3]) residual = x0 x0 = self.input_conv(x0) if self.num_residuals > 0: x0 = self.residual_layers(x0) x0 = self.output_conv(x0) x0 = torch.add(x0, residual) x0 = x0.view(x.shape[0], x.shape[1], x.shape[2], x.shape[3]) return x0
[docs] def weight_init(self): for m in self.modules(): thelper.nn.common.weights_init_kaiming(m)
[docs] def set_task(self, task): if not isinstance(task, thelper.tasks.Regression): raise AssertionError("VDSR architecture only available for super res regression tasks") self.task = task