Blog Post
DiT: Replacing the U-Net with a Transformer
DDPM, DDIM, and latent diffusion all use a U-Net backbone. DiT replaces it with a transformer — and finds that diffusion scales with model size the same way language models do.
Views: –6 min readCite
Every model in this series so far has shared a piece of machinery it never questioned: the denoiser is a U-Net. DDPM used one, DDIM sampled from the same trained weights, and latent diffusion shrank the images the U-Net had to process but left the architecture itself intact. The choice was reasonable — a U-Net's encoder compresses an image down through a bottleneck while its decoder reconstructs at full resolution, and the skip connections that wire each encoder stage to its matching decoder stage let fine spatial detail bypass the bottleneck. For pixel-to-pixel prediction that inductive bias is genuinely useful. The problem is what happens when you try to make the network bigger.
A U-Net does not scale gracefully. The skip connections fix where information may flow — shallow features can only rejoin the computation at one predetermined depth — and the multi-resolution structure bakes in assumptions about how features at different scales should combine, so adding parameters means widening hand-designed stages rather than scaling a single uniform primitive. The rest of deep learning had already learned the lesson that the field keeps relearning: the transformer, a stack of identical blocks with no architectural opinion about its input, scales more predictably than anything purpose-built, across language, vision, audio, and beyond. The question DiT (Peebles & Xie, 2022) asked is whether diffusion is an exception, and the answer is that it is not.
Turning a latent into something a transformer can read takes one step, borrowed directly from Vision Transformers. The 8×-downsampling VAE from the latent diffusion post turns the image DiT trains on into a latent, and you cut that into non-overlapping patches: with a patch size of , the grid becomes a array of patches, each flattened and linearly projected to the model width , which gives a sequence of tokens.
Add a 2D positional embedding so the transformer knows where each patch sat in the grid, and from there it is a standard transformer — self-attention and feed-forward blocks, no skip connections, no resolution hierarchy, every token able to attend to every other from the first layer.
What the patchify step leaves unanswered is how the two conditioning signals every diffusion model needs — the timestep and, here, the class label — get into a network that has no cross-attention. DiT's answer is adaptive layer normalization. Both signals are embedded and summed into a single conditioning vector , and instead of learning fixed scale and shift parameters inside each LayerNorm, a small per-block MLP predicts them from .
The conditioning therefore reaches into every normalization in the network by modulating the activation statistics directly, and the transformer block wraps its attention and feed-forward sublayers in it.
For a compact conditioning vector — a timestep plus a class index — this is markedly cheaper than giving every block its own cross-attention to a conditioning sequence, since you only pay for two small affine predictions per block instead of a full attention operation, and the best DiT variant goes further by also using to predict a per-block residual gate that lets the network learn how much each sublayer should contribute.
With the architecture fixed, scaling is the experiment, and DiT runs it cleanly across four sizes from DiT-S at 33M parameters to DiT-XL at 675M, holding everything else constant. Trained on ImageNet at and measured by FID — Fréchet Inception Distance, where lower means the generated and real image statistics are closer — the result is the curve the field has seen in language modeling over and over: FID falls monotonically as the model grows and as it trains longer, with no sign of the architecture fighting back. Holding the parameter count fixed and instead shrinking the patch size, which raises the token count and the compute per forward pass, improves FID the same way, so the gains track total compute rather than any one knob. The headline number lands where it matters: DiT-XL with a patch size of reaches an FID of on ImageNet , competitive with the best pixel-space diffusion models while doing all of its work in the VAE's compressed latent space at a fraction of their cost.
That number is reported with guidance, because DiT inherits everything from the previous part. The class conditioning is trained with the same conditioning-dropout that classifier-free guidance requires, so at sampling time the guidance scale becomes one more inference knob to sweep, trading diversity for fidelity exactly as before, and the figure is the best point on that sweep at a scale of . None of the diffusion mathematics changed — the forward process, the training loss, and the sampler are all as the earlier parts left them. Only the function that predicts the noise is different.
What the transformer buys, beyond the scaling curve, is everything attached to transformers. There are no skip connections quietly deciding that shallow and deep features must merge at fixed depths — the network learns whatever cross-scale mixing it needs through attention — and because the backbone is a standard transformer it slots directly into the infrastructure the field has spent years building: longer context, sparse and flash attention, the whole tooling stack. That compatibility is why the lineage continued. Stable Diffusion 3 and Flux both dropped the U-Net for a transformer backbone, specifically an MMDiT — a multimodal DiT that lets image tokens and text tokens attend to each other jointly rather than routing text in through cross-attention from the side.
DiT reframes what diffusion progress is. The earlier parts each found a better idea — a faster sampler, a cheaper space to work in, a stronger way to condition — and DiT's contribution is to show that once the backbone is a transformer, the next gains come from the same lever that drives language models: make it bigger and train it longer. That settles the architecture question. The remaining one is the training objective, and the next part replaces DDPM's specific noise process with a simpler, more general target that the newest models train under these same transformer backbones — that is flow matching.