Triforce: A general recipe for kickass Generative Models

For the past two years, I’ve been tinkering around with generative models in my spare time. I think I’ve landed on an approach that produces by far the most compelling results available today, and which scales like big language models. I’d like to outline the approach here.

First of all, I want to touch on something that’ll become immediately obvious: this isn’t a novel architecture or anything. In fact, it is pretty much OpenAI’s DALL E with a diffusion upsampler attached. Instead, it’s a way of thinking how one can (1) improve upon DALL E and (2) universally model generative domains using a single set of techniques.

Three Models

This approach uses three different neural networks to produce the finished result, all trained separately from one another.

The first is a Discrete VAE. This model is responsible for translating your base medium (for example, images, music, video, etc) into a string of integers. The DVAE preserves the structure of your medium, but simplifies the contents of it into discrete bins that can be reasoned about. The DVAE can also compress the medium so it is more computationally tractable to reason about.

The second is the causal transformer, essentially a GPT model. This model is trained in next-token-prediction where the tokens are the discrete outputs of the DVAE. These models are especially neat because you can throw anything you like into the sequence and they will learn how to reason about them. Have text and audio and want to produce images? Discretize all three and throw them into your causal transformer! It’ll learn how to convert between these mediums and predict image tokens. Want to flip the problem around and predict text from images and audio clips? Just flip this sequence around! The flexibility of this architecture is incredible.

The final stage is the diffusion network. To understand why this is necessary, you have to first understand that DVAE’s have absolutely awful decoders. They are always lossy and that cannot be fixed because VAEs do not scale. Anecdotally – this is almost certainly the reason that DALL E’s generates are so blurry.

Diffusion models are, bar none, the best super resolution models in existence. What is good for super resolution is also good for upsampling the output of your DVAE decoder. You simply feed the output of your DVAE as a prior to your diffusion model and train it to reproduce full resolution images. Unlike DVAEs, diffusion models respond excellently to scaling. Unlike GANs, diffusion models do not suffer from mode collapse.

That is the slowest generative model in existence

… You’re right. Autoregressive transformers are slow to sample from, and so are diffusion networks. This is not a fast technique, and it’ll likely never see use on the edge. However, it is capable of producing extremely compelling generates. Better than anything I have seen or heard from literature. While there is certainly a place in the world for something that works fast, there is also a place for something that truly works well. I think we are only a few years away from ML models that produce generates that the average human would consider as true art. Voice, music, paintings, etc – it’s all possible, with enough data, compute and patience.

To that end, I am currently building a text-to-speech triforce model which I suspect will blow every previous TTS model out of the water. It’s going to be slow and ungodly large, but my goal is to build something that you can truly enjoy listening to. Something that you might actually use to transcribe audio books or as a stand-in to voice actors, for example.

Like all large transformer models, this thing is going to be enormously data hungry so my last few months has been spent building a massive speech dataset pulled from podcasts, audiobooks and YouTube. I hope to write about that soon.


Accelerated Differentiable Image Warping in Pytorch

Computing optical flow is an important part of video understanding. There are many ways to train a model to compute this, but one of the more compelling methods is to:

  1. Feed a model an image pair
  2. Have it predict optical flow
  3. Apply that optical flow to the original image
  4. Compute a pixel-wise loss against the second image.

In order to use this algorithm, however, you need a differentiable way to do step (3), typically called an “image warp”. Tensorflow has just such an operation in contrib, but to my knowledge Pytorch does not.

After digging around for awhile today, I found what I needed in one of nVidia’s open source repositories:

In this repository, the author has implemented a new CUDA primitive called “resample2d“. Although there isn’t any documentation on this operation, it is exactly what is needed to compute an image warp given an optical flow vector.

Suppose you have an image and a .flo file, which you can find from several places. Here is how you would use this operation to compute the secondary image:

from utils.flow_utils import readFlow, writeFlow, visulize_flow_file
from networks.resample2d_package.resample2d import Resample2d

im1 = load_image('image.png').to('cuda')
flow = torch.tensor(readFlow('flow.flo')).permute(2,0,1).unsqueeze(0).contiguous().to('cuda')
resample = Resample2d()
synth = resample(im2, flow)
torchvision.utils.save_image(synth, "flowed_image.png")

You’ll need to import the code from the above linked repository to run this. Note that resample2d must be performed on the GPU. It does not work on CPU and just returns all zeros.


Batch Normalization is a Hack

Batch normalization has a simple goal: stabilize the gradients of large computational graphs. In doing so, this technique has enabled the deep learning renaissance that almost every major ML breakthrough in the last 5 years has relied on.

The concept is sound: by regularizing the mean and variance of the inputs of nearly every layer in a neural network, the gradients of that network rarely explode backward pass. The end result is that many neural networks can be easily trained with gradient techniques that would otherwise have never converged.

So why am I calling it a hack? Let’s dig in.

It violates batch invariance

When training a neural network with SGD, we often use mini-batches. Mini-batches allow a computational graph to average the gradients of a forward pass across a greater portion of the input space. They also allow us to easily scale training to fit the computational resources available to us.

There is an important caveat to mini-batches: the choice of which inputs are included in a batch are supposed to be totally random. As a result, there is no structure to the batch dimension of a mini-batched input vector. Therefore, it is always wrong to compute across this structure, which is precisely what batch normalization does.

To give a grounded example of why this is bad, consider some work on GANs I have been doing recently. The discriminator of this GAN is a VGG-style network that used batch normalization. In conventional GAN style, the discriminator discerns whether or not the inputs it is given are real or fake. Since fake inputs must backprop through the generator, they are always batched together. To provide symmetry, real inputs are also batched together.

While analyzing the performance of the GAN, I noticed that the discriminator was always predicting a near-uniform value across the batch dimension. E.g. if one image was predicted as fake, all of the images in the minibatch would be predicted fake and vice versa. This struck me as odd: shouldn’t the discriminator be making predictions per-image and not per-batch?

It turns out this odd behavior was caused by batch normalization. Switching all of the VGG blocks in the discriminator to group norm caused the issue to go away and improved the generator’s inception score in validation by 5%!

I’m not the only one to experience these issues with batch normalization in GANs. In this paper, Xiang et al took an extensive look at batch normalization in Generators and found evidence of performance loss in the end result.

I think it’d be pretty silly to assume that this “bad behavior” is constrained only to GANs. While I can’t point to any specific evidence, I wouldn’t be surprised to see research showing that batch norm hurts the some part of the performance of most neural networks and I’d bet that violating batch invariance is the root cause.

Scaling? Good luck.

Since batch normalization operates across the batch axis, it doesn’t play nicely with data parallelism. While the major ML frameworks have workarounds for this problem, these workarounds bottleneck the network with expensive reduce-alls scattered throughout the graph.

This isn’t really an issue with simple networks because you can simply normalize across the batch dimension per-compute unit, but as your neural network begins to scale to the point that you can no longer fit ~8 samples per compute unit, the normalization really starts to break down. The result is degraded network performance or slower training.

Things get more dire when you don’t have access to a lot of parallel compute. The best way to train large networks in this situation is via gradient accumulation, but this doesn’t play well with batch norm and there are currently no good workarounds pre-implemented in the big frameworks.

Lazy Weight Initialization

The authors of the batch normalization paper argued that it reduces the effect of gradient explosions caused by improper weight initialization.

This seems to be the case, as evidenced by its widespread use in deep residual networks, which have issues with exploding means. In their work titled “Fixup Initialization”, Zhang et al dig into whether or not it is possible to train deep residual networks without batch normalization. They accomplish their goal primarily by deriving a new weight initialization algorithm for use specifically in residual networks. Using Fixup, the authors were able to achieve near-BN performance without any normalization.

The interesting conclusion that can be drawn from this paper is that normalization is (at least in some cases) less a necessity than it is a way to cover up the need for deriving proper weight initialization schemes. It is not hard to do some analysis on a residual network without BN or Fixup to see that the means and variance quickly get out of whack as the network depth increases.

All this is to say that while it might be appropriate to perform normalization in the experimental phase of research, we should hold the ML community at a higher standard when we start using models in production: these models should be closely analyzed for the possibility of improvements via weight initialization. At the very least, weights should be initialized so that the networks can be trained without normalization, even if normalization is still used for performance reasons. I would love to see work similar to FixUp be done for other ML focus areas like NLP.

So it’s a hack..

I think batch normalization is far overused in current models. The situation is a combination of lack of understanding combined with a lack of good research into the use of alternatives. So much research builds upon existing SOTA models that it becomes difficult to reconsider using such a fundamental part of the model, even when it has so many drawbacks.

The next time you are building a model, I urge you to think twice about using batch normalization. For image researchers: try out group normalization, I bet you will be surprised by the results. Other normalization schemes like instance norm can work well too. You might also consider developing your own normalization – it actually isn’t too hard. Often all you need to do is find an input dimension with wide variance and perform normalization across it.

If you are putting a model into production, you should be doing some investigations into weight initialization and trainable scaling. A well-designed model should not need normalization to train properly, and you likely have performance to gain if it does.