--- title: Title keywords: fastai sidebar: home_sidebar ---
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
from incendio.core import BaseModel
class Model(BaseModel):
def __init__(self):
super().__init__()
self.pad = nn.ReflectionPad2d(2)
self.conv1 = nn.Conv2d(3, 8, kernel_size=5)
self.conv2 = nn.Conv2d(8, 16, kernel_size=5)
self.adapt = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 1)
def forward(self, x):
x = self.conv1(self.pad(x))
x = F.leaky_relu(x)
x = self.conv2(self.pad(x))
x = F.leaky_relu(x)
x = self.adapt(x)
x = self.fc(x.squeeze())
return torch.sigmoid(x)
class Model2(BaseModel):
def __init__(self, c_in, c_outs):
super().__init__()
dims = [c_in] + c_outs
self.enc = nn.Sequential(*[nn.Sequential(nn.ReflectionPad2d(2),
nn.Conv2d(*(c_in, c_out), kernel_size=5),
nn.LeakyReLU())
for c_in, c_out in zip(dims, dims[1:])])
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(c_outs[-1], 1)
def forward(self, x):
x = self.enc(x)
x = self.pool(x)
print(x.shape)
x = x.squeeze()
print(x.shape)
return torch.sigmoid(self.fc(x))
m2 = Model2(3, [8, 16, 32])
m2
m1 = Model()
m1
x = torch.randint(255, (2, 3, 4, 4)).float()
x
x[0].shape
def show_img(img):
plt.imshow(img.permute(1, 2, 0) / 255)
plt.show()
show_img(x[0])
x[0].shape
x.shape
pad = nn.ReflectionPad2d(2)
x_pad = pad(x)
x_pad.shape
show_img(x_pad[0])
class ReflectionPaddedConv2d(nn.Module):
def __init__(self, in_channels, out_channels, padding=1,
kernel_size=3, stride=1, bias=True, **kwargs):
super().__init__()
self.reflect = nn.ReflectionPad2d(padding)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride, padding=0, bias=bias)
def forward(self, x):
x = self.reflect(x)
return self.conv(x)
rc = ReflectionPaddedConv2d(in_channels=3, out_channels=3)
nn.init.constant_(rc.conv.weight, 1)
nn.init.constant_(rc.conv.bias, 0)
r = nn.Conv2d(3, 3, kernel_size=3, padding=1, padding_mode='zeros', bias=True)
nn.init.constant_(r.weight, 1)
nn.init.constant_(r.bias, 0)
x_p = r(x)
x_p.shape
print(nn.Conv2d.__doc__)