Batch normalization has a simple goal: stabilize the gradients of large computational graphs. In doing so, this technique has enabled the deep learning renaissance that almost every major ML breakthrough in the last 5 years has relied on.
The concept is sound: by regularizing the mean and variance of the inputs of nearly every layer in a neural network, the gradients of that network rarely explode backward pass. The end result is that many neural networks can be easily trained with gradient techniques that would otherwise have never converged.
So why am I calling it a hack? Let’s dig in.
It violates batch invariance
When training a neural network with SGD, we often use mini-batches. Mini-batches allow a computational graph to average the gradients of a forward pass across a greater portion of the input space. They also allow us to easily scale training to fit the computational resources available to us.
There is an important caveat to mini-batches: the choice of which inputs are included in a batch are supposed to be totally random. As a result, there is no structure to the batch dimension of a mini-batched input vector. Therefore, it is always wrong to compute across this structure, which is precisely what batch normalization does.
To give a grounded example of why this is bad, consider some work on GANs I have been doing recently. The discriminator of this GAN is a VGG-style network that used batch normalization. In conventional GAN style, the discriminator discerns whether or not the inputs it is given are real or fake. Since fake inputs must backprop through the generator, they are always batched together. To provide symmetry, real inputs are also batched together.
While analyzing the performance of the GAN, I noticed that the discriminator was always predicting a near-uniform value across the batch dimension. E.g. if one image was predicted as fake, all of the images in the minibatch would be predicted fake and vice versa. This struck me as odd: shouldn’t the discriminator be making predictions per-image and not per-batch?
It turns out this odd behavior was caused by batch normalization. Switching all of the VGG blocks in the discriminator to group norm caused the issue to go away and improved the generator’s inception score in validation by 5%!
I’m not the only one to experience these issues with batch normalization in GANs. In this paper, Xiang et al took an extensive look at batch normalization in Generators and found evidence of performance loss in the end result.
I think it’d be pretty silly to assume that this “bad behavior” is constrained only to GANs. While I can’t point to any specific evidence, I wouldn’t be surprised to see research showing that batch norm hurts the some part of the performance of most neural networks and I’d bet that violating batch invariance is the root cause.
Scaling? Good luck.
Since batch normalization operates across the batch axis, it doesn’t play nicely with data parallelism. While the major ML frameworks have workarounds for this problem, these workarounds bottleneck the network with expensive reduce-alls scattered throughout the graph.
This isn’t really an issue with simple networks because you can simply normalize across the batch dimension per-compute unit, but as your neural network begins to scale to the point that you can no longer fit ~8 samples per compute unit, the normalization really starts to break down. The result is degraded network performance or slower training.
Things get more dire when you don’t have access to a lot of parallel compute. The best way to train large networks in this situation is via gradient accumulation, but this doesn’t play well with batch norm and there are currently no good workarounds pre-implemented in the big frameworks.
Lazy Weight Initialization
The authors of the batch normalization paper argued that it reduces the effect of gradient explosions caused by improper weight initialization.
This seems to be the case, as evidenced by its widespread use in deep residual networks, which have issues with exploding means. In their work titled “Fixup Initialization”, Zhang et al dig into whether or not it is possible to train deep residual networks without batch normalization. They accomplish their goal primarily by deriving a new weight initialization algorithm for use specifically in residual networks. Using Fixup, the authors were able to achieve near-BN performance without any normalization.
The interesting conclusion that can be drawn from this paper is that normalization is (at least in some cases) less a necessity than it is a way to cover up the need for deriving proper weight initialization schemes. It is not hard to do some analysis on a residual network without BN or Fixup to see that the means and variance quickly get out of whack as the network depth increases.
All this is to say that while it might be appropriate to perform normalization in the experimental phase of research, we should hold the ML community at a higher standard when we start using models in production: these models should be closely analyzed for the possibility of improvements via weight initialization. At the very least, weights should be initialized so that the networks can be trained without normalization, even if normalization is still used for performance reasons. I would love to see work similar to FixUp be done for other ML focus areas like NLP.
So it’s a hack..
I think batch normalization is far overused in current models. The situation is a combination of lack of understanding combined with a lack of good research into the use of alternatives. So much research builds upon existing SOTA models that it becomes difficult to reconsider using such a fundamental part of the model, even when it has so many drawbacks.
The next time you are building a model, I urge you to think twice about using batch normalization. For image researchers: try out group normalization, I bet you will be surprised by the results. Other normalization schemes like instance norm can work well too. You might also consider developing your own normalization – it actually isn’t too hard. Often all you need to do is find an input dimension with wide variance and perform normalization across it.
If you are putting a model into production, you should be doing some investigations into weight initialization and trainable scaling. A well-designed model should not need normalization to train properly, and you likely have performance to gain if it does.