Ring Attention Explained

Table of Contents

By Kilian Haefeli, Simon Zirui Guo, Bonnie Li

Context length in Large Language Models has expanded rapidly over the last few years. From GPT 3.5’s 16k tokens, to Claude 2’s 200k tokens, and recently Gemini 1.5 Pro’s 1 million tokens. The longer the context window, the more information the model can incorporate and reason about, unlocking many exciting use cases!

However, increasing the context length has posed significant technical challenges, constrained by GPU memory capacity. What if we we could use multiple devices to scale to a near infinite context window? Ring Attention is a promising approach to do so, and we will dive into the tricks and details in this blog.

It took us a few tries to understand how Ring Attention works beneath its magic. Several tricks have to click in your head to really grasp it! In this blog, we will go through them step by step:

  • We will build up why scaling long context across GPU is so challenging.
  • Understand how we can divide up attention calculation across the sequence dimension into chunks.
  • See how we can map the divided up attention calculation on multiple GPUs and orchestrate the computation in a way that adds minimal overhead, by cleverly overlapping communication and computation.

Notably the techniques we present here leading up to Ring Attention are general tricks to split computation of Attention into parallel parts. These techniques have been discovered by different amazing researchers and used in other transformer optimizations such as Self-attention does not need 𝑂(𝑛2)𝑂(𝑛^2) memory. and Flash Attention. Note that these optimizations and ring attention are orthogonal concepts, in that they can be used independently from each other.

Attention and Memory

Why is long context so hard? Let’s look at computing attention in the transformer. We have query, key, value - each is a matrix of shape sΓ—ds \times d - sequence length times model dimension. We need to compute O=softmax(QKT)VO =\text{softmax}(QK^T)V, this looks like

image

Notice that the Score Matrix S=QKTS=QK^T and the Attention Matrix A=softmax(QKT)A=\text{softmax}(QK^T) are both of size sΓ—ss \times s, so the memory complexity of naive attention is quadratic with sequence length! Even using advanced optimizations like Flash Attention, the memory complexity is still linear in sequence length. Thus for long contexts, scaling the sequence length is limited by the memory capacity.

The goal of splitting the computation into parts is to split the memory cost, dominated by the sequence length, into parts which each require only a fraction of the total memory complexity. For N GPU’s, we want to split computation into parts each requiring 1/N’th of the whole memory cost. In order to achieve that, we need to split both Q and K, V along the sequence dimension.

However, to compute the final attention result, we still need to access all these matrices that are now split up across GPUs. Naively doing so could add a lot of communciation overhead especially as these matrices grow linearly with sequence length!

Ring Attention presents a way to cleverly to address this! A sneak peak of how this is going to play out: We will rotate between devices to parallelize all computation and hide the communication overhead completely. It looks something like:

image GPUs are arranged in a ring, each holding a portion of the query. During the Ring Attention process, the GPUs will pass along blocks of K, V to each other; eventually each device will have seen all the blocks needed to compute attention output! We will show why and how this works in the rest of blog.

Splitting Q

How can computation be split along the sequence length of QQ? If you want to find out for yourself, follow the orange colours in the Attention visualization below. You will see that computation of each output row only depends on a single query row in QQ. Assume we split QQ into BQB_Q chunks of along its rows. Each QiQ_i chunk is of the shape CQΓ—dC_Q \times d, where CQC_Q is the chunk size and CQΓ—BQ=sC_Q \times B_Q = s. Let’s assign one such Query-chunk to each available GPU, what would each GPU need to compute attention? It still needs all of K and V! So far we did split the memory of QQ into BQB_Q chunks, and the corresponding Score and Attention matrices are also split into BQB_Q chunks. Now we need to split the keys and values.

image

Splitting K and V

How can we split up the sequence length for keys and values? Hint: It is not as trivial as splitting queries was!

Let’s look at the computation for a single Query chunk in the illustration above. What would happen if we additionally split the Keys along its rows (indicated by the blue colour)? Importantly note how the matrix multiplication S=QKTS=QK^T is now split along both rows and columns! The problem is that the Softmax operation needs to be computed over full rows of S at a time.

softmax(si)=exp⁑(si)βˆ‘j=1Nexp⁑(sj)\text{softmax}(s_i) = \frac{\exp(s_i)}{\sum_{j=1}^{N}\exp(s_j)}

Don’t worry, we will see how we can work around this challenge. First we need to understand why the Softmax seems to require access to the whole row at once: Looking at the softmax equation, it is clear that the upper part of the fraction (numerator) does not need access to the whole row at once as it simply takes the exponent of each individual entry. The lower part of the fraction however requires you to compute the sum of ALL elements in the row - at least that is how it seems upon first inspection ;)

We will later see how we can handle the lower part (denominator) of softmax in a parallel fashion, but for now imagine that the lower part simply does not exist.

We already know that we can easily partition computation along the rows (sequence dimension) of QQ. Hence, here we partition the computation of a single Qi,AiQ_i,A_i chunk into BKVB_{KV} independent sub-parts involving only a single Kj,VjK_j, V_j, j∈1,…,BKVj\in{1,…,B_{KV}} chunk at each step. The computation of a single Query, Output chunk Qi,AiQ_i,A_i is:

exponentials.png

This directly leads to our desired result! Computation is now split into an outer loop over BQB_Q chunks of QQ and over an inner loop over BKVB_{KV} chunks of KK and VV. At each step we compute only exp⁑(QiKjT)β‹…Vj∈RhΓ—cQ\exp(Q_iK_j^T)\cdot V_j \in R^{h\times c_Q} which requires only single chunks of KjK_j, VjV_j and QiQ_i at a time, successfully dividing memory cost over the sequence length (Setting the number of chunks to NN (BKV=BQ=NB_{KV}=B_Q=N) would result in 1/N1/Nth of the memory complexity of stoing the full K,V,QK,V,Q).

Put it in a Ring: We have the outer loop (parallel over Query chunks) and inner loop (parallel over Key and Value chunks). But how are the computation steps of this loop allocated to the devices?

Note how the outer loop can be computed completely independently, whereas the inner loop calculates a sum of its results. We partition computation to devices in the following way: Each device gets one QiQ_i chunk (one outer loop index) and calculates the inner loop iteratively for each Kj,VjK_j,V_j. Each device at a given step jj only needs to keep track of the cumulative sum AiA_i of shape hΓ—cQh\times c_Q and only needs a single Vj,KjV_j, K_j block at a time, along with its QiQ_i. Therefore the memory requirement for each device is successfully split, even for KK and VV.

Online softmax

Now let’s look at the normalization constant, how can we compute normalization constant on each device with just local results?

There is a simple solution! We accumulate the partial sum lj=ljβˆ’1+βˆ‘kt∈Kjexp⁑(QiktT)l^{j}=l^{j-1}+\sum_{k_t\in K_j}\exp(Q_i k_t^T), every time we receive a key and value chunk. By the end of the inner loop, the device would have accumulated l=lBKV=βˆ‘j=1sexp⁑(QikjT)l=l^{B_{KV}}=\sum_{j=1}^{s} \exp(Q_i k^T_j), which is the normalization constant we need. Note that the order of normalizing and multiplying with the value matrix V does not make a difference, which is why we can accumulate the sum and execute the actual normalization after everything else.

Therefore each device ii (holding QiQ_i), additionally to its cumulative sum Aj=Ajβˆ’1+exp⁑(QiKjT)VjA^{j}=A^{j-1}+\exp(Q_iK_j^T)V_j, also keeps track of its current lj∈RBQl^{j} \in \mathbb R^{B_Q}, while executing the inner loop. At the end of the inner loop, each device finishes by dividing its computed unnormalized Attention by the normalization constant ABKV/lBKVA^{B_{KV}}/l^{B_{KV}}.

Safe softmax

The exponential operation can easily grow out of bounds, leading to numerical issues and overflow. Typically to compute softmax, we subtract the maximum from every element. That is

softmax(s1:N)=exp⁑(s1:N)βˆ‘iexp⁑(si)β‹…exp⁑(βˆ’smax)exp⁑(βˆ’smax)=exp⁑(s1:Nβˆ’smax)βˆ‘iexp⁑(siβˆ’smax)\text{softmax}(s_{1:N}) = \frac{\exp(s_{1:N})}{\sum_i{\exp(s_i)}} \cdot \frac{\exp(- s_{max})}{\exp(- s_{max})} = \frac{\exp(s_{1:N} - s_{max})}{\sum_i{\exp(s_i - s_{max})}}

This is called the safe softmax. How can we subtract the maximum when we are computing the softmax in blocks? We can keep track of the maximum so far. Let’s say our current sum is AjA^{j} and current maximum is mjm^{j}, we receive a new key Kj+1K_{j+1} and value Vj+1V_{j+1}, we compute QiKj+1TQ_iK^T_{j+1} and get a new maximum mj+1=max⁑(mj,max⁑(QiKj+1T))m^{j+1}=\max(m^{j},\max(Q_iK^T_{j+1})), we can update our result as:

Aj+1=Ajβ‹…exp⁑(mjβˆ’mj+1)+exp⁑(QiKj+1Tβˆ’mj+1)β‹…VjA^{j+1} = A^{j} \cdot \exp(m^{j} - m^{j+1}) + \exp(Q_iK^T_{j+1} - m^{j+1}) \cdot V_j

lj+1=ljβ‹…exp⁑(mjβˆ’mj+1)+exp⁑(QiKj+1Tβˆ’mj+1)l^{j+1} = l^{j} \cdot \exp(m^{j} - m^{j+1}) + \exp(Q_iK^T_{j+1} - m^{j+1})

To conclude: For an inner step j+1j+1, before computing the cumulative sum Aj+1A^{j+1} and the normalization constant lj+1l^{j+1} we first compute the current maximum mj+1m^{j+1}, then renormalize the previous Aj,ljA^{j},l^{j} using our newfound maximum and finally compute the updated Aj+1,lj+1A^{j+1},l^{j+1}.

Putting it together

Finally, we have assembled all the tools we need to construct ring attention:

  1. Splitting along the sequence length of QQ into an independent outer loop.
  2. Applying online safe softmax in order to split along the sequence length of KK and VV resulting in an inner loop computing the attention cumulatively.

As hinted on before the way this is parallelized is by assigning each of the N devices available one chunk QiQ_i of QQ. Therefore we need to split QQ into N equal parts (BQ=NB_Q = N). Each device will individually compute its output block Output(Qi,K,V)=softmax(QiKT)VOutput(Q_i, K, V)=\text{softmax}(Q_iK^T)V iteratively by performing the inner loop over the blocks of Keys and Values. The challenge is that the devices are not able to store the full K and V matrices at a time.

Sequence Query and Output For example if we have 4 GPUs, we will split Query into 4 blocks along sequence dimension for each device. Each device would then compute output using local Q and the whole K, V. The final output would be concatenation of these local outputs along row-dimension

Remember how we showed that we can further split K, V as the inner loop? K and V are now split into BKV=BQ=NB_{KV}=B_Q=N blocks and initialize the devices so that each device holds a single QiQ_i block and a single Key KjK_j block and Value VjV_j block. For simplicity we can assume that device ii holds Qi,Kj=i,Vj=iQ_i,K_{j=i},V_{j=i} in the beginning.

After the devices have computed one inner loop step corresponding to their current Vj,KjV_j, K_j, each device needs to receive a next Key and Value block to continue the inner loop. Sending and waiting for these matrices is a big overhead that we do not want to wait for! This is where the ring in Ring Attention comes into play: We lay out the N devices in a ring, where device ii can send data to device i+1i+1 and so on as illustrated:

KV-overlap-step Observe that for GPU1, while it is computing output using Q1Q_1 (its local query) and K1K_1, V1V_1 (the local K,V blocks that it currently has), it is also receiving K4K_4, V4V_4 from GPU4 (previous host int the ring) and sending Q1Q_1, V1V_1 to GPU2 (next host in the ring). The network communication are illustrated by the blinking arrows. If we select the block size correctly, by the time GPU1 has computed output using Q1Q_1, K1K_1, V1V_1, it has received the block K4K_4, V4V_4 to compute the output in the next iteration!

Computing a step of the inner loop on device ii: Qi,Vj,KjQ_i, V_j, K_j takes a certain amount of time. If during that time the device ii can also send its current Vj,KjV_j, K_j to device i+1i+1 and simultaneously receive Vjβˆ’1,Kjβˆ’1V_{j-1}, K_{j-1} from device iβˆ’1i-1, then the latency from sending and receiving the Key and Value chunks is hidden behind executing the actual computation, as long as the time to transmit is lower than the time it takes to compute. We now can completely hide the communication overhead!

image Here we illustrate for NN devices, it will take NN iterations to finish the whole output comptutation. Watch for each iteration on each device, different partial sum of the output is computed with the K,V block it currently has, and it eventually sees all the K,V blocks and has all the partial sum for the output!

Memory and Arithmetic Complexity

For this analysis we will use the bfloat16 format that is commonly used in deep learning applications. Parallel processing accelerators such as GPU’s or TPU’s are usually measured by their FLOPs:=F:=F, which is the number of floating point operations the device can theoretically execute per second. In practice, we never really see full utilization but for the sake of this analysis we will assume that. Furthermore, the connections between the different devices we assume to have bandwidth :=BBytessec⁑:=B \frac{\text{Bytes}}{\sec}.

Memory Complexity: In order to receive send and compute at the same time we need to have memor for receiving the new Key Value blocks. Storing the current Key Value blocks requires 2β‹…dβ‹…c2 \cdot d \cdot c floats or 4β‹…dβ‹…c4 \cdot d \cdot c Bytes. The memory for receiving the new Key Value blocks is also of size 2β‹…dβ‹…c2\cdot d\cdot c floats or 4β‹…dβ‹…c4 \cdot d \cdot c Bytes. Now assuming that the computation itself does not require more memory (this would require a more in depth discussion, but there are ways to do this like flash attention or blockwise attention) computing the output of the current step requires dβ‹…cd \cdot c floats or 2β‹…dβ‹…c2 \cdot d \cdot c Bytes. Furthermore each device needs to store its QiQ_i block which also takes dβˆ—cd*c floats or 2β‹…dβ‹…c2 \cdot d \cdot c Bytes. In total we require 6β‹…dβ‹…c6\cdot d\cdot c floating points or 12β‹…dβ‹…c12 \cdot d \cdot c Bytes of memory.

  • A note for people familiar with Flash Attention: Ring Attention is an orthogonal concept to things like flash attention and could be used together (Flash attention is actually used in the inner loop of Ring Attention). These latter methods have the goal of not materializing the full Attention matrix and thus getting linear memory complexity in sequence length vs quadratical like the naive implementation would. Ring Attention manages to split memory complexity for both the naive and the flash attention method by at least NN times using NN devices, because it splits everything into at least NN or more parts (Splits Keys, Queries and Values into NN parts, and splits the Attention Matrix into N2N^2 parts)! No matter whether the memory complexity is dominated by the Keys, Queries and Values or by Attention Matrix, Ring Attention manages to split memory cost by at least NN times.

Communication Complexity: During a single step, each device needs to send 2β‹…cQβ‹…d2 \cdot c_Q \cdot d floating point values, from Kj,Vj∈RcQβ‹…dK_j, V_j \in \mathbb{R}^{c_Q \cdot d}, to the next device over a channel with bandwidth B. Each floating point value consists of two Bytes. Thus the time it takes to transmit both is approximately 4β‹…cβ‹…d/B4 \cdot c \cdot d/B

Arithmetic Complexity: Computing an inner loop step requires 2β‹…dβ‹…c22\cdot d\cdot c^2 for computing QiKjTQ_iK_j^T, 2β‹…cβ‹…d2 \cdot c \cdot d for computing the softmax along with the lij,mi(j)l^{j}_i,m^{(j)}_i normalization and safety parameters, and 2β‹…dβ‹…c22\cdot d\cdot c^2 for computing Aijβ‹…VjA_i^j \cdot V_j. Assuming that the device operates at the maximum possible FLOPs (in reality we would use the achieved average FLOPs) the time it takes to compute is roughly β‰ˆ4β‹…dβ‹…c2/F\approx 4 \cdot d \cdot c^2/F

In order to effectively overlap the communication and computation (aka hiding the commmunication overhead), we need the time of transmission of K, V blocks equal to the time it takes to compute the local Q, K, V : 4β‹…cβ‹…d/B≀4β‹…dβ‹…c2/Fβ€…β€ŠβŸΊβ€…β€ŠBβ‰₯F/cβ€…β€ŠβŸΊβ€…β€Šs/Nβ‰₯F/B 4 \cdot c \cdot d/B \leq 4\cdot d\cdot c^2/F \iff B\geq F / c \iff s/N\geq F/B

Further optimizations

One interesting case of Ring Attention is when used for causal transformer models, recall the triangular mask is used for attention calculation. This implies that some GPUs won’t need to compute over the whole sequence, which results in them being idle for the most part. An extension of Ring Attention, Striped Attention addresses this constraint and provides a scheme to distribute computation more even and hence making Ring Attention even faster!

Besides technniques like Ring Attention and Flash Attention to enable the standard Transformer architecture to have longer context length, there are also attempts to experiment with model architecture such as state space models (SSMs) with linear attention such as Mamba, but that is a deepdive for another day.

Summary

So let’s review how we arrived at this magical zero-overhead solution:

  1. Attention requires quadratic memory (or linear if optimized) in the sequence length. Scaling transformers to any desired sequence length therefore requires splitting memory cost over several devices, and we want do that in a low-overhead manner.
  2. We want to parallelize the attention calculation, as well as the shard the Q,K,VQ, K, V matrices across devices because they grow linearly with the sequence length.
  3. One particular way is to shard on the sequence dimension for Query as blocks, but each sharded query each still needs to compute with all the K,VK,V’s (which can be huge)!
  4. We show that it is possible to further parallelize softmax calculation by creating an inner loop over K,VK,V, with the correct normalization to recover the final softmax! This leverages the fact that the operations are commutative and all we need to do is cumulatively sum up the normalization, whilst also keeping track and renormalizing with the current maximum.
  5. Ring Attention builds on this technique of parallelizing Attention calculation by distributing the outer loop over QQ blocks to the individual GPU’s, while letting the inner loop be computed in a ring-reduce fashion. Instead of waiting for new Key, Value blocks we effectively hide transmition overhead by β€œrotating” the Key, Value blocks in a ring around the devices to compute the partial sums. With the correct block size, we can overlap communication with computation and fully hide overhead induced by communication.

Acknowledgement

We like to thank Jay Alammar and Bowen Yang for their insights! We also like to thank Daanish Shabbir, Yuka Ikarashi, Ani Nrusimha, Anne Ouyang, for their helpful feedback!