from fastai.vision.all import *
Implementing a fastai Learner from Scratch
BasicLearner
class which trains a neural net to classify handwritten digits.
Background
In this notebook, I’ll work through the first “Further Research” exercise at the end of Chapter 4 of the Practical Deep Learning for Coders textbook:
Create your own implementation of Learner from scratch, based on the training loop shown in this chapter.
I’ve emphasized that this Learner
implementation is basic, based on what we’ve learned in Chapter 4. I’ll call my implementation BasicLearner
, as it corresponds to the BasicOptim
optimizer created in the chapter. I’ll use my BasicLearner
implementation to train a simple neural net on the MNIST_SAMPLE
dataset.
MNIST_SAMPLE
Training Loop
I’ll start by recreating the training loop in Chapter 4 to train a simple neural net to classify the handwritten digits 3s and 7s.
Load and Prepare the Data
The MNIST_SAMPLE
dataset is available through fastai’s URLs
which I download using untar_data
.
= untar_data(URLs.MNIST_SAMPLE) path
path.ls()
(#3) [Path('/root/.fastai/data/mnist_sample/train'),Path('/root/.fastai/data/mnist_sample/valid'),Path('/root/.fastai/data/mnist_sample/labels.csv')]
Then stack the list of training set and validation set tensor images into 3-dimensional tensors.
= torch.stack([tensor(Image.open(o)) for o in (path/'train'/'3').ls().sorted()]).float()/255
stacked_threes = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'7').ls().sorted()]).float()/255
stacked_sevens stacked_threes.shape, stacked_sevens.shape
(torch.Size([6131, 28, 28]), torch.Size([6265, 28, 28]))
0]); show_image(stacked_threes[
0]); show_image(stacked_sevens[
= torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()]).float()/255
valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()]).float()/255
valid_7_tens valid_3_tens.shape, valid_7_tens.shape
(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))
0]); show_image(valid_3_tens[
0]); show_image(valid_7_tens[
We then combine the training sets for 3s and 7s and “flatten” (not sure if that’s the right term) the tensors so that each image’s pixels are in a one-dimensional row.
= torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_x = tensor([1]*stacked_threes.shape[0] + [0]*stacked_sevens.shape[0]).unsqueeze(1)
train_y train_x.shape, train_y.shape
(torch.Size([12396, 784]), torch.Size([12396, 1]))
Then do the same for the validation sets:
= torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_x = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_y valid_x.shape, valid_y.shape
(torch.Size([2038, 784]), torch.Size([2038, 1]))
We create training and validation datasets with the same structure as PyTorch’s Dataset
:
= list(zip(train_x, train_y))
dset = dset[0]
x,y x.shape, y
(torch.Size([784]), tensor([1]))
= list(zip(valid_x, valid_y))
valid_dset = valid_dset[0]
x,y x.shape, y
(torch.Size([784]), tensor([1]))
Then feed those datasets into fastai’s DataLoader
s:
= DataLoader(dset, batch_size=256)
dl = first(dl)
xb,yb xb.shape, yb.shape
(torch.Size([256, 784]), torch.Size([256, 1]))
= DataLoader(valid_dset, batch_size=256)
valid_dl = first(valid_dl)
valid_xb, valid_yb valid_xb.shape, valid_yb.shape
(torch.Size([256, 784]), torch.Size([256, 1]))
Create Our Model
For this exercise they have us create a simple neural net with a ReLU sandwiched between two linear functions. I have kept the number of intermediate activations (30) the same as the text
= nn.Sequential(
simple_net 28*28, 30),
nn.Linear(
nn.ReLU(),30, 1)
nn.Linear( )
Create a Loss Function
The loss function we will use does the following:
- Pass the model’s activations through a sigmoid function so that they are between 0 and 1.
- When the target is 1 (the digit 3), take the difference between 1 and the activation. When the target is 0 (the digit 7), take the difference between 0 and the activation.
- Take the mean of the distance between activations and targets.
def mnist_loss(predictions, targets):
= predictions.sigmoid()
predictions return torch.where(targets==1, 1-predictions, predictions).mean()
Create a Function to Calculate Predictions, Loss and Gradients
The calc_grad
function takes as inputs the independent and dependent data batches, passes them through the model to get the activations (predictions), calculates the batch’s loss, and calls backward
on the loss to calculate the weights’ gradients:
def calc_grad(xb, yb, model):
= model(xb)
preds = mnist_loss(preds, yb)
loss loss.backward()
Create an Optimizer
The optimizer handles the calculation to step the weights and reset the gradients. When stepping the weights, the .data
attribute of the parameters is used since PyTorch doesn’t calculate gradients on it. The zero_grad
method sets the gradients to 0 (None
) so that they don’t accumulate additively when the next epoch’s gradients are calculated:
class BasicOptim:
def __init__(self,params,lr): self.params,self.lr = list(params),lr
def step(self, *args, **kwargs):
for p in self.params: p.data -= p.grad.data * self.lr
def zero_grad(self, *args, **kwargs):
for p in self.params: p.grad = None
= 0.1 lr
= BasicOptim(simple_net.parameters(), lr) opt
Create a Function to Train One Epoch
For each training epoch:
- Get a batch from the training
DataLoader
. - Calculate the activations, loss, and gradients.
- Step the weights in the direction opposite of the gradients.
- Reset the gradients to zero.
def train_epoch(model):
for xb,yb in dl:
calc_grad(xb, yb, model)
opt.step() opt.zero_grad()
Create a Function to Calculate a Metric for One Batch
The metric of choice in the chapter is accuracy, which is the mean of correctly predicted digits across the batch:
def batch_accuracy(xb, yb):
= xb.sigmoid()
preds = (preds>0.5) == yb
correct return correct.float().mean()
Create a Function to Calculate the Metric for One Epoch
For each batch in the validation DataLoader
, calculate the accuracy. Then, take the mean of all batch accuracy values as the accuracy for the epoch:
def validate_epoch(model):
= [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
accs return round(torch.stack(accs).mean().item(), 4)
Create a Function for the Training Loop
train_model
takes a model, and number of epochs that you want to train the model for as inputs. For each epoch, it trains the model on the training set batches, and outputs the epoch’s metric on the validation set batches:
def train_model(model, epochs):
for i in range(epochs):
train_epoch(model)print(validate_epoch(model), end=' ')
Train the Model
As is done in the text, I’ll train the model for 40 epochs.
40) train_model(simple_net,
0.5127 0.8013 0.9175 0.9419 0.957 0.9653 0.9672 0.9677 0.9687 0.9702 0.9726 0.9736 0.9745 0.9755 0.9755 0.9765 0.977 0.9785 0.9785 0.9785 0.9795 0.9799 0.9804 0.9809 0.9809 0.9814 0.9819 0.9819 0.9819 0.9824 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829
I get a similar starting and final accuracy as the example from the text.
BasicLearner
Class
My BasicLearner
should recreate the training process performed in the above sections. I’ll start by defining the inputs and outputs for an instance of this class:
Inputs and Outputs
The fastai Learner
requires the following inputs:
DataLoaders
with training and validation sets.- The model we want to train with.
- An optimizer function.
- A loss function.
- Any metrics we want calculated.
The Learner
outputs a table with the following information when a fit(epochs, lr)
method is called. I’ve bolded the items that I’m going to show in the first iteration of my Learner
:
- Epoch #.
- Training Loss.
- Validation Loss.
- Metric.
- Time.
With these inputs and outputs in mind, I’ll write the BasicLearner
class:
class BasicLearner:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def calc_grad(self, xb, yb, model):
= self.model(xb)
preds = self.loss_func(preds, yb)
loss
loss.backward()
def train_epoch(self):
for xb,yb in self.dls.train:
self.calc_grad(xb, yb, self.model)
self.opt.step()
self.opt.zero_grad()
def validate_epoch(self):
= [self.metric(self.model(xb), yb) for xb,yb in self.dls.valid]
accs return round(torch.stack(accs).mean().item(), 4)
def train_model(self, model, epochs):
print("Epoch", self.metric.__name__, sep="\t")
for i in range(self.epochs):
self.train_epoch()
print(i, self.validate_epoch(), sep="\t")
def fit(self, epochs, lr):
self.lr = lr
self.epochs = epochs
self.opt = self.opt_func(self.model.parameters(), self.lr)
self.train_model(self.model, self.epochs)
I’ll combine my training and validation DataLoader
s and confirm that they contain the correct number of tuples in their datasets:
= DataLoaders(dl, valid_dl) dls
len(dls.train.dataset)
12396
len(dls.valid.dataset)
2038
I’ll create a fresh neural net to use as a from-scratch model during training:
= nn.Sequential(
simple_net 28*28, 30),
nn.Linear(
nn.ReLU(),30, 1)
nn.Linear( )
I’ll instantiate my BasicLearner
class:
= BasicLearner(dls=dls,
learn =simple_net,
model=BasicOptim,
opt_func=mnist_loss,
loss_func=batch_accuracy) metric
And train the model:
40, 0.1) learn.fit(
Epoch batch_accuracy
0 0.5068
1 0.814
2 0.9184
3 0.9419
4 0.9575
5 0.9648
6 0.9663
7 0.9677
8 0.9692
9 0.9707
10 0.9736
11 0.9736
12 0.9741
13 0.9755
14 0.9765
15 0.9775
16 0.978
17 0.9785
18 0.979
19 0.979
20 0.979
21 0.9795
22 0.9795
23 0.9804
24 0.9804
25 0.9809
26 0.9814
27 0.9819
28 0.9819
29 0.9819
30 0.9814
31 0.9819
32 0.9819
33 0.9824
34 0.9824
35 0.9829
36 0.9829
37 0.9829
38 0.9829
39 0.9829
Looks good! I’m getting similar starting and ending accuracy values as before.
Improving the BasicLearner
Class
Now that I’ve confirmed that my BasicLearner
is able to train a neural net to get 98% accuracy classifying 3s and 7s, I would like to add a bit more functionality to the class.
First, I’d like to add a predict
method to the learner which will take as input a tensor image, and then output the prediction, so that I can test if my model has truly learned how to classify 3s and 7s.
class BasicLearner:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def calc_grad(self, xb, yb, model):
= self.model(xb)
preds = self.loss_func(preds, yb)
loss
loss.backward()
def train_epoch(self):
for xb,yb in self.dls.train:
self.calc_grad(xb, yb, self.model)
self.opt.step()
self.opt.zero_grad()
def validate_epoch(self):
= [self.metric(self.model(xb), yb) for xb,yb in self.dls.valid]
accs return round(torch.stack(accs).mean().item(), 4)
def train_model(self, model, epochs):
print("Epoch", self.metric.__name__, sep="\t")
for i in range(self.epochs):
self.train_epoch()
print(i, self.validate_epoch(), sep="\t")
def fit(self, epochs, lr):
self.lr = lr
self.epochs = epochs
self.opt = self.opt_func(self.model.parameters(), self.lr)
self.train_model(self.model, self.epochs)
def predict(self, x):
= self.model(x)
prediction = prediction.sigmoid()
prediction = "3" if prediction > 0.5 else "7"
label return prediction, label
I’ll instantiate a new model and BasicLearner
and train it again:
= nn.Sequential(
simple_net 28*28, 30),
nn.Linear(
nn.ReLU(),30, 1)
nn.Linear( )
= BasicLearner(dls=dls,
learn =simple_net,
model=BasicOptim,
opt_func=mnist_loss,
loss_func=batch_accuracy) metric
40, 0.1) learn.fit(
Epoch batch_accuracy
0 0.5073
1 0.8184
2 0.9194
3 0.9419
4 0.957
5 0.9638
6 0.9658
7 0.9672
8 0.9697
9 0.9706
10 0.9726
11 0.9741
12 0.9741
13 0.9755
14 0.976
15 0.9765
16 0.9765
17 0.978
18 0.978
19 0.978
20 0.9795
21 0.9795
22 0.9799
23 0.9809
24 0.9809
25 0.9814
26 0.9814
27 0.9814
28 0.9819
29 0.9814
30 0.9814
31 0.9824
32 0.9829
33 0.9829
34 0.9829
35 0.9829
36 0.9824
37 0.9824
38 0.9824
39 0.9824
With the model trained, I can see if it predicts an image of a 3 correctly:
1][0].view(-1,28,28)); show_image(dls.valid.dataset[
1][0]) learn.predict(dls.valid.dataset[
(tensor([1.0000], grad_fn=<SigmoidBackward0>), '3')
The final piece that I’ll add is a “training loss” column in the fit
method’s output during training. The training loss of each batch will be stored in a tensor, at the end of each epoch I’ll calculate the mean loss value, print it out, and reset the loss tensor to 0.
class BasicLearner:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def calc_grad(self, xb, yb, model):
= self.model(xb)
preds = self.loss_func(preds, yb)
loss # store the loss of each batch
# later to be averaged across the epoch
self.loss = torch.cat((self.loss, tensor([loss])))
loss.backward()
def train_epoch(self):
for xb,yb in self.dls.train:
self.calc_grad(xb, yb, self.model)
self.opt.step()
self.opt.zero_grad()
def validate_epoch(self):
= [self.metric(self.model(xb), yb) for xb,yb in self.dls.valid]
accs return round(torch.stack(accs).mean().item(), 4)
def train_model(self, model, epochs):
print("Epoch", "Train Loss", self.metric.__name__, sep="\t")
for i in range(self.epochs):
self.loss = tensor([])
self.train_epoch()
print(i, round(self.loss.mean().item(), 4), self.validate_epoch(), sep="\t\t")
def fit(self, epochs, lr):
self.lr = lr
self.epochs = epochs
self.opt = self.opt_func(self.model.parameters(), self.lr)
self.train_model(self.model, self.epochs)
def predict(self, x):
= self.model(x)
prediction = prediction.sigmoid()
prediction = "3" if prediction > 0.5 else "7"
label return prediction, label
= nn.Sequential(
simple_net 28*28, 30),
nn.Linear(
nn.ReLU(),30, 1)
nn.Linear( )
= BasicLearner(dls=dls,
learn =simple_net,
model=BasicOptim,
opt_func=mnist_loss,
loss_func=batch_accuracy) metric
40, 0.1) learn.fit(
Epoch Train Loss batch_accuracy
0 0.3627 0.5229
1 0.1088 0.7715
2 0.0593 0.9111
3 0.0439 0.9389
4 0.0375 0.9516
5 0.0337 0.9629
6 0.0311 0.9653
7 0.0291 0.9667
8 0.0275 0.9672
9 0.0261 0.9687
10 0.025 0.9721
11 0.0241 0.9736
12 0.0233 0.9746
13 0.0225 0.9755
14 0.0219 0.9755
15 0.0213 0.976
16 0.0208 0.9765
17 0.0204 0.978
18 0.02 0.9785
19 0.0196 0.9785
20 0.0193 0.979
21 0.0189 0.979
22 0.0186 0.979
23 0.0184 0.9799
24 0.0181 0.9804
25 0.0178 0.9804
26 0.0176 0.9804
27 0.0174 0.9804
28 0.0172 0.9804
29 0.017 0.9814
30 0.0168 0.9824
31 0.0166 0.9824
32 0.0164 0.9829
33 0.0163 0.9829
34 0.0161 0.9829
35 0.016 0.9824
36 0.0158 0.9829
37 0.0157 0.9829
38 0.0155 0.9829
39 0.0154 0.9829
# check prediction again
1][0]) learn.predict(dls.valid.dataset[
(tensor([1.0000], grad_fn=<SigmoidBackward0>), '3')
Further Improvements
My BasicLearner
is able to train a neural net classifying two digits to an accuracy of 98%. During training, it prints out the epoch number, training loss and metric. It also has a predict
method to test its classification on new tensor images. While I’m happy with the result of this exercise, there are certainly numerous improvements and additions that can be made to expand this learner to match the functionality of the fastai Learner
class.
I hope you enjoyed reading this blog post!