S. Roy

Blog Post

GPU Architecture for LLM Inference

LLM inference is shaped by GPU hardware: HBM bandwidth, SRAM per SM, tensor core throughput, and the roofline that connects them. This post maps the memory hierarchy from HBM to tensor core, shows where decode and prefill sit on the roofline, and explains why FlashAttention exists.

Views: 8 min readCite

Every performance number in the previous parts of this series — the 7 ms decode floor, the 156 FLOP/byte ridge point, the reason FlashAttention matters — traces back to the same physical object: a GPU whose memory is organised in levels of wildly different size, bandwidth, and latency. Understanding that hierarchy is not optional background; it is the reason the serving stack is built the way it is.

The H100 at a glance

The H100 SXM5 has:

  • 80 GB HBM3 at 3.35 TB/s — the main memory where weights and KV caches live
  • 50 MB L2 cache on-chip, roughly 12 TB/s effective
  • 132 Streaming Multiprocessors (SMs), each with 228 KB of programmable SRAM
  • 4 tensor cores per SM executing 16×16×16 matrix-multiply-accumulate in ~8 cycles
  • Peak compute: 989 TFLOP/s fp16 (without sparsity)

The ridge point — the arithmetic intensity where the roofline transitions from bandwidth-bound to compute-bound — is 989 / 3.35295 FLOP/byte989\ /\ 3.35 \approx 295\ \text{FLOP/byte}. Any workload with intensity below 295 leaves compute idle and is bottlenecked on memory bandwidth. Any workload above it leaves bandwidth idle and is bottlenecked on the tensor cores.

The memory hierarchy: five levels

GPU Memory Hierarchy (H100)

Click any level to see what it stores, its size, bandwidth, and role in LLM inference.

HBML2SRAMRegsTensor Core
Click a memory level above to see details

Working from largest to fastest:

HBM is where everything starts. Model weights in fp16 cost 2 bytes per parameter — a 70B model occupies 140 GB, which already requires two H100s. Every decode step reads these weights once to compute the single-token forward pass. With 3.35 TB/s and 140 GB, the lower bound on a decode step is 140 / 335042 ms140\ /\ 3350 \approx 42\ \text{ms} per batch element. Batching amortises this: with a batch of 32, the same weight read produces 32 tokens, dropping the per-token cost to 1.3 ms\sim 1.3\ \text{ms}.

L2 cache (50 MB on H100) sits between HBM and the SMs. For small models, projections that are reused across many SM threads can live here rather than being re-fetched from HBM — effectively multiplying the apparent bandwidth. But at 50 MB it cannot hold the full weight set of anything larger than a few hundred million parameters.

SRAM (228 KB per SM) is the critical resource for attention. Unlike HBM, SRAM is on-chip and accessible by all threads within an SM at ~19 TB/s with ~30-cycle latency. The tension is its size: 228 KB is not much when an attention head has keys and values for thousands of tokens.

Registers (65,536 × 32-bit per SM, shared across threads) hold the live operands for the tensor core. When a kernel demands too many registers per thread, the SM fits fewer concurrent warps — register pressure directly governs occupancy, which governs how much instruction-level parallelism is available to hide memory latency.

Tensor cores are the throughput engine. Each SM has 4; the whole chip has 528. They execute a 16×16×16 MMA — two 16×16 input tiles (A and B, in fp16 or bf16) against a 16×16 fp32 accumulator — in a few cycles. The peak throughput only materialises if operands arrive in registers on time. A tensor core blocked on a register load is a tensor core producing zero FLOP/s.

Where decode and prefill sit on the roofline

Roofline: Decode vs Prefill

Arithmetic intensity (FLOP/byte) places each workload on the roofline. Decode is memory-bound; prefill is compute-bound. Batch size shifts the decode point rightward.

0247495742989Arithmetic intensity (FLOP/byte)TFLOP/sridge=295decode (B=1)prefill
Decode regime
Memory-bound
Intensity
1 FLOP/byte
Est. decode step time
41791.0 ms
Ridge point (H100)
295 FLOP/byte

Decode has arithmetic intensity equal to the batch size. For a model with NN parameters in fp16:

Idecode=2NB FLOP2N bytes=B FLOP/byteI_{\text{decode}} = \frac{2NB\ \text{FLOP}}{2N\ \text{bytes}} = B\ \text{FLOP/byte}

At batch 1, intensity is 1 FLOP/byte — four orders of magnitude below the H100's ridge point of 295. The tensor cores sit idle 99.7% of the time; the step is gated entirely on how fast HBM can ship weights to the SMs. Increasing batch size shifts the decode point rightward on the roofline. You cross into compute-bound territory only at B295B \approx 295, well beyond what the KV cache budget allows for large models.

Prefill has intensity proportional to the sequence length PP:

Iprefill=2NP FLOP2N bytes=P FLOP/byteI_{\text{prefill}} = \frac{2NP\ \text{FLOP}}{2N\ \text{bytes}} = P\ \text{FLOP/byte}

A 1024-token prefill has intensity 1024 FLOP/byte — past the ridge on both A100 and H100. Prefill is compute-bound, which is why adding GPUs or tensor parallelism helps it and why quantisation (which reduces the byte count denominator) helps decode more than prefill.

Why FlashAttention exists

The naive implementation of attention computes softmax(QK/dk)V\text{softmax}(QK^\top / \sqrt{d_k})V by materialising the full N×NN \times N score matrix SS in HBM:

  1. Write S=QK/dkS = QK^\top / \sqrt{d_k} to HBM — O(N2)O(N^2) bytes written
  2. Read SS back from HBM to apply softmax row-wise — O(N2)O(N^2) bytes read
  3. Write softmax output PP to HBM — O(N2)O(N^2) bytes written
  4. Read PP and VV from HBM to compute output — O(N2+Ndk)O(N^2 + Nd_k) bytes read

For N=8192N = 8192 and dk=128d_k = 128, the score matrix alone is 81922×2 bytes134 MB8192^2 \times 2\ \text{bytes} \approx 134\ \text{MB} — larger than the entire L2 cache. Every element is written to HBM and read back twice. This is the memory access pattern that makes attention the bottleneck for long contexts.

FlashAttention: SRAM Tiling

FlashAttention loops over Q tiles (outer) and K/V tiles (inner). At each step only the current tiles are in SRAM — the N² score matrix is never materialised in HBM. Step through to see the tile schedule.

Q tile in SRAM
tok0tok1tok2tok3
rows 03
K/V tile in SRAM
tok0tok1tok2tok3
cols 03
Output O (accumulating)
o0o1o2o3
Partial sum in SRAM
Score matrix S (N×N = 8×8) — green = computed this outer pass, teal = computed earlier, grey = future
k0
k1
k2
k3
k4
k5
k6
k7
q0
q1
q2
q3
q4
q5
q6
q7
Outer 1/2 · Inner 1/2

FlashAttention (Dao et al., 2022) reorders the computation to keep tiles in SRAM throughout:

  1. Divide QQ into outer tiles of size BrB_r rows; divide KK and VV into inner tiles of size BcB_c columns.
  2. For each outer tile QiQ_i, load it into SRAM once.
  3. For each inner tile KjK_j, VjV_j: load into SRAM, compute Sij=QiKjS_{ij} = Q_i K_j^\top, update the running softmax statistics (online softmax), accumulate into the output tile OiO_i.
  4. Write OiO_i to HBM once when the inner loop finishes.

The N×NN \times N score matrix is never written to HBM. HBM traffic drops from O(N2)O(N^2) to O(Ndk)O(N \cdot d_k) — linear rather than quadratic. The tile sizes are chosen to fit QiQ_i, KjK_j, VjV_j, and OiO_i simultaneously in SRAM:

Br=M/4dk,Bc=min ⁣(M/4dk,dk)B_r = \left\lfloor \frac{M / 4}{d_k} \right\rfloor, \quad B_c = \min\!\left(\left\lfloor \frac{M / 4}{d_k} \right\rfloor, d_k\right)

where MM is the SRAM capacity. With 228 KB SRAM and dk=128d_k = 128, tiles of roughly 450 rows fit — comfortably covering typical head dimensions. The result is that attention, previously one of the most bandwidth-hungry operations in the forward pass, becomes compute-bound for long sequences.

SM compute pipeline: loading a tile

SM Compute Pipeline — One MMA Step

Step through the pipeline to see how one matrix-multiply tile flows through the SM.

HBM
L2 Cache
SRAM
Registers
Tensor Cores
Step 1 / 6
Idle

The SM is waiting for work. Warps are stalled — no data in flight yet.

A single matrix-multiply tile flows through the SM in six stages. The time budget is dominated by the HBM load — a 16×16 fp16 tile is 512 bytes, and at 3.35 TB/s each SM can fetch one in 0.15 μs\sim 0.15\ \mu\text{s}. But the SM has 132 threads per warp and potentially dozens of warps in flight, all issuing loads; the effective per-warp bandwidth is a fraction of the peak, and the scheduling of those loads — latency hiding through warp switching — is what keeps the tensor cores fed.

The pipeline is:

HBM load3.35 TB/sL2 (cache hit)200 cycSRAM stage30 cycregister load1 cyctensor core MMA8 cycaccumulator\text{HBM load} \xrightarrow{\text{3.35 TB/s}} \text{L2 (cache hit)} \xrightarrow{\sim 200\ \text{cyc}} \text{SRAM stage} \xrightarrow{30\ \text{cyc}} \text{register load} \xrightarrow{1\ \text{cyc}} \text{tensor core MMA} \xrightarrow{8\ \text{cyc}} \text{accumulator}

For a bandwidth-bound kernel like single-sequence decode, the HBM load dominates. For a compute-bound kernel like large-batch prefill, the tensor core is the constraint and the memory pipeline must overlap loads for the next tile with the MMA for the current one — the software pipelining that CUTLASS and Triton handle automatically.

What this means for the serving stack

The memory hierarchy imposes three constraints that the entire serving stack is built around:

Decode is HBM-bandwidth-bound for any realistic batch. The weight read is the floor. Every technique that shrinks the weight volume — fp8 quantisation, speculative decoding that cashes multiple tokens from one read, GQA that reduces the KV head count — is an attack on this floor. Every technique that keeps the batch large — continuous batching, careful memory management — moves the decode point rightward on the roofline.

Attention is SRAM-constrained for long contexts. The N² HBM traffic of naive attention is what FlashAttention eliminates by tiling into SRAM. Grouped-query attention (GQA) and multi-query attention (MQA) reduce the K and V head count, shrinking both the SRAM footprint per tile and the HBM KV cache.

Tensor core utilisation requires register discipline. The peak 989 TFLOP/s is only available if operands arrive in registers without stalling. Occupancy — the number of warps an SM can hold simultaneously — is the lever: more warps mean more instruction-level parallelism to hide latency, but more registers per thread mean fewer warps. Kernel writers (and the compilers inside Triton/CUTLASS) spend more effort on this pipeline than on the arithmetic itself.

The next post works through the other structural gap in the series: how vLLM's PagedAttention manages the KV cache as a paged virtual address space, eliminating the fragmentation that would otherwise prevent the large batches the roofline demands.

Cite this work

Generated from article front matter.

Roy, Swastik. (2024). GPU Architecture for LLM Inference. S. Roy. https://swastikroy.me/blog/inference-gpu-architecture

Export PDF opens your browser’s print dialog — choose “Save as PDF” for a Zenodo-ready file.