As internet-scale AI models mature rapidly from coarse research demos to productionized user-facing systems, expectations have increased and goalposts have moved drastically. In just a few short months, the AI community has collectively shifted from being impressed by proof-of-concept zero-shot capabilities to tackling the challenging relative last mile of improving the quality and reliability of finetuned capabilities. As much as the community may have wished (or feared), it appears that it’s not sufficient to just dump ever larger amounts of compute, tokens, and parameters to ascend scaling curves. While this naive scaling approach can produce foundation base models with a rough understanding of the sum total of human experiences, the trillion-dollar question is how to make these base foundation models useful and performant for specific downstream capabilities. Increasingly, Modern AI is now the study of digital domestication: the art and science of taming wild internet-scale data distributions.
Prior Amplification Methods
The processes of training modern large language models (LLMs) and vision-language models (VLMs) critically rely on vast amounts of internet-scale data. High-capacity model architectures like transformers have shown the important ability to effectively model these extremely diverse data distributions — perhaps too well sometimes. These large models train on a virtual stew of all kinds of data: elegant prose from open-domain novels mixed with horrendously toxic 4chan posts, brilliant software projects mixed with bug-ridden homework code, gorgeous professional photography mixed with amateur social media selfies. And so, these models train and soak up all the glory and imperfection of these web-scale datasets, and these models begin to act as mirrors raised to the face of the digital human experience. But, while these “raw” models might offer a unique sociological tool to study human culture, they are a far cry from producing high-quality, desirable, and consistent outputs -- capabilities necessary for full productionization in user-facing applications at scale.
At this point, it’s important to recognize that these raw models are not bad models, but rather that they are doing exactly what they were designed to do: exhaustively and robustly model the distributions of data they were trained on. These underlying data distributions — the dataset priors — may indeed contain many undesirable properties, but also contain the good properties (and the diversity and scale) requisite for performant final models. A popular recent hypothesis emphasizes that a model’s knowledge and capabilities are learnt almost entirely during pretraining, while alignment teaches it which subdistribution of priors should be used during inference. The trillion-dollar question becomes: how do you amplify the good priors in the dataset while suppressing the bad priors? How do you tame the raw models captured directly from wild heterogenous internet distributions?
Prior Amplification: how a set of desired priors can be projected and amplified onto a model’s understanding of internet-scale datasets.
In the past year, a few major approaches have gained traction. While their technical underpinnings and advantages vary, they all share the common goal of prior amplification: how a set of desired priors can be projected and amplified onto a model’s understanding of internet-scale datasets. In this overview, we’ll take a look at various methods for prior amplification, notable usage examples, and provide a high-level framework for deciding between them.
The most obvious starting point for trying to steer a foundation model towards some desired prior is to just nicely ask the model. The intuitive concept is simple: if the model has learned about all sorts of diverse data during training, can you guide the model at inference time by carefully crafting the context to make your query look more like high-quality examples in the training data? This takes advantage of correlations and priors seen during training. For example, chess games correlated with high participant ELO ratings will most likely have much stronger moves than those with low participant ELO ratings; so at test time, a promising prompt should make it abundantly clear to the model that it’s in the high ELO chess playing regime, and should accordingly make strong grandmaster-caliber predictions. In lieu of diving into all the nuances of prompt engineering (aka in-context learning), we’ll just drop a pointer to this wonderful survey on prompt engineering if you’re interested in a more thorough deep dive.
For the purposes of this analysis, we’ll instead comment that there are clear limitations of zero-shot prompting. Prompting is an opportunistic strategy that is strongly dependent on the patterns, correlations, and priors seen in the original training dataset. Successful prompt engineering is a tug-of-war between prompts that are too generic (which the model can successfully follow but may not be useful, ie. “play like a chess AI”) and prompts that are too specific (which would be useful but the model is unable to generalize to, ie. “play like a 9000 ELO chess AI”).
Prompting’s reliance on underlying data distributions becomes challenging when wild data distributions contain many more undesirable data correlations than desirable correlations, as noted as part of the Waluigi Effect. For example, internet forum discussions will likely contain many examples of “polite political discourse turns toxic” compared to “polite political discourse turns toxic and then becomes polite again”. This makes it so that undesirable regions in training data distributions act as absorbing states from which escape is very difficult via prompting alone.
Regardless of whether these issues will go away with “better prompting,” it’s clear that zero-shot methods force a model to operate at inference time with all of the baggage of arbitrary priors contained in the training distributions. Can we amplify priors more effectively if we look beyond gradient-free prompting and consider finetuning the raw model itself?
2. Supervised Finetuning (SFT)
In supervised finetuning (SFT), raw models pretrained on diverse datasets are then subsequently trained on smaller but higher-quality datasets, which may or may not be subsets of the original dataset. SFT is the epitome of “show don’t tell”, where the finetuning dataset acts as the golden standard that contains all of the final model’s desired properties. This simplicity makes a compelling argument: provide the raw model with some target dataset, and SFT promises to bring the raw model closer to this target distribution. Since SFT (aka behavior cloning) is supervised learning, if the data is good and the models are large, success is guaranteed.
The regime of SFT is also flexible to what the finetuning dataset source was. It could be a subset of the original diverse dataset, or a new custom dataset altogether. It could be painstakingly crafted and verified manually by human labor, or automatically sourced using engineered heuristics and rules. And, as we’ll see a bit later, it can also be generated synthetically.
But let’s assume we have selected a particular finetuning dataset that represents all the nice priors we wish to distill into our model: how do you mechanically finetune the base model? Here, there are a few options as well. Standard SFT finetunes the entire base model, updating the weights of the entire network. This is the most exhaustive type of update possible, with the potential for significant changes in underlying model behaviors. Sometimes, a lighter touch is needed (don’t fix it if it ain’t broke!), and just a subset of the network can be finetuned; LiT is an example that freezes the CLIP image encoder while finetuning the language encoder. A related class of exciting recent methods known as Parameter-Efficient Finetuning (PEFT) take this concept further and freeze large parts of the original model, only finetuning a relatively tiny set of (extra) model parameters. PEFT methods like LoRA have unlocked tremendous open-source innovation, allowing consumer hardware to finetune respectably sized foundation models.
Clearly, the how of SFT is rapidly evolving, and will likely remain an exciting area for the foreseeable future. Regardless of the exact SFT method, there remains a heavy reliance on the composition and quality of the underlying finetuning dataset. In SFT, what priors you amplify matters just as much as how you amplify them.
Here are some examples of SFT methods and high-quality datasets that enable distilling desired human priors:
- LAION-Aesthetics is a high-quality image dataset that is a filtered subset of LAION-5B; it’s filtered by using pretrained CLIP embeddings. Aims to capture the prior of visually pleasing images.
- Video PreTraining collected task-specific Minecraft gameplay from contractors. Aims to capture the prior of directed, useful Minecraft actions.
- FLAN formatted more than 60 high-quality NLP datasets into instruction-following datasets. Aims to capture the prior of understanding and respecting textual instructions.
- Interactive Language contains language-annotated robot trajectories teleoperated and labeled by contractors. Aims to capture the relationship between language descriptions and robotic trajectories
- CodeXGLUE contains popular code repositories from GitHub. Aims to capture the prior of functionally correct, well-written programming code.
- Locked-Image Tuning (LiT) finetunes text to match a frozen pretrained image encoder.
- PEFT methods such as Prefix Tuning, Prompt Tuning, Low-rank Adaptation (LoRA), ControlNet freeze the main network and add new tunable weights that can be rapidly adapted to new datasets.
3. Reinforcement Learning from Human Feedback (RLHF)
In contrast to SFT, Reinforcement Learning (RL) finetuning introduces a reward model, a separate component that aims to directly provide granular feedback signals to model outputs during training. One of the most popular RL finetuning paradigms is RL from Human Feedback (RLHF), where the reward model is trained directly on human preference labels. Extending the earlier analogy of SFT taking the non-parametric approach of “show don’t tell”, RLHF is the opposite: explicitly learn good priors via a parameterized reward model, and then directly “tell” the raw model about these preferences during training. Formulating autoregressive token prediction as a reinforcement learning problem has two very compelling technical benefits: direct on-policy feedback and the ability to train on suboptimal data.
First, on-policy learning signals are extremely useful and qualitatively very different from those seen during standard offline off-policy training. On-policy feedback gives the model information on “how good is your best prediction?” compared to off-policy feedback which tells the model “how good would this other prediction have been?”. However, in addition to on-policy feedback being the most informative, sometimes off-policy feedback can be stale and incorrect: pre-collected training datasets contain target labels that exist in a vacuum and do not consider the model’s current capabilities. To illustrate why this matters, consider John Schulman’s example of how to tune ChatGPT to balance hedging (“I’m sorry, I don’t have that information") with confident predictions (“The answer is definitely yes”). The correct response for a given input may not be the same in all situations; a model with an extensive and accurate knowledge graph should be rewarded for a confident output, but a model with lapses in factual understanding should instead be rewarded for a hedged output. In RL terminology, we have a partial observability problem, since we may wish to operate on reward functions fitted to the behavior policy rather than reward functions of the optimal oracle policy; RLHF attempts exactly this.
Second, RLHF provides granular rewards that enable training on suboptimal data. Whereas the SFT setting only allows for a hard boundary between including or excluding data of varying quality, RLHF enables a more flexible approach of utilizing the suboptimal data both during reward model training as well during finetuning using a reward model. During reward model training, varying quality data can be included to make the reward model more robust. During foundation model finetuning, the reward model is able to output multiple granular reward scales (such as 1.0 reward for “correct + confident”, 0.5 for “correct + unconfident”, and -2.0 for “incorrect + confident”), which allows for effective utilization of different types of suboptimal data.
In addition to these two technical benefits, there’s also the systems level benefit of viewing the reward model as an independent component that can be studied and improved upon iteratively. This offers the potential of very nuanced reward modeling, which could then propagate very fine-grained feedback to the raw base model. This is empirically backed by SFT seeming to cause larger shifts in a base model’s instruction following behavior compared to successful RLHF finetuning.
Here are some examples of RLHF that amplify human preference priors:
- InstructGPT (followed by ChatGPT and GPT-4) trained a text alignment reward function using contractor-collected instruction following demonstrations as well as human-labeled model output rankings.
- Text-to-Image Alignment trained a image generation reward function using samples of discrete human preferences of images generated from text with Stable Diffusion.
- Few-Shot Preference Learning for Human-in-the-Loop RL pre-trains a robot manipulation reward model and adapts it to new tasks using human feedback.
4. Incorporating AI Feedback: AI Critics
While RLHF provides a powerful mechanism to transfer human knowledge to AI models, it also faces practical limitations: human feedback can be noisy, inconsistent, and expensive to collect. To tackle these challenges, Reinforcement Learning from AI Feedback (RLAIF) aims to bring existing AI models into the loop by utilizing prompted pretrained models to generate preference data for training reward models. RLAIF capitalizes on the asymmetric property that solution verification is much easier than solution generation (if you squint at it, it’s similar to P vs. NP). Even if existing foundation models are not good enough to generate outputs corresponding to some desired prior, perhaps they’re good enough to know good answers when they see them and provide on-policy preference labels? RLAIF thus captures good priors contained in prompted foundation models to generate automated preference data, with no humans in the loop, for downstream reward model training.
But foundation models acting as AI critics can go beyond generating data for reward models – they can be the reward model directly. At inference time, foundation models can give their best shot at completing the task and then self reflect on whether they succeeded. AI Critics at inference time can enforce additional structure, such as being combined with tree-structured search that prunes logical reasoning plans that don’t stand up to AI Critic scrutiny, or even using multiple AI Critics in a “Society of Minds'' to debate and discuss potential outputs. At training time, these AI critics (the current model or another model altogether) provide direct on-policy feedback, aiming to automatically distill the good AI critic priors into the finetuned models. There is a clear parallel here to lessons in Actor-Critic methods in RL, where critics are easier to learn but can provide great regularization and bootstrapping benefits to the actor policy.
Here are a few examples of AI feedback that amplify existing AI priors onto other AI models:
- Claude introduced Constitutional AI which starts with a human-produced prompt of rules and principles that is used during AI feedback generation and preference ranking of outputs, which are used during downstream reinforcement learning to reduce harmfulness and increase helpfulness of instruction following LLMs.
- ALMoST uses LLMs of different quality and sizes to generate contrasting responses which can be used to train a ranking-based reward model
- LLM Self-Reflection has been a rapidly accelerating area. LLMs understand their own uncertainty, Reflexion (and followups) use AI feedback during inference time, and LLMs Self-Improving incorporates AI feedback during training.
- Trees of Thought uses structured search at inference time to utilize LLMs to propose and search for the most promising reasoning chains.
- Society of Minds utilizes multiagent debate between LLMs to use an ensemble-like approach to improve factuality and reasoning.
- Inner Monologue uses expert models to provide textual feedback for LLMs that iteratively plan robotics tasks.
- AutoGPT combines AI feedback with digital tool use to autonomously execute tasks during inference time until self-judged completion.
5. Synthetic Data Generation
We have already mentioned examples of prior amplification that included AI in different parts of training, be it dataset filtering like LAION-Aesthetics using CLIP embeddings or AI critics using feedback generated by foundation models. But, can AI models also improve how we acquire and label entirely new datasets? Taking this further, could AI models generate useful data that’s high enough quality to subsequently train on?
A starting place might be to not entirely replace humans in the data engine loops, but rather just augment human abilities with a shared autonomy paradigm. Predictions from AI models might not be perfect but are perhaps a good enough starting point to save human labeling time. For example, the Tesla Autopilot team’s famous vision data engine uses automated labels for tasks like 3D object segmentation and lane detection as initial starting points for human raters to correct. More recently, Meta released the SA-1B segmentation mask dataset, which was made possible by an interactive model-assisted labeling process that was 6.5x faster than a completely manual labeling approach.
Beyond just assisting human raters, could advances in generative modeling enable creating useful synthetic data without any humans in the loop at all? This idea has been studied extensively in the past as semi-supervised learning or pseudo-labeling in the past; this blog post is a great overview of pre-2021 semi-supervised learning. But, the post-2021 proliferation of performant internet-scale models in language and vision have dramatically increased the potential of synthetic data generation. Whereas in the past, synthetic labels relied on narrow domain-specific models, now synthetic labels can potentially be produced by general models not specifically fitted for the task at hand. This has two benefits: it lowers the cost of trying out synthetic data generation and has the potential to import internet-scale common sense into the specific training domain.
This narrative of “general large models being used for narrow synthetic generation” has been increasingly explored in a variety of contexts, ranging from vision to robotics. Especially exciting are results that show the power of positive transfer of general model capabilities from the data generation to the data consumption model: InstructPix2Pix created a synthetic image editing instruction dataset by combining the instruction understanding capabilities of LLMs with text-to-image generative models. Synthetic data generation could also be used as data augmentation for existing ground-truth labels; this is explored in DIAL which augments language-conditioned robot trajectories with instructions prediction by CLIP. Finally, synthetic data generation can also be used for distillation between models of very different scales, such as Alpaca fine-tuning a 7B-parameter LLaMA model on instruction following outputs from 175B-parameter GPT-3.
The trend seems clear. Although the usefulness and quality of synthetic data was often called into question in the past (either on a technical or philosophical level), it seems clear that there are at least a few compelling domains where synthetic data is able to combine low-cost efficiency with sufficient quality for training, and in some cases even bring positive transfer from the data labeling model to the data consumption model.
Here are some examples of synthetic data generation:
- The Segment Anything Model trained on a 1.1 billion example segmentation mask dataset collected with model-assisted annotations.
- Tesla Autopilot’s vision models utilize model-assisted labeling for segmentation and detection tasks.
- VPT is a Minecraft agent that uses an inverse dynamics model to automatically label Minecraft gameplay videos with their original keyboard action inputs.
- Goat finetunes LLaMA on a generated arithmetic dataset that encompasses accurate and precise mathematical rigor.
- ROSIE and CACTI are robotic visual data augmentation methods that use diffusion models for semantic visual data augmentation.
- DIAL is a robotic language augmentation method that uses CLIP for generating language instructions or augmenting existing instructions for robotic trajectory datasets.
- Alpaca and Vicuna are instruction following LLMS that finetune LLaMA on GPT-3 and ChatGPT outputs. Alpaca-LoRA uses low-rank adaptation to avoid finetuning the whole model.
- InstructPix2Pix is an instruction following text-to-image generation model that generates a dataset by combining instructions from LLMs with Stable Diffusion to generate images.
- Synthetic generated images from Stable Diffusion can improve downstream classification accuracy.
So, what’s the optimal finetuning strategy for projecting desired priors onto existing foundation models? This is the trillion dollar question, and one that is actively being explored by a plethora of exciting research touched upon in this overview.
But, already there are some lessons and actionable suggestions one can conclude. Summarizing the earlier comparisons between methods, there are a few potential high-level questions to consider when making design decisions about prior amplification:
- Does the original training corpus contain all the capabilities and priors you desire?
- If Yes, try Prompting.
- If No, finetune the model.
- Is it easy to source different finetuning datasets?
- If Yes, try SFT.
- If No, try RLHF or AI Feedback.
- Do you have access to lots of compute?
- If Yes, finetune the whole model.
- If No, use PEFT.
- Are existing AI models good enough for data generation or data verification?
- If good enough for data generation, try creating Synthetic Data.
- If good enough for verification but not generation, try using AI Feedback (RLAIF) or self-reflection.
- If neither, stick to RLHF.
Zooming out a bit, it’s important to recognize AI acceleration of prior amplification as a double-edged sword. As AI models are increasingly utilized in various components of the data curation and training process, the pre-existing priors in these AI models also get passed on – both desirable and undesirable priors. Each of the finetuning methods discussed can be applied iteratively many times, with each generation of finetuned “student” models acting as the “teachers” of the next generation. So over time, the original source of specific priors start to get obfuscated compared to the simple lineage of model training in the past. This has very real implications on the AI field’s technical approach to alignment, safety, and controlling bias.
These are very difficult problems to think about, but this is now one of the core problems in modern AI. Priors are everywhere and everything. Shaping and amplifying them correctly in the context of massive internet-scale data distributions is now the next frontier in modern AI: the study of digital domestication.
Special thanks to Eric Jang, Karol Hausman, and Daniel Bashir for their helpful feedback!
1. The algorithm design decision between on-policy and off-policy feedback is a well-studied problem in robotics, where the attractive benefits of on-policy feedback must be weighed practically against the prohibitive cost of expensive real world interactions. However, in digital foundation modeling problems, on-policy feedback with learned reward models is much more tractable
2. One caveat for methods like Alpaca is that synthetic finetuning datasets are often inferior to on-policy feedback, and the synthetic dataset labels may be appropriate for the original data generation foundation model (in Alpaca’s case, ChatGPT) but not for the smaller, weaker model. This leads to artifacts like Alpaca effectively distilling the style and format of ChatGPT but not important capabilities like factuality. The recent trend of rapid distillation of large-scale LLMs may offer a False Promise of capabilities distillation that isn’t fully there.