memory-efficient attention: Nonlinear Function
Created: February 19, 2024
Modified: February 19, 2024

memory-efficient attention

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

To train a transformer layer on a sequence of length TT requires the output of the attention computation

X=softmax(QKTd)VX = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V

where Q,KQ, K are Dk×TD_k \times T matrices and VV is a Dv×TD_v \times T matrix. This computes attention at all TT positions simultaneously, requiring O(DT2)O(DT^2) FLOPs to evaluate QKTQK^T, the dot product of each query with all keys (often this is restricted to all previous keys, giving us 'causal' attention, but that doesn't change the asymptotics).

There is no way around the quadratic computation if we want to compute this expression exactly. But in practice the real limitation is often memory rather than compute. We can always get more compute by just taking more time, but GPU memory is a hard constraint. And while a naive implementation of the above does use quadratic memory to materialize the 'attention matrix' A=softmax(QKTd)A = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right), it turns out that this can be avoided.

Self-attention Does Not Need O(n2)O(n^2) Memory

Reference: Rabe and Staats, 2021 (https://arxiv.org/abs/2112.05682)

Lazy Softmax: the basic observation is that we can do a 'lazy softmax'. For a given query vector qq, consider the (scaled) dot product si=qkiT/ds_i = qk_i^T/\sqrt{d} of the query with each key, the corresponding column of the attention matrix given by the softmax,

s=softmax(s1:T)=es1:Tj=1Tesj,s' = \text{softmax}(s_{1:T}) = \frac{e^{s_{1:T}}}{\sum_{j=1}^T e^{s_{j}}},

and the final output of the attention computation

x=isivi.x = \sum_i s'_i v_i.

It turns out that we don't need to keep the full vector of attention scores ss and its softmax ss' in memory in order to compute this. Since each entry of the softmax has the same normalizing constant j=1Tesj\sum_{j=1}^T e^{s_{j}}, we can pull that normalization back to the final step

x=iesivijesjx = \frac{\sum_i e^{s_i} v_i}{\sum_j e^{s_j}}

so that the attention output is seen as the ratio of two sums, which can each be computed incrementally. We initialize v=0dkv^* =0^{d_k} and s=0s^* = 0 and for each element of the sequence we update

vv+esiviss+esi,\begin{align*} v^* &\leftarrow v^* + e^{s_i}v_i\\ s^* &\leftarrow s^* + e^{s_i}, \end{align*}

so that in the end we compute the attention vs\frac{v^*}{s^*} using only constant memory. This saves us a factor of TT memory compared to the naive algorithm.

To extend this to self-attention --- computing the attention queries qjq_j at all positions --- we can either parallelize over the TT positions, now requiring O(T)O(T) memory instead of the naive O(T2)O(T^2), or to be really extreme we can compute each position sequentially, in which case we require only O(logT)O(\log T) memory to store the index of the current position.

Stability. A naive accumulation of the exponentiated scores esie^{s_i} will overflow. The paper proposes incrementally normalizing the scores using mm^*, the maximum score seen thus far (initialized to -inf):

mi=max(m,si)vvemmi+viesimissemmi+esimimmi\begin{align*} m_i &= \max(m^*, s_i)\\ v^* &\leftarrow v^*e^{m^* - m_i} + v_i e^{s_i - m_i}\\ s^* &\leftarrow s^*e^{m^* - m_i} + e^{s_i - m_i}\\ m^* &\leftarrow m_i \end{align*}

and shows empirically that this computation yields results effectively indistinguishable from the traditional approach.

Differentiation. The essence of this trick is to convert a parallel computation to a sequential one, saving memory because we only need to represent one step at a time. But naively differentiating through this computation would again require us to remember each intermediate step, losing all of our advantage. This is addressed using the same techniques for memory efficient backprop developed for other sequential computations (RNNs or deep-layered networks), namely checkpointing.

Chunking. The purest form of incremental attention would accumulate one single position at a time, as described above. But in practice we can trade off memory for speed by accumulating chunks of adjacent positions, computing the attention scores for each chunk in parallel. This recovers traditional attention as the chunk size goes to infinity, but finite chunk sizes retain the crucial guarantee of bounded memory usage.

There are really two possible dimensions of chunking. For a given query vector qq we can accumulate its scores against a chunk of key vectors at a time. But we can also work with a chunk of query vectors at a time. Chunking in both of these dimensions corresponds to working with blocks of the attention matrix. The important thing is that each block of the output matrix is ultimately summarized by the accumulated v,s,mv^*, s^*, m^* representing the

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Reference: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré (2022), https://arxiv.org/abs/2205.14135

Flash attention implements the ideas specifically to optimize the round trips between a GPU's main memory ("high-bandwidth memory" or HBM, an A100 GPU has 40-80GB with bandwidth 1.5-2TB/s) and its much faster SRAM (an A100 has 192KB for each of 108 streaming processors, with bandwidth 19TB/s).

This brings in some systems language. Breaking the softmax into chunks, as described above, is here called 'tiling'. Exponentiating the scores in the same pass that we compute them is called 'fusion'. Applying memory efficient backprop is called 'recomputation'. The paper notes that the intermediate values in backprop can be recovered from the Q, K, V matrices and from the accumulated normalization statistics - so it's using a mix of checkpointing and inversion.

Conceptually the main contribution of this paper is the focus on IO-awareness. The insight is that memory efficiency isn't just about saving resources or running larger models, it also enables faster computation by making efficient use of the memory hierarchy. Even if we have enough HBM that we could fit the whole attention matrix in memory, it's actually faster to compute it incrementally so that we can avoid reading and writing the whole matrix from HBM multiple times. The paper has also been extremely impactful because they provided a CUDA implementation that became widely adopted.

Note that some of the same authors later went on to develop Mamba (mamba scratch) which also gets much of its performance from being aware of the memory hierachy; this seems to be a generally useful thing to consider.

Blockwise parallel transformers

Reference: Liu and Abbeel, https://arxiv.org/abs/2305.19370

The usual transformer architecture applies a feedforward block after each attention block. The feedforward layer performs the same computation at each position of the sequence. A naive implementation takes the attention output for the whole sequence, and passes it into the feedforward layer. But just as we can compute the attention output incrementally, one position at a time, we can do the feedforward computation for each position as soon as we've computed its attention output.

Conceptually, the 'inner loop' of online-softmax attention is over blocks of key/value entries for a given query position, and the 'outer loop' is over (blocks of) query positions. Each step of the outer loop produces a complete output for that (block of) positions in the sequence. So we can incorporate the feedforward computation for that (block of) positions into the outer loop.

This saves memory because we never materialize the full intermediate activations of the feedforward block. Typically the first feedforward layer blows up the dimension by a factor of 4, and the second projects back down to the original dimension. By only computing that intermediate for one chunk at a time (and recomputing it during backprop), we only need to store the input activations to the feedforward layer, which are smaller by a factor of 4. So this technique can save up to roughly a factor of 4 in memory usage.