Blog Post
Attention Variants: MHA, MQA, GQA, and the Memory Math Behind Them
Multi-head attention was the original. Multi-query attention was the efficient approximation. Grouped-query attention is the synthesis that modern LLMs converged on — and the reason is bandwidth, not FLOPs.
Views: –6 min readCite
The number that decides which attention variant a model uses is not a FLOP count — it is how many bytes of key-value cache each token drags through GPU memory on every decode step. Standard multi-head attention gives every one of its heads a private set of query, key, and value projections, each of dimension , and during generation the keys and values for every past token have to be kept around so the next token can attend to them. The per-token cost of that cache is set by the heads that hold keys and values.
The factor of two counts keys and values, is the head count, and is bytes per parameter. For LLaMA-2-70B with , , in bf16, that is per token, which at a 4096-token context comes to 131MB of HBM for a single sequence — before you have batched anything. That number is the whole story: it is what caps batch size, it is what caps context length, and it is what the next two variants exist to shrink.
Multi-query attention: one key, one value, many queries
Shazeer's multi-query attention made the most aggressive cut available. Keep all query heads, because the query is what gives each head its distinct view, but collapse the keys and values to a single shared projection that every query head reads from.
The in the denominator of the cache cost is gone, so for the same 70B model the per-sequence cache at 4096 context drops from 131MB to roughly 512KB — a factor of . That is the difference between fitting one sequence and fitting hundreds, and it is why MQA decode throughput is dramatically higher. The catch is quality: forcing 64 query heads to share a single key-value subspace removes most of the representational slack that multi-head attention was buying, and on large models the perplexity hit is real, not negligible.
Grouped-query attention: the interpolation that won
Grouped-query attention is the obvious middle point, and the fact that it is obvious is exactly why it stuck. Partition the query heads into groups and give each group — not each head, and not the whole layer — its own key and value projection.
Setting recovers full multi-head attention and setting recovers multi-query, so GQA is a dial between the two rather than a third thing. LLaMA-2-70B sets : eight key-value heads feeding 64 query heads, for a per-token cache of — an 8× reduction against MHA, while the quality gap to full attention is far smaller than MQA's, because eight independent key-value subspaces preserve most of what one subspace threw away. Eight is not magic; it is the empirically observed knee where you have recovered nearly all the quality for nearly all the memory savings.
Why this is a bandwidth problem, not a compute problem
It is worth being precise about why the cache, and not the matmul, is the bottleneck. During decoding the model emits one token at a time, and each new token must attend to every token before it, so the attention read scans the entire KV cache for that sequence on every single step. The arithmetic per step is tiny — one query against keys — but the cache that query must stream over grows linearly with the conversation, and streaming it out of HBM is what the GPU actually spends its time on. GQA wins not by doing less math but by moving fewer bytes, which is the resource that runs out first when you try to serve long contexts and large batches at the same time. The full accounting of how that cache is allocated, grows, and is freed lives in the KV cache lifecycle post.
FlashAttention: stop materializing the matrix
The same memory-bound logic explains the most important attention kernel of the era. Naive attention computes the full score matrix, writes it to HBM, reads it back to softmax it, and reads it again to weight the values — and that intermediate matrix costs memory and, worse, round trips to slow memory. FlashAttention never writes the matrix down. It tiles the computation into blocks small enough to live in the GPU's on-chip SRAM, computes the softmax incrementally with a running normalizer as it streams over key-value tiles, and produces the exact same output in memory with the HBM traffic slashed. The result is identical numerically and several times faster in wall-clock; FlashAttention-2 rebalances the work across thread blocks and warps to keep the tensor cores fed, and FlashAttention-3 targets the H100's asynchronous tensor cores and FP8 path, together pushing long-context attention 5–9× past the naive kernel.
Sliding-window attention: bound the cache by construction
The last lever caps the cache instead of compressing it. Sliding-window attention restricts each token to attend only to the previous tokens rather than the entire history, which makes the per-step read — and the cache that must be retained — constant in instead of growing with sequence length. Mistral uses a 4096-token window on most layers, and the trick is that information still propagates further than one window: stacking windowed layers gives an effective receptive field of roughly tokens, because each layer can relay what it saw to the next, so local windows on most layers plus a few full-attention layers preserve long-range reach at a fraction of the memory.
That is the full inventory of how attention is made to fit in memory: share the keys and values, fuse the kernel, or bound the window. The next and final post climbs above attention to the block that holds two-thirds of a transformer's parameters — the feed-forward network — and the choices there, from SwiGLU to mixture-of-experts, that separate one model family's quality from another's.