July Papers: Subliminal Learning, Mixture of Recursions and Dataset Curation
As July brought tennis at Wimbledon, so too did the ML world serve up a volley of research. This month, we took an eagle-eyed approach—or, perhaps, Hawk Eyed approach—to three papers.
In our first paper, Subliminal Learning addresses the question, “Can we control or filter the distillation training data so that a student learns desirable properties but avoids picking up undesirable traits?” The authors conclude that the student learns all the teacher’s traits, whether they’re desirable or not!
Next, Mixture of Recursions brings a twist to token-level computation: instead of fixed-depth processing, the model learns to recurse adaptively, allocating compute per token dynamically and efficiently—like a rally whose length depends on the importance of the point.
Last up is DataRater, where the problem of dataset quality is addressed. A ‘rater’ is meta-learned to curate training data without manual filtering—an ace for data-centric AI.
We hope you enjoy this month’s papers as much as we did! If you have thoughts or questions, please reach out to us at @GCResearchTeam.
Here’s our summary of this month’s chosen papers:
DataRater: Meta-Learned Dataset Curation
Authors: Dan A. Calian, Gregory Farquhar, Iurii Kemaev, Luisa M. Zintgraf, et al. (Google DeepMind)
The key idea
As documented by many foundational models papers, data quality is fundamental to the training of large language models (LLMs). This DataReader research work approaches dataset curation as a meta-learning problem, which can be learnt during training of LLMs, leading to more accurate fine-grained data filtering compared to commonly used hand-crafted heuristics.
Background
Numerous technical reports on state-of-art LLMs have highlighted the importance of training on high quality data. Hence many research groups have put great efforts into dataset curation, building complex manual curating pipelines. For instance datasets such as C4 or FineWeb have extensively documented the various filtering stages used: URL filtering, quality heuristics, content filtering, deduplication, …
Additionally, the training of foundation models relies more and more on the use synthetic data. The later has the potential to generate an unlimited quantity of samples, but also highlights the clear need of automated systems which can automatically identify data worth keeping.
Method
The DataRater model is a 50M parameters transformer architecture outputting normalized weights on every sample in a micro-batch. It is trained using a meta-learning approach: the outer loss (i.e. used for training the DataRater) is back-propagated through multiple inner model updates. In other words, the DataRater observes how different data points affect the inner model updates, and adjust the sample weights accordingly.
A DataRater is trained on each dataset using a population of eight 400M inner language models. Once frozen, a DataRater model will be run as a small inference task during the training of larger LLMs, allowing online filtering of the input dataset at every micro batch (i.e. removing the bottom-K samples, for a pre-defined K filter rate).
Results
As mentioned above, the DataRater approach is useful if a frozen model can be successfully re-used for dataset filtering over a range of model sizes. Additionally, dataset quality varies substantially, from highly curated ones such as C4 to largely unclean ones like the Pile. As a consequence, before training a large model with a DataRater, a filtering hyperparameter sweep is necessary to understand which proportion of a dataset can be safely discarded at every micro-batch without hurting validation accuracy.
As presented in the figure above, a filtering hyperparameter sweep can be performed at a relatively small model size (i.e. 50M parameters), and transferred to much large models (>1B parameters). Interestingly, this hyperparameter transfer is effective across a scope of datasets, from 10% filtering on C4 to 75% on the Pile.
Finally, the DataRater approach is showing robustness across a variety of downstream tasks: in experiments over 3 datasets, 4 models and 7 metrics, 73 out 84 downstream tasks results are improved.
Full paper: DataRater: Meta-Learned Dataset Curation
Subliminal Learning: Language models transmit behavioral traits via hidden signals in data
Authors: Alex Cloud, Minh Le, et al. (Anthropic Fellows Program, Truthful AI, Warsaw University of Technology, Alignment Research Center, Anthropic, UC Berkeley)
Tags: LLMs distillation efficient-inference
The key idea
When we choose to distil a smaller ‘student’ model from a larger ‘teacher’, what does the student learn from the teacher? Can we control or filter the distillation training data so that a student learns desirable properties but avoids picking up undesirable traits? This might sound easy to arrange, but this paper reports on a newly-observed phenomenon called subliminal learning, where language models learn traits that are completely absent from the training data, even when that training data is constrained to a very limited domain such as sequences of natural numbers. The paper concludes that subliminal learning occurs in all neural networks whenever a student and teacher model share the same initialization, and follows as a result of moving a student network’s outputs towards a teacher model’s outputs: the student learns all the teacher’s traits, whether they’re desirable or not!
Their method
The language model experiments in the paper all follow the same series of steps: the paper
- Takes a reference model, such as GPT-4.1;
- Chooses a trait to be expressed in the teacher model (such as a preference for an animal or a type of tree);
- Creates a teacher where the reference model expresses the trait, either by finetuning or using a system prompt;
- Generates a distillation dataset from the teacher by sampling completions for prompts that are unrelated to the trait;
- Filters the dataset to ensure it’s formatted correctly and contains no detectable semantic associations to the trait;
- Trains a student model by finetuning the reference model on the filtered dataset.
Results
For all animals and trees shown in the figure above, the student model’s preference shifts towards the teacher’s, even though the student was finetuned only on completions containing ‘between one and ten positive integers in the range from 0-999’.
The paper shows that this effect:
- cannot be explained by finetuning on arbitrary number sequences;
- also applies for models where the learned trait is ‘misalignment with human preferences’, such as expressing dangerous or harmful behaviour;
- also appears for more realistic distillation datasets which consist of code or Chain-of-Thought transcripts rather than number sequences;
- doesn’t appear reliably where the teacher and student use different base models, or different initializations;
- cannot be explained by hidden, model-specific semantic references.
Takeaways
- We’ve designed model architectures and optimizers to promote smooth generalization to unseen data, and this paper shows that this can apply where generalization is unintended, as well as where it’s desirable and intended.
- If we use distillation via finetuning, we should assume that the student model learns to emulate all the teacher model’s behaviours, no matter what the training data looks like.
- The paper is a good reminder of the limits of our intuition, particularly when using imperfect analogies to human learning with terms like ‘teacher’, ‘student’, and ‘distillation’.
Full paper: Subliminal Learning: Language models transmit behavioral traits via hidden signals in data
Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation
Authors: Sangmin Bae, Yujin Kim, Reza Bayat, et al. (KAIST AI, Mila, Google [Cloud, DeepMind, Research], University of Montreal)
Tags: efficient-inference mixture-of-experts LLMs
The key idea
While modern transformer-based LLMs have showcased impressive capabilities, they have significant computational and memory costs associated with training and inference, motivating the research into improving their efficiency. In the current work, the authors tackle improving parameter efficiency (i.e., can we get the same performance with fewer parameters), as well as adaptive computation (i.e., adaptively using more or less compute depending on the “difficulty” of the input), by introducing the Mixture-of-Recursions architecture. By considering a sequence of transformer layers as a single “recursion block”, the authors train a routing layer that effectively decides how many times each token should be passed through the block. The authors showcase that the performance of this architecture can match the standard transformer architecture for the same computational cost, while leading to a significant decrease in parameter count due to its recursive nature.
Their method
The authors take inspiration from the standard Mixture-of-Experts (MoE) approaches (see e.g., Switch Transformers), as well as previous attempts at adaptive computation using router-based networks (Mixture-of-Depths), in order to construct an approach that can adaptively re-apply a sequence of layers based on the estimated input difficulty.
Recursive block
The recursive block is constructed by either grouping the full sequence of transformer layers into the block that can then be cyclically re-applied (“cycle strategy”), or by keeping the initial and final layers as standard (non-recursive) layers, and re-applying only the middle ones (“middle-cycle strategy”).
Mixture-of-Recursions
In order to decide how many times each token should pass through the recursive block, the authors adapt the methods from the MoE literature. There are two main approaches that they consider:
-
Expert-choice routing: Each recursion depth is regarded as an “expert” that chooses the top-$k$ tokens to pass through the stage — only the tokens that passed through the previous recursion stage can be selected at the current one, thus, tokens “exit early” when they are not selected at a particular recursion depth. The router takes the previous hidden state and produces a scalar score for each token in the sequence. Thus, at each recursion stage, the router selects $k$ tokens corresponding to the highest router scores (Figure 2a). Note that, at each stage, $k$ is chosen so that progressively fewer tokens pass to the next one (figure showcases $k = 9, 6, 3$).
-
Token-choice routing: The recursion depth for each token in the sequence is decided immediately at the input to the block; by considering each recursion depth $r$ as an “expert”, the router calculates a score for each expert (with the number of experts equal to the maximum number of recursions allowed $N_r$), and the token is assigned to the expert corresponding to its largest score (top-1 routing). This is shown in Figure 2b.
In both cases, the router consists of a linear layer followed by a non-linearity; for expert choice, it produces a scalar score using a tanh
/sigmoid
function, while for token-choice it produces a $N_r$-sized vector, using a softmax
non-linearity.
Both approaches come with well-known pros and cons. In the case of expert-choice, there is an issue with “information leakage”, as the later tokens can influence the decisions for the previous ones in the sequence, which can be addressed by training an auxiliary router (that predicts if the token would be selected without considering other tokens in the sequence). On the other hand, token-choice often needs additional auxiliary “load balancing” in order to make sure each expert is assigned an equal number of tokens.
KV caching
As some tokens “exit early”, they will not have their key-value pairs available at later recursive iterations. In order to deal with this, the authors try two approaches (depicted in Figure 2c):
-
Recursion-wise KV caching: Attention is restricted to the tokens that are available at the given recursion depth. This means that the number of available tokens that can be attended to shrinks with each recursive step, leading to a less computationally intensive attention operation at each recursive step.
-
Recursive KV sharing: As all tokens pass through the recursive block at least once, another approach is to re-use the key-value pairs after the first recursive depth for each of the consecutive ones. Thus, at each depth the queries can attend to the full sequence, but the key-value pairs are only calculated once, during the first recursive pass.
Results
The main results are shown in Table 3. The downstream tasks used are LAMBADA (LD), HellaSwag (HS), PIQA (PQ), WinoGrande (WG), ARC (Easy and Challenge), and MMLU. The best performing setting utilises expert-choice (Expert), with recursion-wise caching (Cache), and middle-cycle recursion scheme (M-Cyc). For the same training FLOPs, the MoR model is able to have similar/better performance as the vanilla transformer, but with half of the original number of parameters. Increasing the number of recursive steps ($N_r$) can further decrease the total number of parameters, but at some cost of performance.
The authors also test the performance of the models at the same compute budget as the size is scaled up, for fixed recursion depth $N_r = 3$ (Figure 3). The results indicate that the MoR architecture can outperform the vanilla transformer, however, the gap decreases as the compute budget is increased, suggesting that the lower number of parameters might be reaching a capacity limit.
Finally, the authors showcase how the approach can lead to improvement in throughput compared to the vanilla transformer for a fixed number of effective parameters (360M), shown in Figure 4 (the “maximum” batch size line indicates the throughput when the largest batch size that fits on the GPU is used). As the maximum number of recursion depths is increased, throughput can be increased at the expense of decreased performance.
Takeaways
The authors showcase an interesting spin on the standard MoE approach: while MoE techniques improve model performance at a fixed computational budget (at the expense of higher memory requirements), the paper suggests that the same routing techniques can be used to dynamically adjust the computation applied to each token, showcasing a way of obtaining similar performance as the standard transformer, with a smaller memory footprint.
Full paper: Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation
Comments