Source code for thelper.nn.sr.srcnn
import thelper.nn
[docs]class SRCNN(thelper.nn.Module):
"""Implements the SRCNN architecture.
See Dong et al., "Image Super-Resolution Using Deep Convolutional Networks" (2014) for more
information (https://arxiv.org/abs/1501.00092).
"""
[docs] def __init__(self, task, num_channels=1, base_filter=64, groups=1):
# note: must always forward args to base class to keep backup
super(SRCNN, self).__init__(task, num_channels=num_channels, base_filter=base_filter, groups=groups)
self.conv1 = thelper.nn.common.ConvBlock(num_channels, base_filter * groups, kernel_size=9,
stride=1, padding=0, activation="relu", norm=None, groups=groups)
self.conv2 = thelper.nn.common.ConvBlock(base_filter * groups, base_filter // 2 * groups, kernel_size=5,
stride=1, padding=0, activation="relu", norm=None, groups=groups)
self.conv3 = thelper.nn.common.ConvBlock((base_filter // 2) * groups, num_channels, kernel_size=5,
stride=1, padding=0, activation=None, norm=None, groups=groups)
self.set_task(task)
[docs] def forward(self, x):
x0 = x.view(x.shape[0] * x.shape[1], 1, x.shape[2], x.shape[3])
x0 = self.conv1(x0)
x0 = self.conv2(x0)
x0 = self.conv3(x0)
x0 = x0.view(x.shape[0], x.shape[1], x0.shape[2], x0.shape[3])
return x0
[docs] def weight_init(self):
for m in self.modules():
thelper.nn.common.weights_init_xavier(m)
[docs] def set_task(self, task):
if not isinstance(task, thelper.tasks.Regression):
raise AssertionError("SRCNN architecture only available for super res regression tasks")
self.task = task