Over the past decade, some of the most remarkable AI breakthroughs—AlphaGo, AlphaStar, AlphaFold1, VPT, OpenAI Five, ChatGPT—have all shared a common thread: they start with large-scale data gathering (self-supervised or imitation learning, or SSL) and then use reinforcement learning to refine their performance toward a specific goal. This marriage of general knowledge acquisition and focused, reward-driven specialization has emerged as a the paradigm by which we can reliably train AI systems to excel at arbitrary tasks.
I’d like to talk about how and why this works so well.
1 – AlphaFold 2 technically does not use RL; instead it uses distillation via rejection sampling, which has similar (if less adaptable) results.
Generalization
In recent years, we’ve found that applying SSL to highly general datasets improves the robustness and thus usefulness of our models at downstream tasks. As a result, the models the big labs are putting out are increasingly trained on self-prediction objectives over a diverse corpus of interleaved text, images, video and audio.
By comparison, RL training has stayed quite “narrow”. All of the systems I mentioned above were trained with RL that optimizes something fairly specific: for example, play a game well or be engaging and helpful to humans talking to you.
Over the last year, something seems to have happened at many of the top research labs: they started investing in more “general” RL optimization. Instead of using reinforcement learning to optimize models to play one game very well, we’re optimizing them to solve complex math problems, write correct code, derive coherent formal proofs, play all games, write extensive research documents, operate a computer, etc.
And this seems to be working! Reasoning models trained with general RL are leap-frogging SSL on every benchmark for measuring model performance that we know of. There’s something happening here, and it’s worth paying attention to.
Some Terminology
When training with an RL objective, you are no longer learning to model a distribution of data – you are learning a sampling policy. What does that mean?
You can think of sampling from an autoregressive model (e.g. one that is sampled one “token” at a time) as sampling a chain of actions – each token is an action. Observations can be injected in the chain by inserting non-sampled tokens at any time during the process. This chain of actions and observations is called a “trajectory”. It is the causal rollout of a series of interactions between the model and some external environment.
When we say we’re learning a policy, what that means is that we’re teaching a model a set of processes for generating useful trajectories. “Useful” here is defined by the reward function you introduce. A useful trajectory is one that has a high likelihood of achieving the goal.
These processes can be thought of as little “subroutines” that the model learns to use as effective ways to solve the classes of problems that it regularly encounters. Thinking harder might be one. Learning to write common algorithms in C++ might be another. They have analogues in the human experience – Over my life I have learned how to speak, type on a keyboard, pour water, turn a screwdriver, plant a seed, carry heavy objects, drive – the list goes on and on. In all cases I’ve learned to do these things subconsciously. That only works by learning subroutines that I can chain together to accomplish my goals. We see policy models learn to do the same things.
Error Correction
I suspect the most impactful end result of policy learning is in learning error correction. Likelihood training teaches the models how to mimic intelligent agents, but it does not teach them what intelligent agents would do if they were put in highly unlikely scenarios – such as would occur if a really poor prediction was made. Making a poor prediction like this is called “falling off the manifold”. It is analogous to something unexpected happening during your daily routine. Getting yourself back “on the manifold” is allows you to continue your goal oriented behaviors. This is error correction.
Here’s the dirty little secret about relying entirely on SSL: our models will always fail, and they will fail in ways that the humans who generated the pre-training data did not expect or encounter themselves. As a result, pre-training data will not always contain examples which teach the model to perform the error correction that it needs to function in the real world. So our SSL models will likely never be able to reliably error correct at all levels.
By contrast, general RL models learn error correction policies early on. We see this in reasoning models with their tendency to second guess their own thoughts. Words like “but”, “except”, “maybe”, etc trigger the model to review previous generations and catch errors made by naïveté, poorly executed exploration, or random chance.
Intentionality and Refinement
The first few times we do a complicated task, we say we are “learning” that task. For many tasks, this process is very intentional: we study it beforehand, come up with plans, and execute slowly with lots of interjected thinking. For tasks that happen too fast to think, we spend time afterwards debriefing. As we repeat the task over and over, the intentionality vanishes. We build the mental subroutines that I discussed previously.
A core component of the paradigm is the distillation of a cycles of observation, planning and action into simpler cycles of observation and action. This process of distillation was previously an offline process. We built models, we deployed them and studied how they interacted with their environment. We labeled to the good behaviors and the bad and we used those labels to train better models. This has largely been the driver of progress in LLMs for the past 2 years.
Now that we can build algorithms that improve themselves by “thinking” more, I expect this process of self-improvement to accelerate. It will likely define the next few years of ML progress – possibly decades. Increasingly, we’ll apply RL techniques to new fields or applications, generate large amounts of high-quality “on-policy” data, and pump that data into the SSL regime. The base models will get smarter, and we’ll use them to do more RL on an increasingly diverse set of problems.
This is a data generation engine that feeds off of compute and interaction with the world, and nothing else. People are concerned about the shortage of data, this is the big reason that I am not. Similarly, this is why I think it is naive to think that the need for more AI accelerators to taper off any time soon.
Reasoning
The first application of general RL has been to build “reasoning” models. We’re anthropomorphizing these things a bit by saying that they are “thinking”, but I think that the analogy isn’t as far off as many skeptics would have you believe.
When humans “think”, we channel our mental energy into better understanding the world so that we might take better actions in the future. This happens by internally searching our space of intuition to aid in solving problems. The search can happen in many different ways – we can think in language, talk to ourselves, or use the “minds eye” to visualize desired end states and the trajectory to getting to them. Effective thinkers are creative, self-critical and well-informed.
Reasoning models attempt to solve problems by generating long sequences of tokens to improve their answers to questions. These long sequences of tokens follow patterns imparted on the model by learning from human language. Increasingly, we’re seeing models learn to use knowledge retrieval to aid their search. The model also learns how to be self-critical and how to explore the vast space of possible trajectories.
What is interesting is that effective general reasoning strategies seem to “fall out” of general RL optimization. For example, a LLM that is taught to “think hard” to better solve math and programming problems performs much better at legal, biology and economics tests.
This opens up a brand new “scaling curve” in ML. Previously, we scaled data and compute to get better models with log-linear returns. We’re on the diminishing slope side of that curve. Now we’ve got a new method of optimization which can be applied on top of the old. From everything I’ve seen so far, the two seem to compound on top of one another. How far this can be pushed is the open question.
Where is this going?
As the techniques underpinning this paradigm mature and proliferate over the next decade, it’ll be increasingly clear that there are only two obstacles to building computer systems that solve any task of interest:
- Enabling models to interact with the world with the fidelity required to solve the task
- Finding robust ways to measure whether or not that task has been completed satisfactorily
Make no mistake: these are extremely hard problems. Just because we know what we need to do doesn’t mean that it’ll get solved this year. With that said, I think they are absolutely within the realm of “solvable” for a wide variety of useful tasks within a short timeframe.