From 2019-2021, I was fascinated with neural network architectures. I think a lot of researchers in the field were at the time. The transformer paper had been out for a little while and it was starting to sink in how transformational it was going to be. The general question in the air was: what other simple tweaks can we make to greatly improve performance?
As time has passed, I’ve internally converged on the understanding that there are only a few types of architectural tweaks that actually meaningfully impact performance across model scales. These tweaks seem to fall into one of two categories: modifications that improve numerical stability during training, and modifications that enhance the expressiveness of a model in learnable ways.
Improving numerical stability is a bit of a black art. I’m not an expert but those that are remind me of the RF engineers I worked with in my first job. Things that fit into this category would include where and how to normalize activations, weight initialization and smoothed non-linearities. I’d love to talk more about this someday.
I wanted to talk about learnable expressiveness in this post. The core idea here is to build structured representations of your data, and let those structures interact in learnable ways. Let’s start by looking at different ways this currently can happen:
MLPs are the most basic building block of a neural network and provide the foundation of interacting structures: they allow all of the elements of a vector to interact with each other through the weights of the neural network.
Attention builds another layer: rather than considering just a single vector interacting with weights, we consider a set of vectors. Through the attention layer, elements from this set can interact with each other.
Mixture of Experts adds yet another layer: Rather than considering vectors interacting with a fixed set of weights, we now dynamically select the weights to use for other operations based on the values within the vector (and some more weights!)
Hopefully you’re seeing the pattern here: in each of the above cases, we add an axis by which our activations can affect the end result of the computation performed by our neural network. I have no empirical proof for this, but what I think is actually happening here is that as you add these nested structures into the computational graph, you are adding ways for the network to learn in stages.
Why is important to learn in stages? Because we train our neural networks in a really, really dumb way: we optimize the entire parameter space from the beginning of training. This means all of the parameters fight from the very beginning to optimize really simple patterns of the data distribution. 7 Billion parameters learning that “park” and “frisbee” are common words to find around “dog”.
The neat thing about these learned structures is that they’re practically useless in the early training regime. Attention cannot be meaningfully learned while the network is still learning “black” from “white”. Same with MoE: expert routing amounts to random chance when the network activations are akin to random noise. As training progresses, these mechanisms come “online”, though: providing meaningful value just when you need a boost in capacity to learn a more complex layer of the data distribution.
Anyhow, regardless of whether or not my philosophical waxing is correct, learnable structures are probably the most fascinating research direction I can think of in architecture right now. My hunch is that there are additional structures that we can bolt onto our neural networks for another meaningful increase in performance. The main thing to pay attention to is that you are not just re-inventing a type of learned structure that already exists. Like Mamba. 🙂
One idea along this vein that I had explored before joining OpenAI:
StyleGAN is an image generation model with exceptional fidelity and speed. The catch is that it is an extremely narrow learning framework: It only works when you heavily regularize the dataset you train it on. For example, only photos of center-cropped faces, or specific types of churches. If you attempt to train it on something like LAION quality drops off as you lose the ability to model the data distribution: it’s just too wide to fit in the parameter space. But here’s the thing: you can think of most images as being made up of several modal components. Maybe a persons face here, a hand there, a tree in the background. It seems to me that an optimal way to get high generation performance and fidelity would be to train StyleGAN-like things separately from an image “composer” that learns to place the correct StyleGAN over the correct places in an image to decode. A “mixture of StyleGANs” if you will.
As a final note: I don’t want to claim the above is novel or anything, just a good idea. I think one of my favorite early applications of this general idea is using StyleGAN to fix StableDiffusion faces, like this. I want to try something like this learned end to end someday!