Out-of-Distribution Generalization via Composition: A Lens Through Induction Heads in Transformers
Study on Out-of-Distribution Generalization and Composition Mechanisms in Large Language Models
Paper Background
In recent years, large language models (LLMs) such as GPT-4 have demonstrated remarkable creativity in handling novel tasks, often solving problems with just a few examples. These tasks require models to generalize on distributions different from the training data, known as “out-of-distribution generalization” (OOD Generalization). Despite the tremendous success of LLMs, how they achieve OOD generalization remains an open question. This paper aims to explore this phenomenon by examining LLMs’ performance in tasks generated according to hidden rules, particularly focusing on a component called “induction heads” (IHs), to reveal the relationship between OOD generalization and composition mechanisms.
The study primarily investigates LLMs’ performance in symbolic reasoning tasks, exploring how these models infer hidden rules from input prompts without fine-tuning. Through empirical studies of training dynamics, the authors found that LLMs learn rules by composing two self-attention layers, thereby achieving OOD generalization. Additionally, they propose the “common bridge representation hypothesis,” suggesting that a shared latent subspace in the embedding (or feature) space aligns early and later layers, acting as a bridge for composition.
Paper Source
This paper was co-authored by Jiajun Song, Zhuoyan Xu, and Yiqiao Zhong, from the Beijing Academy of Artificial Intelligence and the University of Wisconsin-Madison. It was published in PNAS (Proceedings of the National Academy of Sciences) on February 7, 2025, titled “Out-of-Distribution Generalization via Composition: A Lens through Induction Heads in Transformers”.
Research Process and Results
Research Process
Synthetic Task Experiment
The authors first conducted experiments on a synthetic task, specifically the “copying sequence” task. Given a sequence containing repeated patterns (e.g., [a], [b], [c]), the model is required to predict the next token as [c] upon receiving [a], [b]. The experiment used a two-layer Transformer model with standard self-attention mechanisms and residual connections during training.Training Dynamics Analysis
During training, the authors observed two phases: the weak learning phase and the rule-learning phase. In the weak learning phase, the model only learned simple statistical features of the input sequences and failed to generalize on OOD data. In contrast, during the rule-learning phase, the model learned the copying rule and performed well on both ID and OOD data.Role of Induction Heads (IHs)
By analyzing training dynamics, the authors discovered that IHs play a crucial role in OOD generalization. IHs are a type of attention head that can identify repetitive patterns in input sequences and predict the next token. Experiments showed that the model achieves OOD generalization by combining two self-attention layers, which handle positional information and token information separately.Common Bridge Representation Hypothesis
The authors further proposed the common bridge representation hypothesis, suggesting that a latent subspace in multi-layer and multi-head models serves as a bridge for compositions. By aligning subspaces between early and later layers, the model can generalize on OOD data.Experiments on Pretrained LLMs
To validate the hypothesis, the authors conducted extensive experiments on various pretrained LLMs, including LLaMA, Mistral, and Falcon. The results indicated that IHs are critical for tasks like symbolic reasoning and mathematical reasoning, especially on OOD data.
Research Results
Results from Synthetic Tasks
In the synthetic task, the two-layer Transformer model demonstrated OOD generalization capabilities, while the one-layer model only achieved weak learning. Data showed a significant improvement in generalization ability during the rule-learning phase, particularly when dealing with longer repeated sequences.Experimental Results of IHs
Removing IHs significantly reduced the model’s performance on OOD data across different tasks. For example, in symbolic reasoning tasks, removing IHs decreased accuracy from nearly 90% to below 50%.Validation of Common Bridge Representation Hypothesis
Experimental results showed that IHs and previous-token heads (PTHs) share a latent subspace. By aligning these subspaces, the model could generalize on OOD data. This hypothesis was further validated through weight matrix projection experiments.
Conclusion and Significance
The main conclusion of this paper is that LLMs achieve OOD generalization through composition mechanisms, with IHs and PTHs playing a critical role in this process. The common bridge representation hypothesis provides a new perspective on how LLMs learn rules and generalize in novel tasks.
Scientific Value
Revealing Generalization Mechanisms
This study reveals how LLMs generalize on OOD data, filling a gap in the research field.Proposing New Hypotheses
The common bridge representation hypothesis offers a new theoretical framework for understanding LLMs’ composition mechanisms, facilitating further research into model internals.Application Value
The findings can guide improvements in LLM training methods and model design, particularly in enhancing performance on novel tasks.
Research Highlights
Novel Research Perspective
By focusing on IHs, this study uncovers the internal mechanisms of OOD generalization in LLMs, a less explored area.Extensive Experimental Validation
The study not only experimented with synthetic tasks but also extensively validated its conclusions on various large-scale LLMs, enhancing the generality of the results.Theoretical Innovation
The common bridge representation hypothesis provides a new theoretical perspective on how LLMs achieve generalization through composition mechanisms, offering significant academic value.
Additional Valuable Information
The code and data for this paper have been made publicly available on GitHub at: https://github.com/jiajunsong629/ood-generalization-via-composition. This facilitates replication and extension of the research by other researchers.
Summary
By delving into the mechanisms of OOD generalization in LLMs, this study reveals the key role of composition mechanisms in learning rules and achieving generalization. This not only deepens our understanding of LLM internals but also provides important theoretical support for future model design and optimization.