from fastai.vision.all import *
from fastai.callback.hook import *
'image', cmap='Greys') matplotlib.rc(
Comparing CNN Performance by Varying Batch Normalization Placement
Background
In this notebook I’ll work through the following prompt given in the “Further Research” section of Chapter 13 (Convolutional Neural Networks):
Try moving the activation function after the batch normalization layer in
conv
. Does it make a difference? See what you can find out about what order is recommended and why.
= untar_data(URLs.MNIST)
path path.ls()
(#2) [Path('/root/.fastai/data/mnist_png/testing'),Path('/root/.fastai/data/mnist_png/training')]
Activation Function After Batch Normalization
I’ll first train a model with the same architecture as the textbook—the activation function (in this case nn.ReLU
) is placed after nn.BatchNorm2d
in nn.Sequential
.
def conv(ni, nf, ks=3, act=True):
= [nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)]
layers
layers.append(nn.BatchNorm2d(nf))if act: layers.append(nn.ReLU())
return nn.Sequential(*layers)
1,8) # activation function is after batch normalization conv(
Sequential(
(0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
I’ll re-use the simple_cnn
architecture as well as the get_dls
and fit
helper functions defined in the textbook:
def simple_cnn():
return sequential(
1, 8, ks=5), # 14x14
conv(8, 16), # 7x7
conv(16, 32), # 4x4
conv(32, 64), # 2x2
conv(64, 10, act=False), # 1x1
conv(
Flatten() )
# create a function to change dls params
def get_dls(bs=64):
return DataBlock(
=(ImageBlock(cls=PILImageBW), CategoryBlock),
blocks=get_image_files,
get_items=GrandparentSplitter('training', 'testing'),
splitter=parent_label,
get_y=Normalize()
batch_tfms=bs) ).dataloaders(path, bs
= get_dls(512)
dls =9, figsize=(4,4)) dls.show_batch(max_n
def fit(epochs=1, lr=0.06):
= Learner(dls, simple_cnn(), loss_func=F.cross_entropy, metrics=accuracy, cbs=ActivationStats(with_hist=True))
learn
learn.fit_one_cycle(epochs, lr)return learn
= fit() learn
/usr/local/lib/python3.10/dist-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (modules) that exists in the learner. Use `self.learn.modules` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.135710 | 0.056765 | 0.986600 | 01:10 |
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
I’ll take a look at the activation distribution across batches in a couple of ways:
color_dim
which will display an image where each column of pixels is a histogram of one batch of activations where lighter colors are close to zero and darker colors are non-zero.plot_layer_stats
which displays the mean, standard deviation and %-near-zero activations across batches during training.
The output of color_dim
looks as expected—a smooth transition from zero to non-zero activations.
-4) learn.activation_stats.color_dim(
The % near zero chart looks better than the counterexamples in the text but it still looks pretty high (~65% of the activations are near zero after batch #25).
-2) learn.activation_stats.plot_layer_stats(
Activation Function Before Batch Normalization
Next, I’ll reverse the order of nn.ReLU
and nn.BatchNorm2d
, placing the activation function before the batch normalization:
def conv(ni, nf, ks=3, act=True):
= [nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)]
layers if act: layers.append(nn.ReLU()) # activation function before batch norm
layers.append(nn.BatchNorm2d(nf))return nn.Sequential(*layers)
1,8) conv(
Sequential(
(0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
simple_cnn()
Sequential(
(0): Sequential(
(0): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
(1): ReLU()
(2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): Sequential(
(0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(2): Sequential(
(0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(3): Sequential(
(0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): ReLU()
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(4): Sequential(
(0): Conv2d(64, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(5): fastai.layers.Flatten(full=False)
)
The accuracy is slightly higher than when the activation function was placed after the batch norm (0.986600):
= fit() learn
/usr/local/lib/python3.10/dist-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (modules) that exists in the learner. Use `self.learn.modules` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.130553 | 0.055105 | 0.987000 | 01:09 |
The activations look pretty good and it seems like there are greater non-zero activations to start with in this model, which is an improvement.
-4) learn.activation_stats.color_dim(
The standard deviation of the activations are smoother than before. However, there are more near-zero activations (and % near zero increases during training) which is a downside to this architecture.
-2) learn.activation_stats.plot_layer_stats(
Running More Experiments
It’s tough to arrive at a conclusion with just one model for each architecture trained for just 1 epoch. To get a better sense of how performance is affected by the position of the activation function, I’ll train 10 models for 5 epochs for each architecture, storing the activation stats along the way.
To store the activation stats, I’ll grab the layer_stats
for the penultimate layer (idx = -2
) for each training run.
I’ll store the layer_stats
object in an L
object so I can use the itemgot
method to get the mean, standard deviation and %-near-zero values across trainings.
I’ll also store each training run’s hist
ograms to view the output of color_dim
manually (using idx = -4
as done in the textbook):
= learn.activation_stats.hist(-4) h
= subplots(figsize=(10,5))[1][0]
ax ='lower');
ax.imshow(h, origin'off'); ax.axis(
Activation Function After Batch Normalization
I’ll update fit
so that it doesn’t display the progress bar and logging during training (so that I don’t get a screen full of logs):
def fit(epochs=5, lr=0.06):
= Learner(dls, simple_cnn(), loss_func=F.cross_entropy, metrics=accuracy, cbs=ActivationStats(with_hist=True))
learn with learn.no_logging(), learn.no_bar(): learn.fit_one_cycle(epochs, lr)
return learn
I’ll train 10 models, each for 5 epochs, and save the layer_stats
and hist
for each model. I’ll also make sure to define conv
such that the activation function is after the batch normalization.
These training runs should take about 50 minutes in total.
= L()
layer_stats_after = L()
hist_after
def conv(ni, nf, ks=3, act=True):
= [nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)]
layers
layers.append(nn.BatchNorm2d(nf))if act: layers.append(nn.ReLU()) # activation function after batch norm
return nn.Sequential(*layers)
for _ in range(10):
= fit()
learn -2))
layer_stats_after.append(learn.activation_stats.layer_stats(-4)) hist_after.append(learn.activation_stats.hist(
/usr/local/lib/python3.10/dist-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (modules) that exists in the learner. Use `self.learn.modules` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Activation Function Before Batch Normalization
I’ll run a similar set of trainings for the architecture where the activation function is placed before batch normalization.
= L()
layer_stats_before = L()
hist_before
def conv(ni, nf, ks=3, act=True):
= [nn.Conv2d(ni, nf, stride=2, kernel_size=ks, padding=ks//2)]
layers if act: layers.append(nn.ReLU()) # activation function before batch norm
layers.append(nn.BatchNorm2d(nf))return nn.Sequential(*layers)
for _ in range(10):
= fit()
learn -2))
layer_stats_before.append(learn.activation_stats.layer_stats(-4)) hist_before.append(learn.activation_stats.hist(
/usr/local/lib/python3.10/dist-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (modules) that exists in the learner. Use `self.learn.modules` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
Analysis of Training Results
I’ll start by plotting the mean, standard deviation and %-near-zero activations for both architectures (activation function before/after batch norm).
I’ll create this helper function to plot the spread of mean, standard deviation and %-near-zero activations:
def plot_stats_spread(layer_stats, titles, super_title):
= plt.subplots(1, 3, figsize=(15, 5))
fig, axs =16)
fig.suptitle(super_title, fontsize
for i, title in enumerate(titles):
= layer_stats.itemgot(i)
y = range(len(y[0]))
x
= np.minimum.reduce(y)
y_min = np.maximum.reduce(y)
y_max
=0.5)
axs[i].fill_between(x, y_min, y_max, alphaf"Spread of {title} Activations Across 10 Trainings")
axs[i].set_title("Batch")
axs[i].set_xlabel(f"{title}")
axs[i].set_ylabel(
plt.tight_layout() plt.show()
= ['Mean', 'Std', '%-near-zero'] titles
In the following six charts, I plot the spread or range of activation stats over 5 epochs across the 10 trainings. I chose to plot the spread (instead of 10 line plots) so that you can more easily see the range of values for each statistic.
When the activation function is placed after the batch norm, the standard deviations and the %-age of activations near zero are lower than when the activation is placed before the batch norm. This shows the benefit of placing the activation after the batch norm.
"Activation Function After Batch Norm") plot_stats_spread(layer_stats_after, titles,
"Activation Function Before Batch Norm") plot_stats_spread(layer_stats_before, titles,
I’ll create a helper function to plot the mean and median values for each statistic:
def plot_stats_avg(layer_stats, titles, super_title):
= plt.subplots(1, 3, figsize=(15, 5))
fig, axs =16)
fig.suptitle(super_title, fontsize
for i, title in enumerate(titles):
= layer_stats.itemgot(i)
y = range(len(y[0]))
x
= np.mean(np.stack(y), axis=0)
y_mean = np.median(np.stack(y), axis=0)
y_median
='red', alpha=0.5, linewidth=1)
axs[i].plot(x, y_mean, color='blue', alpha=0.5, linewidth=1)
axs[i].plot(x, y_median, colorf"Average {title} Activations Across 10 Trainings")
axs[i].set_title("Batch")
axs[i].set_xlabel(f"{title}")
axs[i].set_ylabel(
plt.tight_layout() plt.show()
The mean (red) and median (blue) values across the 10 training runs for each architecture show the same trends: lower standard deviation and lower %-near-zero activations when the activation function is placed after batch norm. Additionally, the average mean activations for the first 100 or so batches are less than 0 when the activation function is placed after the batch norm.
Overall, these plots reiterate the benefit of placing the activation function after the batch norm.
"Activation Function After Batch Norm") plot_stats_avg(layer_stats_after, titles,
"Activation Function Before Batch Norm") plot_stats_avg(layer_stats_before, titles,
Next, I’ll look at histograms of mean and median activation values for both architectures:
def plot_hist_avg(hist, super_title):
= plt.subplots(2, 1, figsize=(20, 5))
fig, axs =16)
fig.suptitle(super_title, fontsize= torch.stack(list(hist)).mean(0)
h_mean = torch.stack(list(hist)).median(0)[0]
h_median
0].imshow(h_mean, origin='lower');
axs[0].set_title(f"Mean Activations Across 10 Trainings")
axs[0].axis('off');
axs[
1].imshow(h_median, origin='lower');
axs[1].set_title(f"Median Activations Across 10 Trainings")
axs[1].axis('off'); axs[
It’s tough to visually distinguish between the different shades of gray but the clear difference is that the non-zero activations when the activation function is placed after batch norm increase much slower than when it’s placed before.
"Activation Function After Batch Norm") plot_hist_avg(hist_after,
"Activation Function Before Batch Norm") plot_hist_avg(hist_before,
References to Batch Norm Placement
My quick experiments show that placing the activation function after the batch normalization is preferred because it decreases the %-age of activations that are near zero, and decreases the standard deviation of the activations, both which are signs of “stable” training.
However, these are just quick, rough experiments. What do the experts say? I googled around and found the following:
This fastai forums thread which led me to more resources.
In the original Batch Normalization paper the authors state (emphasis mine):
We add the BN transform immediately before the nonlinearity, by normalizing \(x = Wu+ b\)
A quick note that BN before nonlinearity = nonlinearity after BN (which is the terminology I have been using: “activation function before BN” and “activation function after BN”).
They compare placing BN before and after the nonlinearity—my experiments matches their explanation, that the activations are sparser (i.e. there are more near-zero activations) when BN is placed after the nonlinearity (emphasis mine):
The goal of Batch Normalization is to achieve a stable distribution of activation values throughout training, and in our experiments we apply it before the nonlinearity since that is where matching the first and second moments is more likely to result in a stable distribution. On the contrary, (Gulcehre & Bengio, 2013) apply the standardization layer to the output of the nonlinearity, which results in sparser activations. In our large-scale image classification experiments, we have not observed the nonlinearity inputs to be sparse, neither with nor without Batch Normalization.
I didn’t explore it here, but they also compare placing BN to the layer inputs before the weights are applied (which can be seen as another way of placing it after the nonlinearity; emphasis mine):
We could have also normalized the layer inputs \(u\), but since \(u\) is likely the output of another nonlinearity, the shape of its distribution is likely to change during training, and constraining its first and second moments would not eliminate the covariate shift. In contrast, \(Wu + b\) is more likely to have a symmetric, non-sparse distribution, that is “more Gaussian” (Hyvarinen & Oja, 2000); normalizing it is likely to produce activations with a stable distribution.
In a 2016 Keras GitHub comment Francois Chollet said (emphasis mine):
I haven’t gone back to check what they are suggesting in their original paper, but I can guarantee that recent code written by Christian applies relu before BN. It is still occasionally a topic of debate, though.
Final Thoughts
This portion of the Chapter got me excited because I felt I was finally peaking under the hood to see the dynamics of a neural net training run. The activation_stats
callback is a simple but powerful tool to understand how different architectural changes are affecting training dynamics. I look forward to digging deeper under the hood in part 2 of the course!
I hope you enjoyed this blog post. Follow me on Twitter @vishal_learner.