--- title: Examples01 keywords: fastai sidebar: home_sidebar summary: "Example using fa_convnav to select modules for investigation with pytorch hooks." ---
Import fastai deep learning library including pretrained vision models.
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *
from torch import torch
import PIL
import cv2
Import the fa_convnav.navigator module
from fa_convnav.navigator import *
Create a fastai datablock and dataloader using the Oxford PetsII dataset (included with fastai install), and apply some simple image transforms in the process.
pets = DataBlock(blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=RegexLabeller(pat = r'/([^/]+)_\d+.jpg$'),
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, max_rotate=30, min_scale=0.75), Normalize.from_stats(*imagenet_stats)])
dls = pets.dataloaders(untar_data(URLs.PETS)/"images", bs=128)
Download the pretrained model we want to use.
model = resnet18
Create a fastai Learner object from the dataloader, the chosen model, and an optimiser.
learn = cnn_learner(
dls,
model,
opt_func=partial(Adam, lr=slice(3e-3), wd=0.01, eps=1e-8),
metrics=error_rate,
config=cnn_config(ps=0.33)).to_fp16()
Create a ConvNav instance.
cn = ConvNav(learn, learn.summary())
Check the CNDF dataframe.
cn.view(top=True)
Select a spread of equally spaced blocks from the model and store their module objects in spread
.
spread = cn.spread('block', 4)
View the modules.
for b in spread:
print(f'\n{b}')
The following code registers forward and backward hooks to target_module
then passes a single image input_img
through the model
. During the forward pass, the input and output activations of the target_module
, and during the backward pass the gradients, are stored. These are then converted into image representations of the input activations, gradients, feature map and gradient class activation map (gradcam) associated with the target_module
.
If you are not familiar with Pytorch hooks don't worry they are just ways to 'hook' into a model at specific points (i.e. the target_module
) to draw out activations and gradients during model operation. Gradcams allow us to see which areas of an image contribute most towards final predictions.
The code is adapted and abridged from these sources, which are also good place to start reading about hooks and gradcams.
class examine_modules():
"Gets activation stats and gradients for a module"
def __init__(self, model, target_module, input_img):
self.model = model
self.gradients = dict()
self.activations_out = dict()
self.activations_in = dict()
self.input_img_resized = input_img.resize((224, 224), Image.BILINEAR)
def forward_hook(module, inp, output):
self.activations_in['value'] = inp[0]
self.activations_out['value'] = output
return None
def backward_hook(module, grad_input, grad_output):
self.gradients['value'] = grad_output[0]
return None
self.handle_fwd = target_module.register_forward_hook(forward_hook)
self.handle_bck = target_module.register_backward_hook(backward_hook)
def normalize(tensor):
mean, std = imagenet_stats
mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device)
return tensor.sub(mean).div(std)
self.torch_img = torch.from_numpy(np.asarray(self.input_img_resized)).permute(2, 0, 1).unsqueeze(0).float().div(255).cuda()
self.normed_torch_image = normalize(self.torch_img)
def model_pass(self):
"make a pass through the model with input image (shape: 1, 3, H, W)"
preds = self.model(self.torch_img) # forward pass gives predictions (preds)
loss = preds[:, preds.max(1)[-1]].squeeze() # class prediction (highest pred) becomes the loss
self.model.zero_grad()
loss.backward(retain_graph=False) # backward pass with highest prediction as loss
self.acts_in = self.activations_in['value'] # input activations
self.acts_out = self.activations_out['value'] # output feature maps, shape (1, 512, 7, 7) for example
self.grads = self.gradients['value'] # input gradients
return None
def gradcam(self):
"Make gradcam"
b, c, h, w = self.grads.size()
alpha = self.grads.view(b, c, -1).mean(2) # mean of each feature map (shape 1, 512)
weights = alpha.view(b, c, 1, 1) # create weight matrix shape (1,512,1,1) containing gradient means
saliency_map = (weights*self.acts_out).sum(1, keepdim=True) # sum across second dimension (dim 1) keeping dimensions as per inputs
saliency_map = F.relu(saliency_map)
saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False)
saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
mask = saliency_map
heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze().cpu()), cv2.COLORMAP_JET)
heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255)
b, g, r = heatmap.split(1)
heatmap = torch.cat([r, g, b])
cam_rgb = heatmap + self.torch_img.cpu()
cam_rgb = cam_rgb.div(cam_rgb.max()).squeeze()
self.gradcam = cv2.merge((cam_rgb.numpy()))
return None
def display(self, n):
"Display input image, input activations, output activations, gradient and gradcam images in a row"
def display_image(x):
"Denormalises, processea and displays a torch.image , `x`"
x -= x.mean()
x /= (x.std() + 1e-5)
x *= 0.1
x += 0.5
x = np.clip(x, 0, 1)
x *= 255
x = np.clip(x, 0, 255).astype('uint8')
plt.imshow(x)
return None
fig, ax = plt.subplots(1,5, figsize=(20, 4))
plt.subplot(1, 5, 1)
img = self.input_img_resized.resize(self.acts_in[0][1].shape, Image.BILINEAR)
plt.imshow(img)
if not n: plt.title("\nInput image \n(from dataloader)\n", fontsize=18)
plt.ylabel(f'Block {n}', fontsize=20)
plt.subplot(1, 5, 2)
f_in = self.acts_in[0][1].cpu().detach().numpy()
if not n: plt.title("\nInput activations\n", fontsize=20)
display_image(f_in)
plt.subplot(1, 5, 3)
g = self.grads[0][1].cpu().detach().numpy()
if not n: plt.title("\nGradients\n", fontsize=18)
display_image(g)
plt.subplot(1, 5, 4)
f_out = self.acts_out[0][1].cpu().detach().numpy()
if not n: plt.title("\nOutput activations \n(feature map)\n", fontsize=18)
display_image(f_out)
plt.subplot(1, 5, 5)
if not n: plt.title("\nGradcam\n", fontsize=18)
plt.imshow(self.gradcam)
plt.show()
return None
def remove_hooks(self):
self.handle_fwd.remove()
self.handle_bck.remove()
By plotting the input activations, gradients, feature maps and gradcam in successively for each of the four modules we selected (and that were stored in spread
) we can easily follow an image as it is processed by the model. By repeating the hooks and plots at different stages in model training we can see whether our model is training appropriately or not.
# place the model onto the GPU
model = learn.model.eval().cuda()
# select an image from the dataloader
fname = dls.valid_ds.items[0]
pil_img = PIL.Image.open(fname)
#Loop through our selected modules in spread
def display_img_grid(model, modules, input_img):
for n, m in enumerate(modules):
x = examine_modules(model, m, input_img)
x.model_pass()
x.gradcam()
x.display(n)
x.remove_hooks()
display_img_grid(model, spread, pil_img)
Resnet18 has been pre-trained on the imagenet dataset and the gradcam shows that it already does a pretty good job of identifying the dog or cat (there were lots of dog and cat images in the imagenet dataset) even without additional training. Now train the model on our dog and cat breeds dataset and see what it can do.
learn.fit_one_cycle(5, 1e-3)
display_img_grid(model, spread, pil_img)
After 5 epochs the model localises those features of the dog or cat which distinguish it from other breeds (not just from other animals). This corresponds to a low error rate in the breed classification task and shows that the model is training correctly. Next we can unfreeze the body of the model and train for a further 5 epochs.
learn.unfreeze()
learn.fit_one_cycle(5, slice(1e-5, 1e-4, 1e-4))
display_img_grid(model, spread, pil_img)
After a further 5 epochs of training the unfrozen model, close examination of the gradcam image above shows an even greater localization of the distinguishing features of the dog or cat.
Now lets try something different and train the model again, bit this time at a much higher learning rate.
learn.fit_one_cycle(1, 1e-2)
display_img_grid(model, spread, pil_img)
As expected, the high learning rate hinders rather than helps gradient descent and reduces model performance, likely wiping out all the pretrained weights. The rise in error rate corresponds to gradcam images that are more noisy and less focused on areas of interest. This model is no longer training effectively.