--- title: Examples01 keywords: fastai sidebar: home_sidebar summary: "Example using fa_convnav to select modules for investigation with pytorch hooks." ---
{% raw %}

Import fastai deep learning library including pretrained vision models.

from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
from torch import torch
import PIL
import cv2

Import the fa_convnav.navigator module

from fa_convnav.navigator import *

Create a fastai datablock and dataloader using the Oxford PetsII dataset (included with fastai install), and apply some simple image transforms in the process.

pets = DataBlock(blocks=(ImageBlock, CategoryBlock), 
                 get_items=get_image_files, 
                 splitter=RandomSplitter(),
                 get_y=RegexLabeller(pat = r'/([^/]+)_\d+.jpg$'),
                 item_tfms=Resize(460),
                 batch_tfms=[*aug_transforms(size=224, max_rotate=30, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])

dls = pets.dataloaders(untar_data(URLs.PETS)/"images",  bs=128)

Download the pretrained model we want to use.

model = resnet18

Create a fastai Learner object from the dataloader, the chosen model, and an optimiser.

learn = cnn_learner(
    dls, 
    model, 
    opt_func=partial(Adam, lr=slice(3e-3), wd=0.01, eps=1e-8), 
    metrics=error_rate, 
    config=cnn_config(ps=0.33)).to_fp16()

Create a ConvNav instance.

cn = ConvNav(learn, learn.summary())

Check the CNDF dataframe.

cn.view(top=True)
Resnet: Resnet18
Input shape: [128 x 3 x 224 x 224] (bs, ch, h, w)
Output features: [128 x 37] (bs, classes)
Currently frozen to parameter group 3 out of 3

Module_name Model Division Container_child Container_block Layer_description Torch_class Output_dimensions Parameters Trainable Currently
Index
0 Sequential torch.nn.modules.container.Sequential
1 0 Sequential torch.nn.modules.container.Sequential
2 0.0 Conv2d Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) torch.nn.modules.conv.Conv2d [128 x 64 x 112 x 11] 9,408 False Frozen
3 0.1 BatchNorm2d BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.nn.modules.batchnorm.BatchNorm2d [128 x 64 x 112 x 11] 128 True
4 0.2 ReLU ReLU(inplace=True) torch.nn.modules.activation.ReLU [128 x 64 x 112 x 11] 0 False
5 0.3 MaxPool2d MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) torch.nn.modules.pooling.MaxPool2d [128 x 64 x 56 x 56] 0 False
6 0.4 Sequential torch.nn.modules.container.Sequential
7 0.4.0 BasicBlock torchvision.models.resnet.BasicBlock
8 0.4.0.conv1 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) torch.nn.modules.conv.Conv2d [128 x 64 x 56 x 56] 36,864 False Frozen
9 0.4.0.bn1 BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) torch.nn.modules.batchnorm.BatchNorm2d [128 x 64 x 56 x 56] 128 True
...69 more layers

Select a spread of equally spaced blocks from the model and store their module objects in spread.

spread = cn.spread('block', 4)
resnet18
Spread of block where n = 4

Module_name Model Division Container_child Container_block Num_layers Torch_class Output_dimensions Parameters Trainable Currently
Index
7 0.4.0 0 4 BasicBlock 5 torchvision.models.resnet.BasicBlock [128 x 64 x 56 x 56] Frozen
20 0.5.0 0 5 BasicBlock 8 torchvision.models.resnet.BasicBlock [128 x 128 x 28 x 28] Frozen
36 0.6.0 0 6 BasicBlock 8 torchvision.models.resnet.BasicBlock [128 x 256 x 14 x 14] Frozen
61 0.7.1 0 7 BasicBlock 5 torchvision.models.resnet.BasicBlock [128 x 512 x 7 x 7] Frozen

View the modules.

for b in spread:
  print(f'\n{b}')
BasicBlock(
  (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

BasicBlock(
  (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (downsample): Sequential(
    (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

BasicBlock(
  (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (downsample): Sequential(
    (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

BasicBlock(
  (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

The following code registers forward and backward hooks to target_module then passes a single image input_img through the model. During the forward pass, the input and output activations of the target_module, and during the backward pass the gradients, are stored. These are then converted into image representations of the input activations, gradients, feature map and gradient class activation map (gradcam) associated with the target_module.

If you are not familiar with Pytorch hooks don't worry they are just ways to 'hook' into a model at specific points (i.e. the target_module) to draw out activations and gradients during model operation. Gradcams allow us to see which areas of an image contribute most towards final predictions.

The code is adapted and abridged from these sources, which are also good place to start reading about hooks and gradcams.

class examine_modules():
  "Gets activation stats and gradients for a module"
  def __init__(self, model, target_module, input_img):
    self.model = model
    self.gradients = dict()
    self.activations_out = dict()
    self.activations_in = dict()
    self.input_img_resized = input_img.resize((224, 224), Image.BILINEAR)

    def forward_hook(module, inp, output):
        self.activations_in['value'] = inp[0]
        self.activations_out['value'] = output
        return None
        
    def backward_hook(module, grad_input, grad_output):
        self.gradients['value'] = grad_output[0]
        return None

    self.handle_fwd = target_module.register_forward_hook(forward_hook)
    self.handle_bck = target_module.register_backward_hook(backward_hook)

    def normalize(tensor):
      mean, std = imagenet_stats
      mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
      std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
      return tensor.sub(mean).div(std)

    self.torch_img = torch.from_numpy(np.asarray(self.input_img_resized)).permute(2, 0, 1).unsqueeze(0).float().div(255).cuda()
    self.normed_torch_image = normalize(self.torch_img)

  def model_pass(self):
    "make a pass through the model with input image (shape: 1, 3, H, W)"

    preds = self.model(self.torch_img)                  # forward pass gives predictions (preds)
    loss = preds[:, preds.max(1)[-1]].squeeze()         # class prediction (highest pred) becomes the loss 

    self.model.zero_grad()
    loss.backward(retain_graph=False)                   # backward pass with highest prediction as loss

    self.acts_in = self.activations_in['value']         # input activations 
    self.acts_out = self.activations_out['value']       # output feature maps, shape (1, 512, 7, 7) for example
    self.grads = self.gradients['value']                # input gradients
    return None

  def gradcam(self):
    "Make gradcam"
    b, c, h, w = self.grads.size()
    alpha = self.grads.view(b, c, -1).mean(2)                  # mean of each feature map (shape 1, 512)
    weights = alpha.view(b, c, 1, 1)                      # create weight matrix shape (1,512,1,1) containing gradient means

    saliency_map = (weights*self.acts_out).sum(1, keepdim=True)    # sum across second dimension (dim 1) keeping dimensions as per inputs
    saliency_map = F.relu(saliency_map)
    saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False)
    saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
    saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data

    mask = saliency_map

    heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze().cpu()), cv2.COLORMAP_JET)
    heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255)
    b, g, r = heatmap.split(1)
    heatmap = torch.cat([r, g, b])
    
    cam_rgb = heatmap + self.torch_img.cpu()
    cam_rgb = cam_rgb.div(cam_rgb.max()).squeeze()
    self.gradcam = cv2.merge((cam_rgb.numpy())) 
    return None

  def display(self, n):
    "Display input image, input activations, output activations, gradient and gradcam images in a row"

    def display_image(x):
      "Denormalises, processea and displays a torch.image , `x`"
      x -= x.mean()
      x /= (x.std() + 1e-5)
      x *= 0.1
      x += 0.5
      x = np.clip(x, 0, 1)
      x *= 255
      x = np.clip(x, 0, 255).astype('uint8')
      plt.imshow(x)
      return None

    fig, ax = plt.subplots(1,5, figsize=(20, 4))

    plt.subplot(1, 5, 1)
    img = self.input_img_resized.resize(self.acts_in[0][1].shape, Image.BILINEAR)
    plt.imshow(img)
    if not n: plt.title("\nInput image \n(from dataloader)\n", fontsize=18)
    plt.ylabel(f'Block {n}', fontsize=20)

    plt.subplot(1, 5, 2)
    f_in = self.acts_in[0][1].cpu().detach().numpy()
    if not n: plt.title("\nInput activations\n", fontsize=20)
    display_image(f_in)
    
    plt.subplot(1, 5, 3)
    g = self.grads[0][1].cpu().detach().numpy()
    if not n: plt.title("\nGradients\n", fontsize=18)
    display_image(g)
    
    plt.subplot(1, 5, 4)
    f_out = self.acts_out[0][1].cpu().detach().numpy()
    if not n: plt.title("\nOutput activations \n(feature map)\n", fontsize=18)
    display_image(f_out)
    
    plt.subplot(1, 5, 5)
    if not n: plt.title("\nGradcam\n", fontsize=18)
    plt.imshow(self.gradcam)

    plt.show()
    return None
  
  def remove_hooks(self):
    self.handle_fwd.remove()
    self.handle_bck.remove()

By plotting the input activations, gradients, feature maps and gradcam in successively for each of the four modules we selected (and that were stored in spread) we can easily follow an image as it is processed by the model. By repeating the hooks and plots at different stages in model training we can see whether our model is training appropriately or not.

# place the model onto the GPU
model = learn.model.eval().cuda()

# select an image from the dataloader
fname = dls.valid_ds.items[0]
pil_img = PIL.Image.open(fname) 

#Loop through our selected modules in spread
def display_img_grid(model, modules, input_img):
  for n, m in enumerate(modules):
    x = examine_modules(model, m, input_img)
    x.model_pass()
    x.gradcam()
    x.display(n)
    x.remove_hooks()

display_img_grid(model, spread, pil_img)

Resnet18 has been pre-trained on the imagenet dataset and the gradcam shows that it already does a pretty good job of identifying the dog or cat (there were lots of dog and cat images in the imagenet dataset) even without additional training. Now train the model on our dog and cat breeds dataset and see what it can do.

learn.fit_one_cycle(5, 1e-3)
display_img_grid(model, spread, pil_img)
epoch train_loss valid_loss error_rate time
0 2.758332 0.546939 0.161028 00:58
1 1.190638 0.355577 0.106225 00:58
2 0.689494 0.318071 0.091340 00:58
3 0.456033 0.291165 0.079161 00:58
4 0.364217 0.294581 0.083221 01:00

After 5 epochs the model localises those features of the dog or cat which distinguish it from other breeds (not just from other animals). This corresponds to a low error rate in the breed classification task and shows that the model is training correctly. Next we can unfreeze the body of the model and train for a further 5 epochs.

learn.unfreeze()
learn.fit_one_cycle(5, slice(1e-5, 1e-4, 1e-4))
display_img_grid(model, spread, pil_img)
epoch train_loss valid_loss error_rate time
0 0.303732 0.284637 0.079161 01:01
1 0.284223 0.285909 0.078484 01:00
2 0.251341 0.268647 0.081867 01:00
3 0.218321 0.265787 0.077808 01:00
4 0.208654 0.265001 0.077131 00:59

After a further 5 epochs of training the unfrozen model, close examination of the gradcam image above shows an even greater localization of the distinguishing features of the dog or cat.

Now lets try something different and train the model again, bit this time at a much higher learning rate.

learn.fit_one_cycle(1, 1e-2)
display_img_grid(model, spread, pil_img)
epoch train_loss valid_loss error_rate time
0 4.272862 4.995016 0.947903 00:58

As expected, the high learning rate hinders rather than helps gradient descent and reduces model performance, likely wiping out all the pretrained weights. The rise in error rate corresponds to gradcam images that are more noisy and less focused on areas of interest. This model is no longer training effectively.

{% endraw %}