Blog Post
Writing Fused Kernels with Triton
Kernel fusion eliminates the HBM round-trips between chained operations. Triton makes this practical in Python. This post builds a fused online softmax from scratch, then extends it to a fused RMSNorm + linear projection — the kind of kernel that actually speeds up LLM inference.
Views: –8 min readCite
The previous post established the CUDA programming model: grids, blocks, warps, coalescing, shared memory. This post uses those ideas to build something useful — a fused kernel that eliminates unnecessary HBM round-trips between operations that are always chained together.
Why fusion matters
Consider the forward pass through an attention layer. The unfused sequence is:
- Compute → write to HBM
- Read from HBM → apply causal mask → write masked to HBM
- Read masked from HBM → apply softmax → write to HBM
- Read and from HBM → compute → write output to HBM
Steps 1–4 produce five HBM reads and four HBM writes for a matrix whose size is . Each of those is a round-trip through the memory hierarchy at 3.35 TB/s — slow.
The fused version (FlashAttention) completes steps 1–4 inside SRAM without any intermediate HBM writes. The savings are not marginal: for a sequence length of 4096 with , the unfused path writes and reads ~4 GB of attention scores; the fused path writes and reads ~128 KB (the output only). The operation shifts from memory-bound to compute-bound.
The key insight is that as long as operations are purely element-wise or reductions over the same data, they can be fused: load the data once, apply all transformations in registers or SRAM, write the final output once.
Triton's programming model
Triton is a Python DSL that compiles to GPU kernels. Its abstraction level is higher than CUDA C but lower than PyTorch — you write tile-level operations explicitly, but Triton handles the warp-level scheduling, bank-conflict analysis, and autotuning.
The key concepts:
tl.program_id(axis)— the block index (equivalent toblockIdx)tl.arange(0, BLOCK)— a vector of thread indices within the blocktl.load(ptr + offsets, mask=...)— vectorized load with bounds checktl.store(ptr + offsets, data, mask=...)— vectorized storetl.sum(x, axis=0),tl.max(x, axis=0)— warp-level reductionstl.constexpr— compile-time constants that Triton can autotune
Triton compiles each @triton.jit function to PTX and then to SASS (the GPU's native ISA). The @triton.autotune decorator runs the kernel with multiple configurations and picks the fastest — equivalent to manually sweeping block sizes but automated.
Example 1: fused online softmax
The numerically stable softmax requires two passes over the input: one to find the maximum (for stability), one to compute the exponentials and their sum. The standard PyTorch implementation does this as two separate kernels with an HBM round-trip between them. The fused version uses the online softmax algorithm to do both in one pass using running statistics.
The online algorithm maintains two accumulators per row:
- — the running maximum seen so far
- — the denominator:
When a new value arrives:
After processing all elements, the softmax output for element is .
import triton
import triton.language as tl
@triton.jit
def fused_softmax_kernel(
x_ptr, out_ptr,
stride_row,
N_COLS,
BLOCK_SIZE: tl.constexpr,
):
# Each program handles one row
row_idx = tl.program_id(0)
row_start = row_idx * stride_row
offsets = tl.arange(0, BLOCK_SIZE)
# Load row (mask out-of-bounds)
row = tl.load(x_ptr + row_start + offsets, mask=offsets < N_COLS, other=-float('inf'))
# Online pass: running max and denominator
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denom = tl.sum(numerator, axis=0)
# Normalize and store
out = numerator / denom
tl.store(out_ptr + row_start + offsets, out, mask=offsets < N_COLS)
def fused_softmax(x: 'torch.Tensor') -> 'torch.Tensor':
N_ROWS, N_COLS = x.shape
out = torch.empty_like(x)
BLOCK_SIZE = triton.next_power_of_2(N_COLS)
fused_softmax_kernel[(N_ROWS,)](
x, out,
x.stride(0),
N_COLS,
BLOCK_SIZE=BLOCK_SIZE,
)
return outThe entire row fits in registers for sequences up to ~8192 columns on H100. Beyond that, the kernel needs to tile the row across SRAM blocks — which is exactly what FlashAttention does for the attention score rows.
What this eliminates: two separate HBM reads (one for max, one for exp) and one intermediate write. For a attention score matrix, that is of saved traffic — at 3.35 TB/s, 1 GB of saved reads = 0.3 ms.
Example 2: fused RMSNorm + linear projection
RMSNorm followed by a linear layer is one of the most common patterns in transformer blocks. The unfused version:
# Unfused — three HBM reads, two HBM writes
x_normed = rms_norm(x) # read x, write x_normed
out = x_normed @ W.T # read x_normed and W, write outThe fused version reads x once, computes the RMS in registers, normalizes, and immediately feeds the result into the matmul accumulator — x_normed never touches HBM.
@triton.jit
def fused_rmsnorm_linear_kernel(
x_ptr, w_ptr, weight_ptr, out_ptr,
M, N, K, # x: (M, K), w: (N, K), out: (M, N)
eps,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# Tile indices
m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# Accumulator for output tile
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Compute RMS for each row in this tile (needs full row — inner loop)
# Step 1: compute row-wise RMS norm factor
rms_sum = tl.zeros((BLOCK_M,), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
k_offsets = k + tl.arange(0, BLOCK_K)
x_tile = tl.load(x_ptr + m_offsets[:, None] * K + k_offsets[None, :],
mask=(m_offsets[:, None] < M) & (k_offsets[None, :] < K), other=0.)
rms_sum += tl.sum(x_tile * x_tile, axis=1)
rms_inv = tl.rsqrt(rms_sum / K + eps) # (BLOCK_M,)
# Step 2: apply norm + scale + multiply with W tile
for k in range(0, K, BLOCK_K):
k_offsets = k + tl.arange(0, BLOCK_K)
x_tile = tl.load(x_ptr + m_offsets[:, None] * K + k_offsets[None, :],
mask=(m_offsets[:, None] < M) & (k_offsets[None, :] < K), other=0.)
gamma = tl.load(weight_ptr + k_offsets, mask=k_offsets < K, other=1.)
x_normed = x_tile * rms_inv[:, None] * gamma[None, :]
w_tile = tl.load(w_ptr + n_offsets[:, None] * K + k_offsets[None, :],
mask=(n_offsets[:, None] < N) & (k_offsets[None, :] < K), other=0.)
acc += tl.dot(x_normed, tl.trans(w_tile))
# Store output
tl.store(out_ptr + m_offsets[:, None] * N + n_offsets[None, :], acc,
mask=(m_offsets[:, None] < M) & (n_offsets[None, :] < N))The x read happens twice (once for the RMS computation, once for the matmul) — but both reads are served from L2 on the second pass if the tile fits in cache. The x_normed intermediate is never written to HBM.
The fusion decision tree
Not every pair of operations is worth fusing. The rule of thumb:
Fuse when: the intermediate result is large relative to the compute involved — softmax scores, attention masks, activation functions applied to matmul outputs, normalization layers. These are memory-bandwidth-bound unfused and become compute-bound fused.
Don't fuse when: the operations have different parallelism structures that can't be expressed as a single tile schedule — for example, operations along different axes of a 3D tensor may require different block decompositions that cannot be merged without serialization.
The kernel fusion test: if the intermediate tensor's size is large relative to the compute time, fusion wins. For a attention matrix at N=4096, the intermediate is ; at 3.35 TB/s, eliminating it saves per layer — meaningful at 32 layers.
Autotuning
Triton's autotuner runs the kernel with multiple BLOCK_SIZE configurations and measures wall time:
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
],
key=['N_COLS'],
)
@triton.jit
def fused_softmax_kernel(..., BLOCK_SIZE: tl.constexpr):
...The key ['N_COLS'] tells Triton to cache the best config per unique value of N_COLS — so a 512-column softmax and a 2048-column softmax each get their own tuned block size, compiled once and cached to disk.
Autotuning is the right tool for: block sizes, number of pipeline stages, vectorization widths. It is not a substitute for correct coalescing or avoiding bank conflicts — those must be correct by construction.
What to fuse in practice for LLM inference
The highest-value fusion targets in a transformer decoder, in rough order of impact:
| Operation pair | Why it matters |
|---|---|
| QK attention + softmax + masked fill | Eliminates N² intermediate; the FlashAttention insight |
| RMSNorm + Q/K/V projection | Eliminates normalized activations from HBM |
| SiLU/GeLU + gate multiply (SwiGLU) | Fuse gate with pointwise activation; saves one intermediate |
| Dequantize + matmul | Read INT8 weights, dequantize to fp16 in registers, multiply — one weight read |
| Matmul + residual add + RMSNorm | Post-attention merge; saves two reads of the residual stream |
Each of these is a kernel that ships in production inference libraries (vLLM, TensorRT-LLM, SGLang) and has measurable latency impact at 32+ layers.
The next post in the series covers profiling: how to use nsys and Nsight Compute to measure whether your kernel is actually memory-bound or compute-bound, read roofline charts from real traces, and identify which operation is the bottleneck in a real forward pass.