ring attention: Nonlinear Function
Created: February 19, 2024
Modified: February 19, 2024

ring attention

This page is from my personal notes, and has not been specifically reviewed for public consumption. It might be incomplete, wrong, outdated, or stupid. Caveat lector.

References:

Ring attention is a technique for implementing long-context transformer attention (long-term context in Transformers) in a distributed system.

It builds on blockwise-parallel transformers, which loop over (chunks of) output positions, and at each position they compute the output of a fused attention-feedforward block. Each step of that outer loop involves an inner loop over all other positions --- this inner loop is the online softmax / FlashAttention computation.

Ring attention adapts this to a distributed system by assigning to each node a (chunk of) output position(s), i.e., it parallelizes the BPT outer loop. Each node then needs to run a full 'inner loop' which access key-value information from every chunk of the sequence. But the node has finite memory, so it can't store all of the key-value information for a potentially unbounded-length sequence. Instead, the nodes pass the key-value chunks around a ring as they process them, so

  • each node only holds a couple of chunks at a time (the one it's working on and sending to the next node in the ring, and the next one it will work on as it's received from the previous node in the ring), and
  • we only need O(N)O(N) links to connect the NN nodes in a ring, and each link transfers only a single chunk at a time, so it doesn't become congested, and
  • the chunks are transferred concurrently with the compute, so as long as the network link can transfer a chunk faster than the node can process it, there is no slowdown from network overhead.

The key point about transformers is that they're position-invariant: the attention computation doesn't care what order the input chunks are processed in, so it's fine that each node will start its inner loop from an arbitrary point in the sequence and 'wrap around'.

This technique allows for exact attention computation on arbitrarily long sequences, limited only by the number of nodes in the ring. Of course there's no computational free lunch: for a long sequence we need O(T)O(T) nodes working for O(T)O(T) sequential steps (so each member of the ring sees the whole input), so the total work is still quadratic. But at least the sequence length is no longer limited by the memory of individual devices.