from fastai.vision.all import *Implementing a fastai Learner from Scratch
BasicLearner class which trains a neural net to classify handwritten digits.
Background
In this notebook, I’ll work through the first “Further Research” exercise at the end of Chapter 4 of the Practical Deep Learning for Coders textbook:
Create your own implementation of Learner from scratch, based on the training loop shown in this chapter.
I’ve emphasized that this Learner implementation is basic, based on what we’ve learned in Chapter 4. I’ll call my implementation BasicLearner, as it corresponds to the BasicOptim optimizer created in the chapter. I’ll use my BasicLearner implementation to train a simple neural net on the MNIST_SAMPLE dataset.
MNIST_SAMPLE Training Loop
I’ll start by recreating the training loop in Chapter 4 to train a simple neural net to classify the handwritten digits 3s and 7s.
Load and Prepare the Data
The MNIST_SAMPLE dataset is available through fastai’s URLs which I download using untar_data.
path = untar_data(URLs.MNIST_SAMPLE)path.ls()(#3) [Path('/root/.fastai/data/mnist_sample/train'),Path('/root/.fastai/data/mnist_sample/valid'),Path('/root/.fastai/data/mnist_sample/labels.csv')]
Then stack the list of training set and validation set tensor images into 3-dimensional tensors.
stacked_threes = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'3').ls().sorted()]).float()/255
stacked_sevens = torch.stack([tensor(Image.open(o)) for o in (path/'train'/'7').ls().sorted()]).float()/255
stacked_threes.shape, stacked_sevens.shape(torch.Size([6131, 28, 28]), torch.Size([6265, 28, 28]))
show_image(stacked_threes[0]);
show_image(stacked_sevens[0]);
valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()]).float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()]).float()/255
valid_3_tens.shape, valid_7_tens.shape(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))
show_image(valid_3_tens[0]);
show_image(valid_7_tens[0]);
We then combine the training sets for 3s and 7s and “flatten” (not sure if that’s the right term) the tensors so that each image’s pixels are in a one-dimensional row.
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*stacked_threes.shape[0] + [0]*stacked_sevens.shape[0]).unsqueeze(1)
train_x.shape, train_y.shape(torch.Size([12396, 784]), torch.Size([12396, 1]))
Then do the same for the validation sets:
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_x.shape, valid_y.shape(torch.Size([2038, 784]), torch.Size([2038, 1]))
We create training and validation datasets with the same structure as PyTorch’s Dataset:
dset = list(zip(train_x, train_y))
x,y = dset[0]
x.shape, y(torch.Size([784]), tensor([1]))
valid_dset = list(zip(valid_x, valid_y))
x,y = valid_dset[0]
x.shape, y(torch.Size([784]), tensor([1]))
Then feed those datasets into fastai’s DataLoaders:
dl = DataLoader(dset, batch_size=256)
xb,yb = first(dl)
xb.shape, yb.shape(torch.Size([256, 784]), torch.Size([256, 1]))
valid_dl = DataLoader(valid_dset, batch_size=256)
valid_xb, valid_yb = first(valid_dl)
valid_xb.shape, valid_yb.shape(torch.Size([256, 784]), torch.Size([256, 1]))
Create Our Model
For this exercise they have us create a simple neural net with a ReLU sandwiched between two linear functions. I have kept the number of intermediate activations (30) the same as the text
simple_net = nn.Sequential(
nn.Linear(28*28, 30),
nn.ReLU(),
nn.Linear(30, 1)
)Create a Loss Function
The loss function we will use does the following:
- Pass the model’s activations through a sigmoid function so that they are between 0 and 1.
- When the target is 1 (the digit 3), take the difference between 1 and the activation. When the target is 0 (the digit 7), take the difference between 0 and the activation.
- Take the mean of the distance between activations and targets.
def mnist_loss(predictions, targets):
predictions = predictions.sigmoid()
return torch.where(targets==1, 1-predictions, predictions).mean()Create a Function to Calculate Predictions, Loss and Gradients
The calc_grad function takes as inputs the independent and dependent data batches, passes them through the model to get the activations (predictions), calculates the batch’s loss, and calls backward on the loss to calculate the weights’ gradients:
def calc_grad(xb, yb, model):
preds = model(xb)
loss = mnist_loss(preds, yb)
loss.backward()Create an Optimizer
The optimizer handles the calculation to step the weights and reset the gradients. When stepping the weights, the .data attribute of the parameters is used since PyTorch doesn’t calculate gradients on it. The zero_grad method sets the gradients to 0 (None) so that they don’t accumulate additively when the next epoch’s gradients are calculated:
class BasicOptim:
def __init__(self,params,lr): self.params,self.lr = list(params),lr
def step(self, *args, **kwargs):
for p in self.params: p.data -= p.grad.data * self.lr
def zero_grad(self, *args, **kwargs):
for p in self.params: p.grad = Nonelr = 0.1opt = BasicOptim(simple_net.parameters(), lr)Create a Function to Train One Epoch
For each training epoch:
- Get a batch from the training
DataLoader. - Calculate the activations, loss, and gradients.
- Step the weights in the direction opposite of the gradients.
- Reset the gradients to zero.
def train_epoch(model):
for xb,yb in dl:
calc_grad(xb, yb, model)
opt.step()
opt.zero_grad()Create a Function to Calculate a Metric for One Batch
The metric of choice in the chapter is accuracy, which is the mean of correctly predicted digits across the batch:
def batch_accuracy(xb, yb):
preds = xb.sigmoid()
correct = (preds>0.5) == yb
return correct.float().mean()Create a Function to Calculate the Metric for One Epoch
For each batch in the validation DataLoader, calculate the accuracy. Then, take the mean of all batch accuracy values as the accuracy for the epoch:
def validate_epoch(model):
accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
return round(torch.stack(accs).mean().item(), 4)Create a Function for the Training Loop
train_model takes a model, and number of epochs that you want to train the model for as inputs. For each epoch, it trains the model on the training set batches, and outputs the epoch’s metric on the validation set batches:
def train_model(model, epochs):
for i in range(epochs):
train_epoch(model)
print(validate_epoch(model), end=' ')Train the Model
As is done in the text, I’ll train the model for 40 epochs.
train_model(simple_net, 40)0.5127 0.8013 0.9175 0.9419 0.957 0.9653 0.9672 0.9677 0.9687 0.9702 0.9726 0.9736 0.9745 0.9755 0.9755 0.9765 0.977 0.9785 0.9785 0.9785 0.9795 0.9799 0.9804 0.9809 0.9809 0.9814 0.9819 0.9819 0.9819 0.9824 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829 0.9829
I get a similar starting and final accuracy as the example from the text.
BasicLearner Class
My BasicLearner should recreate the training process performed in the above sections. I’ll start by defining the inputs and outputs for an instance of this class:
Inputs and Outputs
The fastai Learner requires the following inputs:
DataLoaderswith training and validation sets.- The model we want to train with.
- An optimizer function.
- A loss function.
- Any metrics we want calculated.
The Learner outputs a table with the following information when a fit(epochs, lr) method is called. I’ve bolded the items that I’m going to show in the first iteration of my Learner:
- Epoch #.
- Training Loss.
- Validation Loss.
- Metric.
- Time.
With these inputs and outputs in mind, I’ll write the BasicLearner class:
class BasicLearner:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def calc_grad(self, xb, yb, model):
preds = self.model(xb)
loss = self.loss_func(preds, yb)
loss.backward()
def train_epoch(self):
for xb,yb in self.dls.train:
self.calc_grad(xb, yb, self.model)
self.opt.step()
self.opt.zero_grad()
def validate_epoch(self):
accs = [self.metric(self.model(xb), yb) for xb,yb in self.dls.valid]
return round(torch.stack(accs).mean().item(), 4)
def train_model(self, model, epochs):
print("Epoch", self.metric.__name__, sep="\t")
for i in range(self.epochs):
self.train_epoch()
print(i, self.validate_epoch(), sep="\t")
def fit(self, epochs, lr):
self.lr = lr
self.epochs = epochs
self.opt = self.opt_func(self.model.parameters(), self.lr)
self.train_model(self.model, self.epochs)I’ll combine my training and validation DataLoaders and confirm that they contain the correct number of tuples in their datasets:
dls = DataLoaders(dl, valid_dl)len(dls.train.dataset)12396
len(dls.valid.dataset)2038
I’ll create a fresh neural net to use as a from-scratch model during training:
simple_net = nn.Sequential(
nn.Linear(28*28, 30),
nn.ReLU(),
nn.Linear(30, 1)
)I’ll instantiate my BasicLearner class:
learn = BasicLearner(dls=dls,
model=simple_net,
opt_func=BasicOptim,
loss_func=mnist_loss,
metric=batch_accuracy)And train the model:
learn.fit(40, 0.1)Epoch batch_accuracy
0 0.5068
1 0.814
2 0.9184
3 0.9419
4 0.9575
5 0.9648
6 0.9663
7 0.9677
8 0.9692
9 0.9707
10 0.9736
11 0.9736
12 0.9741
13 0.9755
14 0.9765
15 0.9775
16 0.978
17 0.9785
18 0.979
19 0.979
20 0.979
21 0.9795
22 0.9795
23 0.9804
24 0.9804
25 0.9809
26 0.9814
27 0.9819
28 0.9819
29 0.9819
30 0.9814
31 0.9819
32 0.9819
33 0.9824
34 0.9824
35 0.9829
36 0.9829
37 0.9829
38 0.9829
39 0.9829
Looks good! I’m getting similar starting and ending accuracy values as before.
Improving the BasicLearner Class
Now that I’ve confirmed that my BasicLearner is able to train a neural net to get 98% accuracy classifying 3s and 7s, I would like to add a bit more functionality to the class.
First, I’d like to add a predict method to the learner which will take as input a tensor image, and then output the prediction, so that I can test if my model has truly learned how to classify 3s and 7s.
class BasicLearner:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def calc_grad(self, xb, yb, model):
preds = self.model(xb)
loss = self.loss_func(preds, yb)
loss.backward()
def train_epoch(self):
for xb,yb in self.dls.train:
self.calc_grad(xb, yb, self.model)
self.opt.step()
self.opt.zero_grad()
def validate_epoch(self):
accs = [self.metric(self.model(xb), yb) for xb,yb in self.dls.valid]
return round(torch.stack(accs).mean().item(), 4)
def train_model(self, model, epochs):
print("Epoch", self.metric.__name__, sep="\t")
for i in range(self.epochs):
self.train_epoch()
print(i, self.validate_epoch(), sep="\t")
def fit(self, epochs, lr):
self.lr = lr
self.epochs = epochs
self.opt = self.opt_func(self.model.parameters(), self.lr)
self.train_model(self.model, self.epochs)
def predict(self, x):
prediction = self.model(x)
prediction = prediction.sigmoid()
label = "3" if prediction > 0.5 else "7"
return prediction, labelI’ll instantiate a new model and BasicLearner and train it again:
simple_net = nn.Sequential(
nn.Linear(28*28, 30),
nn.ReLU(),
nn.Linear(30, 1)
)learn = BasicLearner(dls=dls,
model=simple_net,
opt_func=BasicOptim,
loss_func=mnist_loss,
metric=batch_accuracy)learn.fit(40, 0.1)Epoch batch_accuracy
0 0.5073
1 0.8184
2 0.9194
3 0.9419
4 0.957
5 0.9638
6 0.9658
7 0.9672
8 0.9697
9 0.9706
10 0.9726
11 0.9741
12 0.9741
13 0.9755
14 0.976
15 0.9765
16 0.9765
17 0.978
18 0.978
19 0.978
20 0.9795
21 0.9795
22 0.9799
23 0.9809
24 0.9809
25 0.9814
26 0.9814
27 0.9814
28 0.9819
29 0.9814
30 0.9814
31 0.9824
32 0.9829
33 0.9829
34 0.9829
35 0.9829
36 0.9824
37 0.9824
38 0.9824
39 0.9824
With the model trained, I can see if it predicts an image of a 3 correctly:
show_image(dls.valid.dataset[1][0].view(-1,28,28));
learn.predict(dls.valid.dataset[1][0])(tensor([1.0000], grad_fn=<SigmoidBackward0>), '3')
The final piece that I’ll add is a “training loss” column in the fit method’s output during training. The training loss of each batch will be stored in a tensor, at the end of each epoch I’ll calculate the mean loss value, print it out, and reset the loss tensor to 0.
class BasicLearner:
def __init__(self, dls, model, opt_func, loss_func, metric):
self.dls = dls
self.model = model
self.opt_func = opt_func
self.loss_func = loss_func
self.metric = metric
def calc_grad(self, xb, yb, model):
preds = self.model(xb)
loss = self.loss_func(preds, yb)
# store the loss of each batch
# later to be averaged across the epoch
self.loss = torch.cat((self.loss, tensor([loss])))
loss.backward()
def train_epoch(self):
for xb,yb in self.dls.train:
self.calc_grad(xb, yb, self.model)
self.opt.step()
self.opt.zero_grad()
def validate_epoch(self):
accs = [self.metric(self.model(xb), yb) for xb,yb in self.dls.valid]
return round(torch.stack(accs).mean().item(), 4)
def train_model(self, model, epochs):
print("Epoch", "Train Loss", self.metric.__name__, sep="\t")
for i in range(self.epochs):
self.loss = tensor([])
self.train_epoch()
print(i, round(self.loss.mean().item(), 4), self.validate_epoch(), sep="\t\t")
def fit(self, epochs, lr):
self.lr = lr
self.epochs = epochs
self.opt = self.opt_func(self.model.parameters(), self.lr)
self.train_model(self.model, self.epochs)
def predict(self, x):
prediction = self.model(x)
prediction = prediction.sigmoid()
label = "3" if prediction > 0.5 else "7"
return prediction, labelsimple_net = nn.Sequential(
nn.Linear(28*28, 30),
nn.ReLU(),
nn.Linear(30, 1)
)learn = BasicLearner(dls=dls,
model=simple_net,
opt_func=BasicOptim,
loss_func=mnist_loss,
metric=batch_accuracy)learn.fit(40, 0.1)Epoch Train Loss batch_accuracy
0 0.3627 0.5229
1 0.1088 0.7715
2 0.0593 0.9111
3 0.0439 0.9389
4 0.0375 0.9516
5 0.0337 0.9629
6 0.0311 0.9653
7 0.0291 0.9667
8 0.0275 0.9672
9 0.0261 0.9687
10 0.025 0.9721
11 0.0241 0.9736
12 0.0233 0.9746
13 0.0225 0.9755
14 0.0219 0.9755
15 0.0213 0.976
16 0.0208 0.9765
17 0.0204 0.978
18 0.02 0.9785
19 0.0196 0.9785
20 0.0193 0.979
21 0.0189 0.979
22 0.0186 0.979
23 0.0184 0.9799
24 0.0181 0.9804
25 0.0178 0.9804
26 0.0176 0.9804
27 0.0174 0.9804
28 0.0172 0.9804
29 0.017 0.9814
30 0.0168 0.9824
31 0.0166 0.9824
32 0.0164 0.9829
33 0.0163 0.9829
34 0.0161 0.9829
35 0.016 0.9824
36 0.0158 0.9829
37 0.0157 0.9829
38 0.0155 0.9829
39 0.0154 0.9829
# check prediction again
learn.predict(dls.valid.dataset[1][0])(tensor([1.0000], grad_fn=<SigmoidBackward0>), '3')
Further Improvements
My BasicLearner is able to train a neural net classifying two digits to an accuracy of 98%. During training, it prints out the epoch number, training loss and metric. It also has a predict method to test its classification on new tensor images. While I’m happy with the result of this exercise, there are certainly numerous improvements and additions that can be made to expand this learner to match the functionality of the fastai Learner class.
I hope you enjoyed reading this blog post!