S. Roy

Blog Post

Writing Fused Kernels with Triton

Kernel fusion eliminates the HBM round-trips between chained operations. Triton makes this practical in Python. This post builds a fused online softmax from scratch, then extends it to a fused RMSNorm + linear projection — the kind of kernel that actually speeds up LLM inference.

Views: 8 min readCite

The previous post established the CUDA programming model: grids, blocks, warps, coalescing, shared memory. This post uses those ideas to build something useful — a fused kernel that eliminates unnecessary HBM round-trips between operations that are always chained together.

Why fusion matters

Consider the forward pass through an attention layer. The unfused sequence is:

  1. Compute S=QK/dkS = QK^\top / \sqrt{d_k} → write SS to HBM
  2. Read SS from HBM → apply causal mask → write masked SS to HBM
  3. Read masked SS from HBM → apply softmax → write PP to HBM
  4. Read PP and VV from HBM → compute PVPV → write output to HBM

Steps 1–4 produce five HBM reads and four HBM writes for a matrix whose size is O(N2)O(N^2). Each of those is a round-trip through the memory hierarchy at 3.35 TB/s — slow.

The fused version (FlashAttention) completes steps 1–4 inside SRAM without any intermediate HBM writes. The savings are not marginal: for a sequence length of 4096 with dk=128d_k = 128, the unfused path writes and reads ~4 GB of attention scores; the fused path writes and reads ~128 KB (the output only). The operation shifts from memory-bound to compute-bound.

The key insight is that as long as operations are purely element-wise or reductions over the same data, they can be fused: load the data once, apply all transformations in registers or SRAM, write the final output once.

Triton's programming model

Triton is a Python DSL that compiles to GPU kernels. Its abstraction level is higher than CUDA C but lower than PyTorch — you write tile-level operations explicitly, but Triton handles the warp-level scheduling, bank-conflict analysis, and autotuning.

The key concepts:

  • tl.program_id(axis) — the block index (equivalent to blockIdx)
  • tl.arange(0, BLOCK) — a vector of thread indices within the block
  • tl.load(ptr + offsets, mask=...) — vectorized load with bounds check
  • tl.store(ptr + offsets, data, mask=...) — vectorized store
  • tl.sum(x, axis=0), tl.max(x, axis=0) — warp-level reductions
  • tl.constexpr — compile-time constants that Triton can autotune

Triton compiles each @triton.jit function to PTX and then to SASS (the GPU's native ISA). The @triton.autotune decorator runs the kernel with multiple configurations and picks the fastest — equivalent to manually sweeping block sizes but automated.

Example 1: fused online softmax

The numerically stable softmax requires two passes over the input: one to find the maximum (for stability), one to compute the exponentials and their sum. The standard PyTorch implementation does this as two separate kernels with an HBM round-trip between them. The fused version uses the online softmax algorithm to do both in one pass using running statistics.

The online algorithm maintains two accumulators per row:

  • mm — the running maximum seen so far
  • dd — the denominator: jtexp(xjm)\sum_{j \leq t} \exp(x_j - m)

When a new value xt+1x_{t+1} arrives:

mnew=max(m,xt+1),dnew=dexp(mmnew)+exp(xt+1mnew)m_{\text{new}} = \max(m, x_{t+1}), \qquad d_{\text{new}} = d \cdot \exp(m - m_{\text{new}}) + \exp(x_{t+1} - m_{\text{new}})

After processing all elements, the softmax output for element jj is exp(xjm)/d\exp(x_j - m) / d.

import triton
import triton.language as tl
 
@triton.jit
def fused_softmax_kernel(
    x_ptr, out_ptr,
    stride_row,
    N_COLS,
    BLOCK_SIZE: tl.constexpr,
):
    # Each program handles one row
    row_idx = tl.program_id(0)
    row_start = row_idx * stride_row
    offsets = tl.arange(0, BLOCK_SIZE)
 
    # Load row (mask out-of-bounds)
    row = tl.load(x_ptr + row_start + offsets, mask=offsets < N_COLS, other=-float('inf'))
 
    # Online pass: running max and denominator
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denom = tl.sum(numerator, axis=0)
 
    # Normalize and store
    out = numerator / denom
    tl.store(out_ptr + row_start + offsets, out, mask=offsets < N_COLS)
 
 
def fused_softmax(x: 'torch.Tensor') -> 'torch.Tensor':
    N_ROWS, N_COLS = x.shape
    out = torch.empty_like(x)
    BLOCK_SIZE = triton.next_power_of_2(N_COLS)
    fused_softmax_kernel[(N_ROWS,)](
        x, out,
        x.stride(0),
        N_COLS,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return out

The entire row fits in registers for sequences up to ~8192 columns on H100. Beyond that, the kernel needs to tile the row across SRAM blocks — which is exactly what FlashAttention does for the attention score rows.

What this eliminates: two separate HBM reads (one for max, one for exp) and one intermediate write. For a B×NB \times N attention score matrix, that is 2×B×N×2 bytes2 \times B \times N \times 2\ \text{bytes} of saved traffic — at 3.35 TB/s, 1 GB of saved reads = 0.3 ms.

Example 2: fused RMSNorm + linear projection

RMSNorm followed by a linear layer is one of the most common patterns in transformer blocks. The unfused version:

# Unfused — three HBM reads, two HBM writes
x_normed = rms_norm(x)    # read x, write x_normed
out = x_normed @ W.T      # read x_normed and W, write out

The fused version reads x once, computes the RMS in registers, normalizes, and immediately feeds the result into the matmul accumulator — x_normed never touches HBM.

@triton.jit
def fused_rmsnorm_linear_kernel(
    x_ptr, w_ptr, weight_ptr, out_ptr,
    M, N, K,              # x: (M, K), w: (N, K), out: (M, N)
    eps,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
 
    # Tile indices
    m_offsets = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
 
    # Accumulator for output tile
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
 
    # Compute RMS for each row in this tile (needs full row — inner loop)
    # Step 1: compute row-wise RMS norm factor
    rms_sum = tl.zeros((BLOCK_M,), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        k_offsets = k + tl.arange(0, BLOCK_K)
        x_tile = tl.load(x_ptr + m_offsets[:, None] * K + k_offsets[None, :],
                         mask=(m_offsets[:, None] < M) & (k_offsets[None, :] < K), other=0.)
        rms_sum += tl.sum(x_tile * x_tile, axis=1)
    rms_inv = tl.rsqrt(rms_sum / K + eps)  # (BLOCK_M,)
 
    # Step 2: apply norm + scale + multiply with W tile
    for k in range(0, K, BLOCK_K):
        k_offsets = k + tl.arange(0, BLOCK_K)
        x_tile = tl.load(x_ptr + m_offsets[:, None] * K + k_offsets[None, :],
                         mask=(m_offsets[:, None] < M) & (k_offsets[None, :] < K), other=0.)
        gamma = tl.load(weight_ptr + k_offsets, mask=k_offsets < K, other=1.)
        x_normed = x_tile * rms_inv[:, None] * gamma[None, :]
        w_tile = tl.load(w_ptr + n_offsets[:, None] * K + k_offsets[None, :],
                         mask=(n_offsets[:, None] < N) & (k_offsets[None, :] < K), other=0.)
        acc += tl.dot(x_normed, tl.trans(w_tile))
 
    # Store output
    tl.store(out_ptr + m_offsets[:, None] * N + n_offsets[None, :], acc,
             mask=(m_offsets[:, None] < M) & (n_offsets[None, :] < N))

The x read happens twice (once for the RMS computation, once for the matmul) — but both reads are served from L2 on the second pass if the tile fits in cache. The x_normed intermediate is never written to HBM.

The fusion decision tree

Not every pair of operations is worth fusing. The rule of thumb:

Fuse when: the intermediate result is large relative to the compute involved — softmax scores, attention masks, activation functions applied to matmul outputs, normalization layers. These are memory-bandwidth-bound unfused and become compute-bound fused.

Don't fuse when: the operations have different parallelism structures that can't be expressed as a single tile schedule — for example, operations along different axes of a 3D tensor may require different block decompositions that cannot be merged without serialization.

The kernel fusion test: if the intermediate tensor's size × 2 (read + write)×num operations\times\ 2\ \text{(read + write)} \times \text{num operations} is large relative to the compute time, fusion wins. For a B×NB \times N attention matrix at N=4096, the intermediate is B×4096×2 bytesB \times 4096 \times 2\ \text{bytes}; at 3.35 TB/s, eliminating it saves 0.5 ms\sim 0.5\ \text{ms} per layer — meaningful at 32 layers.

Autotuning

Triton's autotuner runs the kernel with multiple BLOCK_SIZE configurations and measures wall time:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE': 256}),
        triton.Config({'BLOCK_SIZE': 512}),
        triton.Config({'BLOCK_SIZE': 1024}),
        triton.Config({'BLOCK_SIZE': 2048}),
    ],
    key=['N_COLS'],
)
@triton.jit
def fused_softmax_kernel(..., BLOCK_SIZE: tl.constexpr):
    ...

The key ['N_COLS'] tells Triton to cache the best config per unique value of N_COLS — so a 512-column softmax and a 2048-column softmax each get their own tuned block size, compiled once and cached to disk.

Autotuning is the right tool for: block sizes, number of pipeline stages, vectorization widths. It is not a substitute for correct coalescing or avoiding bank conflicts — those must be correct by construction.

What to fuse in practice for LLM inference

The highest-value fusion targets in a transformer decoder, in rough order of impact:

Operation pairWhy it matters
QK attention + softmax + masked fillEliminates N² intermediate; the FlashAttention insight
RMSNorm + Q/K/V projectionEliminates normalized activations from HBM
SiLU/GeLU + gate multiply (SwiGLU)Fuse gate with pointwise activation; saves one intermediate
Dequantize + matmulRead INT8 weights, dequantize to fp16 in registers, multiply — one weight read
Matmul + residual add + RMSNormPost-attention merge; saves two reads of the residual stream

Each of these is a kernel that ships in production inference libraries (vLLM, TensorRT-LLM, SGLang) and has measurable latency impact at 32+ layers.

The next post in the series covers profiling: how to use nsys and Nsight Compute to measure whether your kernel is actually memory-bound or compute-bound, read roofline charts from real traces, and identify which operation is the bottleneck in a real forward pass.

Cite this work

Generated from article front matter.

Roy, Swastik. (2024). Writing Fused Kernels with Triton. S. Roy. https://swastikroy.me/blog/gpu-kernel-triton-fused-ops

Export PDF opens your browser’s print dialog — choose “Save as PDF” for a Zenodo-ready file.