by Subhaditya Mukherjee
Here we will see what happens when we “dont” take Batchnorm for granted.
Batch Norm is one of the most widely used layers in a neural network. Ever since it came out, it became possible to train neural networks that were faster, more accurate and more resistant to change. Sounds almost magic doesn’t it? You would think, for something so magical, the implementation must be crazy hard. (You would be wrong)
What happened is that due to the black box nature of a neural network, we started taking this magic for granted. There were many assumptions of course about how and why it had the effect it does, but I recently found a paper which made a serious attempt to understand it.
So what happens when we just train the batchnorm layers and freeze everything else? What happens when we use different types of networks? Let us dig in…
Before we get to anything, a quick primer on batchnorm.
Why do we need it? Standardize inputs to the network. This will allow the network to “focus” and learn whats more important numerically speaking.
Now, how does it look (Note that these are the components and not the entire implementation)
Here we first generate a random array. The inputs γ, β are learnable parameters which we will look into later. ϵ is a very small number which will prevent our values from becoming 0. We first take the mean, then the variance and then we standardize the data using them. At the end we take a product and sum with the parameters.
Now, consider the random array is a batch of data, we apply this over the batch and instead of just the individual mean, we take a running mean. (aka a continuously changing value based on streaming data). Thats about it.
ran = rand(10,20) function bnorm_x(x,γ, β, ϵ = 1e-5) mean_x = sum(x)/length(x) #mean of x variance_x = sum((x .- mean_x).^2)/length(ran) #variance of x x̂ = (x .- mean_x)/sqrt(variance_x^2 + ϵ) @info size(x̂ ), size(β), size(γ) return γ.*x̂ + β end
Okay so this bit will be in Pytorch. I will try to explain everything I do but I cannot paste the full code here as this will become too huge. So I will just show what is different from a standard training here.
The entire code (with comments) can be found at my repository
So our main workflow remains the same as every other deep learning project.
We focus on steps 3 and 4 here.
Let us first get the libraries we need. aka Pytorch and tqdm (this is a tiny little progress bar helper which I absolutely adore)
import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm
Before we think about going ahead, let us first try to understand what we want. We need to train the network as usual, but make sure that only the batchnorm layes get trained. Just to see how far we can stretch it.
To do that, let us create a function which goes through our entire model, and if it finds any layer which is NOT batchnorm, it will tell pytorch to forget the gradients for that layer. In the process, making sure that only the Batchnorm layers get trained. We freeze the weights and biases for that reason.
def freezeOthers(m): for param in m.parameters(): if not isinstance(m, torch.nn.modules.batchnorm._BatchNorm): if hasattr(m, 'weight') and m.weight is not None: m.weight.requires_grad_(False) if hasattr(m, 'bias') and m.bias is not None: m.bias.requires_grad_(False)
Now for the main training loop. We first make the model trainable and send it to the GPU. Then we iterate over the data. Following pytorch standard loops, we reset the gradients and pass the batch through our model. We perform back propagation and step through our optimizer. And then we apply our previous function. After this, our model will forget the gradients for every other layer. We also add a tiny little option to print out the current loss. (This helps when you are looking at it train)
def train(args, model, device, train_loader, optimizer, epoch): model.train() # Setting model to train device = torch.device("cuda") # Sending to GPU for batch_idx, (data, target) in tqdm(enumerate(train_loader)): data, target = data.to(device), target.to(device) optimizer.zero_grad() #Reset grads output = model(data) # Passing batch through model loss = nn.CrossEntropyLoss()(output, target) # Calculate loss loss.backward() # Backprop optimizer.step() # Pass through optimizer model.apply(freezeOthers) # Dont train other layers if batch_idx % args.log_interval == 0: print(loss.item()) if args.dry_run: break
Well I tested ResNet110 on the CIFAR 10 dataset and for the normal network, I got a test accuracy of around 92%. While only training BatchNorm, I got to around 68% test accuracy. Now you might think, that is very far off. Yes, it is.. but notice that we are using only around .48% of the data Cite1 Wow…
But is this conclusive? Well I did a few more experiments, but instead let me point you towards an awesome blog post I found. It is by someone I admire and you can go and play around with the network and see the results for yourself instead of me giving you graphs.
Now that we understand a bit more about how expressive these layers seem to be. Let me share some points I found extremely interesting from the paperCite 1.
By now I think we come to realize that maybe we should not take what we think we know for granted. Sometimes it takes a difficult to digest paper to make one understand that. BatchNorm does play an important part in the network. And this somehow proves our need to be able to dig into the structure and the black box of Neural network architectures.
I would love to discuss further and if you have any questions do feel free to reach out in the comments. Or connect with me on Github.