'<iframe width="560" height="315" src="https://www.youtube.com/embed/3g5YLK1nbu8" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>') HTML(
fast.ai Chapter 7:Test Time Augmentation
Here’s a video walkthrough of this notebook:
Introduction
In this notebook, I work through the first of four “Further Research” problems assigned at the end of Chapter 7 in the textbook “Deep Learning for Coders with fastai and PyTorch”.
The prompt for this exercise is:
Use the fastai documentation to build a function that crops an image to a square in each of the four corners; then implement a TTA method that averages the predictions on a center crop and those four crops. Did it help? Is it better than the TTA method of fastai?
What is Test Time Augmentation?
I’ll quote directly from the text:
During inference or validation, creating multiple versions of each image using data augmentation, and then taking the average or maximum of the predictions for each augmented version of the image.
TTA is data augmentation during validation, in hopes that objects located outside the center of the image (which is the default fastai validation image crop) can be recognized by the model in order to increase the model’s accuracy.
The default Learner.tta
method averages the predictions on the center crop and four randomly generated crops. The method I’ll create will average the predictions between the center crop and four corner crops.
User-defined Test Time Augmentation
Read and understand the Learner.tta
and RandomCrop
source code
def tta(self:Learner, ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False):
"Return predictions on the `ds_idx` dataset or `dl` using Test Time Augmentation"
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
try:
self(_before_epoch)
with dl.dataset.set_split_idx(0), self.no_mbar():
if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))
= []
aug_preds for i in self.progress.mbar if hasattr(self,'progress') else range(n):
self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
self.get_preds(dl=dl, inner=True)[0][None])
aug_preds.append(= torch.cat(aug_preds)
aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)
aug_preds self.epoch = n
with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True)
finally: self(event.after_fit)
if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs
= (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta)
preds return preds,targs
class RandomCrop(RandTransform):
"Randomly crop an image to `size`"
= None,1
split_idx,order def __init__(self, size, **kwargs):
= _process_sz(size)
size
store_attr()super().__init__(**kwargs)
def before_call(self, b, split_idx):
self.orig_sz = _get_sz(b)
if split_idx: self.tl = (self.orig_sz-self.size)//2
else:
= self.orig_sz[0] - self.size[0]
wd = self.orig_sz[1] - self.size[1]
hd = (wd, -1) if wd < 0 else (0, wd)
w_rand = (hd, -1) if hd < 0 else (0, hd)
h_rand self.tl = fastuple(random.randint(*w_rand), random.randint(*h_rand))
def encodes(self, x:(Image.Image,TensorBBox,TensorPoint)):
return x.crop_pad(self.size, self.tl, orig_sz=self.orig_sz)
Practice cropping images using the .crop
method on a PILImage
object
A PIL
Image
has a method called crop
which takes a crop rectangle tuple, (left, upper, right, lower)
and crops the image within those pixel bounds.
Here’s an image with a grizzly bear at the top and a black bear on the bottom. There are four coordinates of interest: left, upper, right and bottom. The leftmost points on the image are assigned a pixel value of 0. The rightmost points are located at the image width pixel pixel value. The uppermost points are at pixel 0, and the bottommost points are at the image height pixel value.
= "/content/gdrive/MyDrive/fastai-course-v4/images/test/grizzly_black.png"
f = PILImage.create(f)
img 320) img.to_thumb(
Top-Left Corner Crop
A top-left corner crop the corresponds to a left pixel of 0, upper pixel 0, right pixel of 224, and bottom pixel of 224. The order in the tuple is left, upper, right, bottom, so 0, 0, 224, 224. You can see that this crop is taken from the top left corner of the original image.
0,0,224,224)) img.crop((
Top Right Corner Crop
For the top right corner, I get the image width since the left end of the crop will be 224 pixels from the right end of the image. That translates to w-224. The upper pixel is 0, and the rightmost pixel is at w, and the bottom pixel is 224. You can see that this crop is at the top right corner of the original.
= img.width
w = img.height
h -224, 0, w, 224)) img.crop((w
Bottom Right Corner Crop
For the bottom right corner the left pixel is 224 from the right end, w-224, the upper pixel is 224 from the bottom, h-224, the right pixel is at w, and the bottom is at h.
-224, h-224, w, h)) img.crop((w
Bottom Left Corner Crop
The bottom left corner’s leftmost pixel is 0, uppermost pixel is 224 pixels from the bottom of the whole image, h - 224, the rightmost pixel is 224, and bottommost pixel is the bottom of the whole image, at h.
0, h-224, 224, h)) img.crop((
Center Crop
Finally, for the center crop, the leftmost pixel is 112 left of the image center, w/2 - 112, the upper pixel is 112 above the image center, h/2 - 112, the rightmost pixel is 112 right of the center, w/2 + 112, and the bottom pixel is 112 below the center, h/2 + 112.
/2-112, h/2-112, w/2+112,h/2+112)) img.crop((w
Summary
To better visualize this, here are a couple of images which show the left, upper, right and bottom coordinates for the corner and center crops.
Summary of corner crop
arguments (left, upper, right, bottom)
Summary of center crop
arguments (left, upper, right, bottom)
Define a function which takes an image and returns a stacked Tensor
with four corner crops and a center crop
I wrap those five lines of code into a function called corner_crop, which takes a PILImage img, and a square side length size (defaulted to 224) as its arguments. It first grabs the width and height of the image. And then goes on to save the crops of the four corners and center as TensorImages, returning them all in a single stacked Tensor.
def corner_crop(img, size=224):
"""Returns a Tensor with 5 cropped square images
img: PILImage
size: int
"""
= img.width, img.height
w,h = TensorImage(img.crop((0,0,size,size)))
top_left = TensorImage(img.crop((w-size, 0, w, size)))
top_right = TensorImage(img.crop((w-size, h-size, w, h)))
bottom_right = TensorImage(img.crop((0, h-size, size, h)))
bottom_left = TensorImage(img.crop((w/2-size/2, h/2-size/2, w/2+size/2,h/2+size/2)))
center return torch.stack([top_left, top_right, bottom_right, bottom_left, center])
I’ll test the corner_crop function and make sure that the five images are cropped correctly.
Here’s the top left corner.
= corner_crop(img)
imgs
# Top Left Corner Crop
0].show() imgs[
<matplotlib.axes._subplots.AxesSubplot at 0x7f12e1a177d0>
Top right corner:
# Top Right Corner Crop
1].show() imgs[
<matplotlib.axes._subplots.AxesSubplot at 0x7f12e197da50>
Bottom right:
# Bottom Right Corner Crop
2].show() imgs[
<matplotlib.axes._subplots.AxesSubplot at 0x7f12e146ed50>
Bottom left:
# Bottom Left Corner Crop
3].show() imgs[
<matplotlib.axes._subplots.AxesSubplot at 0x7f12e1424dd0>
And center:
# Center Crop
4].show() imgs[
<matplotlib.axes._subplots.AxesSubplot at 0x7f12e1424450>
Define a new CornerCrop
transform by extending the Transform
class definition
The main purpose for all of that was for me to wrap my head around how the crop behavior functions so that I can wrap that into a transform.
Transforms are any function that you want to apply to your data. I’ll extend the base Transform class and add in the functionality I need for these crops. When an object of the CornerCrop
class is constructed, the constructor takes size and corner_type
arguments. Since I’ll use this within a for-loop, the corner_type argument is an integer from 0
to 3
, corresponding to the loop counter. The transform is applied to the data during the .encodes
method. I grab the original image width
and height
, and create a list of cropped images using the left, upper, right, bottom
coordinates we saw above. Finally, based on the corner_type
, the corresponding crop is returned.
class CornerCrop(Transform):
"Create 4 corner and 1 center crop of `size`"
def __init__(self, size, corner_type=0, **kwargs):
self.size = size
self.corner_type = corner_type
def encodes(self, x:(Image.Image,TensorBBox,TensorPoint)):
self.w, self.h = x.size
self.crops = [
0,0,self.size, self.size)),
x.crop((self.w - self.size, 0, self.w, self.size)),
x.crop((self.w-self.size, self.h-self.size, self.w, self.h)),
x.crop((0, self.h-self.size, self.size, self.h))
x.crop((
]return self.crops[self.corner_type]
To test this transform, I created an image with top left, top right, bottom right and bottom left identified. I created multiple copies so that I can create batches.
# test image for CornerCrop
= Path('/content/gdrive/MyDrive/fastai-course-v4/images/test/corner_crop_images')
path open((path/'01.jpg')) Image.
I create a DataBlock and pass my CornerCrop to the item_tfms parameter. I’ll cycle through the different corner types. 0 corresponds to top left, 1 is top right, 2 is bottom right and 3 is bottom left. All images in my batch should be cropped to the same corner.
I set corner_type
to 0
, build the DataBlock and DataLoaders and the batch shows top left.
# get the data
# path = untar_data(URLs.IMAGENETTE)
= Path('/content/gdrive/MyDrive/fastai-course-v4/images/test/corner_crop_images')
path
# build the DataBlock and DataLoaders using CornerCrop
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =get_image_files,
get_items=parent_label,
get_y=CornerCrop(224,0))
item_tfms
= dblock.dataloaders(path, bs=4)
dls
# view a batch
dls.show_batch()
I set corner_type
to 1
, build the DataBlock and DataLoaders and the batch shows top right.
# build the DataBlock and DataLoaders using CornerCrop
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =get_image_files,
get_items=parent_label,
get_y=CornerCrop(224,1))
item_tfms
= dblock.dataloaders(path, bs=4)
dls
# view a batch
dls.show_batch()
I set corner_type
to 2
, build the DataBlock and DataLoaders and the batch shows bottom right.
# build the DataBlock and DataLoaders using CornerCrop
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =get_image_files,
get_items=parent_label,
get_y=CornerCrop(224,2))
item_tfms
= dblock.dataloaders(path, bs=4)
dls
# view a batch
dls.show_batch()
I set corner_type
to 3
, build the DataBlock and DataLoaders and the batch shows bottom left.
# build the DataBlock and DataLoaders using CornerCrop
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =get_image_files,
get_items=parent_label,
get_y=CornerCrop(224,3))
item_tfms
= dblock.dataloaders(path, bs=4)
dls
# view a batch
dls.show_batch()
Now, I can implement this transform into a new TTA method.
Define a new Learner.corner_crop_tta
method by repurposing the existing Learner.tta
definition
I’ll largely rely on the definition of tta
in the built-in Learner
class. In this method, predictions are calculated on four sets of augmented data (images) and then averaged along with predictions on a center-crop dataset.
In the existing for-loop, four sets of predictions on randomly generated crops are appended into a list.
for i in self.progress.mbar if hasattr(self,'progress') else range(n):
self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
self.get_preds(dl=dl, inner=True)[0][None]) aug_preds.append(
In my loop, I create a new DataLoader each time, passing a different corner_type
argument to the CornerCrop
transform. I also have to pass the ToTensor
transform, so that the PIL Image is converted to a Tensor. In the first iteration, it will append predictions on the top left corner crops. In the next one, it will append predictions on the top right, then the bottom right, and finally on the fourth loop, the bottom left.
= []
aug_preds for i in range(4):
= dls[1].new(after_item=[CornerCrop(224,i), ToTensor])
dl #self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
=dl, inner=True)[0][None]) aug_preds.append(learn.get_preds(dl
Since I am to average these with the center-crop image predictions, I’ll create a new DataLoader
without the CornerCrop
transform and calculate the predictions on those images:
= dls[1].new(shuffled=False, drop_last=False)
dl with dl.dataset.set_split_idx(1): preds,targs = learn.get_preds(dl=dl, inner=True)
Finally, I’ll append the center crop preds
to aug_preds
list, concatenate them into a single tensor and take the mean of the predictions:
None])
aug_preds.append(preds[= torch.cat(aug_preds).mean(0) preds
I decided to create a new Learner2
class which extends the built-in the Learner
, and added the corner_crop_tta
method by copying over the tta
method, commenting out the lines I won’t need and adding the lines and changes I’ve written above.
class Learner2(Learner):
def corner_crop_tta(self:Learner, ds_idx=1, dl=None, n=4, beta=0.25, use_max=False):
"Return predictions on the `ds_idx` dataset or `dl` using Corner Crop Test Time Augmentation"
if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)
# if item_tfms is not None or batch_tfms is not None: dl = dl.new(after_item=item_tfms, after_batch=batch_tfms)
try:
#self(_before_epoch)
with dl.dataset.set_split_idx(0), self.no_mbar():
if hasattr(self,'progress'): self.progress.mbar = master_bar(list(range(n)))
= []
aug_preds # Crop image from four corners
for i in self.progress.mbar if hasattr(self,'progress') else range(n):
= dl.new(after_item=[CornerCrop(224,i), ToTensor])
dl self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
self.get_preds(dl=dl, inner=True)[0][None])
aug_preds.append(# aug_preds = torch.cat(aug_preds)
# aug_preds = aug_preds.max(0)[0] if use_max else aug_preds.mean(0)
self.epoch = n
= self.dls[ds_idx].new(shuffled=False, drop_last=False)
dl # Crop image from center
with dl.dataset.set_split_idx(1): preds,targs = self.get_preds(dl=dl, inner=True)
None])
aug_preds.append(preds[finally: self(event.after_fit)
# if use_max: return torch.stack([preds, aug_preds], 0).max(0)[0],targs
# preds = (aug_preds,preds) if beta is None else torch.lerp(aug_preds, preds, beta)
# preds = torch.cat([aug_preds, preds]).mean(0)
= torch.cat(aug_preds).mean(0)
preds return preds,targs
Implement this new TTA method on the Imagenette classification model
In the last section of this notebook, I train a model on the Imagenette dataset, which a subset of the larger ImageNet dataset. Imagenette has 10 distinct classes.
# get the data
= untar_data(URLs.IMAGENETTE)
path
# build the DataBlock and DataLoaders
# for a single-label classification
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =get_image_files,
get_items=parent_label, # image folder names are the class names
get_y=Resize(460),
item_tfms=aug_transforms(size=224, min_scale=0.75))
batch_tfms
= dblock.dataloaders(path, bs=64)
dls
# view a batch
dls.show_batch()
# Try `CornerCrop` on a new DataLoader
# add `ToTensor` transform to conver PILImage to TensorImage
= dls[1].new(after_item=[CornerCrop(224,3), ToTensor])
new_dl new_dl.show_batch()
# baseline training
= xresnet50()
model = Learner2(dls, model, loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn 5, 3e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.628959 | 2.382344 | 0.450336 | 02:39 |
1 | 1.258259 | 3.365233 | 0.386482 | 02:45 |
2 | 0.992097 | 1.129573 | 0.653473 | 02:49 |
3 | 0.709120 | 0.643617 | 0.802091 | 02:47 |
4 | 0.571318 | 0.571139 | 0.824122 | 02:45 |
I run the default tta method, pass the predictions and targets to the accuracy function and calculate an accuracy of about 83.5% percent. Which is higher than the default center crop validation accuracy.
# built-in TTA method
= learn.tta()
preds_tta, targs_tta accuracy(preds_tta, targs_tta).item()
0.8345780372619629
Finally, I run my new corner_crop_tta method, pass the predictions and targets to the accuracy function, and calculate an accuracy of about 70.9% percent. Which is lower than the default center crop validation accuracy.
# user-defined TTA method
= learn.corner_crop_tta()
preds, targs accuracy(preds, targs).item()
0.7098581194877625
I’ll walk through the corner_crop_tta
code to verify the accuracy calculated above.
I first create an empty list for my augmented image predictions.
Then I loop through a range of 4, each time creating a new DataLoader which applies the CornerCrop
transform for each corner type and append the predictions onto the list.
# get predictions on corner cropped validation images
= []
aug_preds for i in range(4):
= dls[1].new(after_item=[CornerCrop(224,i), ToTensor])
dl #self.epoch = i #To keep track of progress on mbar since the progress callback will use self.epoch
=dl, inner=True)[0][None])
aug_preds.append(learn.get_preds(dllen(aug_preds), aug_preds[0].shape
(4, torch.Size([1, 2678, 1000]))
I then create a new DataLoader without my transform, and get those predictions.
# get predictions on center crop validation images
= dls[1].new(shuffled=False, drop_last=False)
dl with dl.dataset.set_split_idx(1): preds,targs = learn.get_preds(dl=dl, inner=True)
preds.shape
torch.Size([2678, 1000])
The shape of these predictions is missing an axis, so I pass None as a Key and it adds on a new axis.
# add an axis to match augmented prediction tensor shape
= preds[None]
preds preds.shape
torch.Size([1, 2678, 1000])
I append the center crop predictions onto the augmented predictions and concatenate all five sets of predictions into a Tensor and calculate the mean.
# average all 5 sets of predictions
aug_preds.append(preds)= torch.cat(aug_preds).mean(0) preds
I then pass those average predictions and the targets to the accuracy function calculate the accuracy which is slightly higher than above. I ran these five cells multiple times and got the same accuracy value. When I ran the corner_crop_tta
method multiple times, I got different accuracy values each time. Something in the corner_crop_tta
definition is incorrect. I’ll go with this value since it was consistent.
# calculate validation set accuracy
accuracy(preds, targs).item()
0.7311426401138306
The following table summarize the results from this training:
Validation | Accuracy |
---|---|
Center Crop | 82.4% |
Center Crop + 4 Random Crops: Linearly Interpolated | 83.5% |
Center Crop + 4 Random Crops: Averaged | 73.1% |
There are a few further research items I should pursue in the future:
- Fix the
corner_crop_tta
method so that it returns the same accuracy each time it’s run on the same trained model - Try
corner_crop_tta
on a multi-label classification dataset such as PASCAL - Try linear interpolation (between center crop and corner crop maximum) instead of mean