In my last post, I briefly discussed the infuriating fact that a neural network, even when deeply flawed, will often “work” in the sense that it’ll do above-random at classification or a generative network might create things that may sometimes look plausibly from the dataset.
Given an idea that you’re testing out that is performing poorly – how, then, do you tell the difference between a botched implementation and an idea that just isn’t good? I think this is one of the toughest questions I have to deal with on a daily basis as an ML engineer. It’s the difference between funneling an immense amount of work into an idea that doesn’t pan out (which happens often!) or calling it early to look at something else.
I definitely don’t have all the answers, but I have gathered a few tricks over the last couple of years that I wanted to share:
Know how to interpret your loss curves. Different classes of NNs will have different loss curve shapes, but rarely do tweaks to a NN result in a change to the shape – generally performance will increase or decrease in a stepwise fashion. New notches in the curves or progressively diverging performance is an interesting phenomenon you should consider digging into more. In particular – are you sure that you didn’t add more compute to your NN?
For generative models, build a good eval and plot it regularly. This eval should actually sample an image/audio clip/etc from your generator and the eval should measure some aspect of your modality that you are not actively optimizing. My favorite approach here is to take a pre-trained classifier for a modality and use it to generate loss values between real/generated classes given a label. Evals that judge sampling results will often show performance differences that are not easily visible in the training loss.
Plot grad and param norms. In runs that will likely go divergent before long, you’ll notice spikes in the grad norms start to occur. When training with FP16, these will occur at a regular intervals as the scales overflow, but will begin to occur more often as a NN becomes less stable. If you’re seeing an abnormal, or increasing level of grad spikes in your logs, it’s a good thing to consider stopping training runs and re-calibrating some of your hyperparameters (increased weight decay can reduce param and correspondingly grad norms, a lower learning rate can also help reduce both norms).
Param norms are a good early warning sign for parameters which are overfitting some aspect of the dataset. These seem to ultimately be the cause of grad spikes and ultimately training divergence. Things to look for are param norms that appear to be growing without an upper bound. If you plot the norm for all parameters individually, you can sometimes catch this earlier, since sometimes this curve takes a long time to “warm” up for some parameters. The solution to exploding param norms is the same – increased weight decay or lower learning rates (or rethink some aspect of your architecture).
Activation norms are a good thing to watch to judge the long term stability of your network. They should generally correlate pretty well parameter norms, but they also incorporate variance that can be attributed to the dataset. A good indication of a network that is becoming brittle and will soon fail is one with very high activation norms. This isn’t universally true, but in my experience it is a great way to compare the stability of two networks: If you have a baseline that works and a new network with higher activation norms, the latter will be more likely to diverge over the long term.
Activation norms can also be used to isolate the specific causes of a failure. Generally only a small set of the norms will be divergent (for example, several of the projection weights for your MLP layers). Knowing where you are having numeric problems can help when trying to fix these types of problems – by adding normalization layers or L2 weight regularization, for example.