Authors:
(1) David Raposo, Google DeepMind and with equal contribution;
(2) Sam Ritter, Google DeepMind;
(3) Blake Richards, Google DeepMind and McGill University & Mila;
(4) Timothy Lillicrap, Google DeepMind;
(5) Peter Conway Humphreys, Google DeepMind;
(6) Adam Santoro, Google DeepMind and with equal contribution.
Editor's note: this is part 1 of 5 of a study detailing a way to make transformer-based language models more efficient by dynamically allocating computational resources. Read the rest below.
Table of Links
- Introduction
- Background
- Implementing Mixture-of-Depths Transformers
-
3.1. Defining a compute budget
-
3.2. Routing around transformer blocks
-
3.3. Routing schemes
-
3.4. Routing implementation
-
3.5. Sampling and 3.6. Training methods
-
- Results
- 4.1. Training, isoFLOP comparisons
- 4.2. Auto-regressive Evaluation and 4.3. Mixture-of-Depths-and-Experts (MoDE)
- Discussion and References
Transformer-based language models spread FLOPs uniformly across input sequences. In this work we demonstrate that transformers can instead learn to dynamically allocate FLOPs (or compute) to specific positions in a sequence, optimising the allocation along the sequence for different layers across the model depth. Our method enforces a total compute budget by capping the number of tokens (𝑘) that can participate in the self-attention and MLP computations at a given layer. The tokens to be processed are determined by the network using a top-𝑘 routing mechanism. Since 𝑘 is defined a priori, this simple procedure uses a static computation graph with known tensor sizes, unlike other conditional computation techniques. Nevertheless, since the identities of the 𝑘 tokens are fluid, this method can expend FLOPs non-uniformly across the time and model depth dimensions. Thus, compute expenditure is entirely predictable in sum total, but dynamic and context-sensitive at the token-level. Not only do models trained in this way learn to dynamically allocate compute, they do so efficiently. These models match baseline performance for equivalent FLOPS and wall-clock times to train, but require a fraction of the FLOPs per forward pass, and can be upwards of 50% faster to step during post-training sampling.
1. Introduction
Not all problems require the same amount of time or effort to solve. Analogously, in language modeling not all tokens and sequences require the same time or effort to accurately make a prediction. And yet, transformer models expend the same amount of compute per token in a forward pass. Ideally, transformers would use smaller total compute budgets by not spending compute unnecessarily.
Conditional computation is a technique that tries to reduce total compute by expending it only when needed (Bengio et al., 2016; Bengio, 2013; Bengio et al., 2013). Various algorithms offer solutions to when and how much compute should be used (Ainslie et al., 2023; Bapna et al., 2020; Fedus et al., 2022). However, general formulations of this challenging problem may not work well with existing hardware constraints since they tend to introduce dynamic computation graphs (Dehghani et al., 2018; Graves, 2016). The most promising conditional computation methods may instead be those that are harmonious with our current hardware stack, which prioritizes static computation graphs, and known tensor sizes that are selected to maximize hardware utilization.
Here we consider the problem of language modeling using a static compute budget that can be made less than that used by a vanilla transformer. The network must learn how to dynamically allocate the available compute by making decisions per-token, in each layer, about where to spend compute from the available budget. In our implementation total compute is user defined and unchanging prior to training, rather than being a function of the network’s on-the-fly decisions. Thus, hardware efficiency gains—such as reduced memory footprint, or reduced FLOPs per forward pass—can be anticipated and exploited ahead of time. As we will show, these gains can be had without sacrificing overall performance.
We leverage an approach akin to Mixture of Experts (MoE) transformers, in which dynamic token-level routing decisions are made across the network depth. Departing from MoE, we choose to either apply a computation to a token (as would be the case for a standard transformer), or pass it through a residual connection (remaining unchanged and saving compute). Also in contrast to MoE, we apply this routing to both forward MLPs and multi-head attention. Since this therefore also impacts the keys and queries we process, the routing makes decisions not only about which tokens to update, but also which tokens are made available to attend to. We refer to this strategy as Mixture-of-Depths (MoD) to emphasize how individual tokens pass through different numbers of layers, or blocks, through the depth of the transformer (see figure 1).
The MoD technique also allows one to trade-off performance with speed. On the one hand, one can train an MoD transformer that improves upon vanilla transformers by as much as 1.5% on the final log probability training objective for equivalent training FLOPs (isoFLOP), and while taking an equivalent amount of wall-clock time to train. On the other hand, one can train an MoD transformer that achieves training loss parity with an isoFLOP optimal vanilla transformer, but which uses a fraction of the FLOPs (upwards of 50%) per forward pass, and hence is faster to step. Together, these results imply that MoD transformers learn to route intelligently (i.e., skipping computations that are unnecessary) since they can achieve equal or better log probabilities per sequence despite a smaller FLOP footprint per forward pass.
This paper is available on arxiv under CC BY 4.0 DEED license.