Blog Post
Debugging Transformer Training Runs: Reading the Curves
Most training failures leave signatures in the metrics before they fully manifest. Here's how to read loss curves, gradient norms, learning rate schedules, and activation statistics to diagnose what's going wrong.
Views: –13 min readCite
A training run that diverges at step 40,000 was already in trouble at step 38,000, and a loss curve that plateaus at a bad value stopped learning long before the line went flat. The metrics tell the story in real time; the only question is whether you can read them while there is still a run to save. Almost every failure mode has a signature — a shape in the loss, a movement in the gradient norm, a skew in the per-layer statistics — that shows up well before the run becomes unrecoverable. The skill is matching the shape on the screen to the cause behind it, and the patterns below are the small vocabulary that covers most of what you will actually see.
Loss spike then recovery
loss
│ ╭─╮
│ ╭──╯ ╰──╮
│──╯ ╰────
└──────────────── steps
The training loss jumps sharply — two to ten times its running value — and the global gradient norm spikes in the same step, then both settle back to trend over the next 100 to 500 steps. This is usually a single bad batch: a handful of outlier sequences (a run of repeated tokens, a corrupted document, a language the tokenizer mangles) produces an enormous gradient, the optimizer takes one oversized step, and the model spends a few hundred steps climbing back to where it was. The same shape appears when the peak learning rate is slightly too high and an ordinary-but-unlucky gradient is enough to destabilize the model momentarily.
The first thing to confirm is that gradient clipping is actually applied, because a recovering spike is exactly what clipping is supposed to prevent from becoming a divergence. Clipping rescales the gradient whenever its norm exceeds a threshold .
With — the conventional default — any gradient larger than unit norm is shrunk back to it while its direction is preserved, so a single pathological batch can still nudge the model but cannot blow it up. If you are seeing frequent spikes even with clipping at , lower the peak learning rate by around 20%, and log the batch index alongside the loss so that when a spike happens you can pull up the exact sequences that caused it.
The asymmetry is what matters here: spikes that recover are almost always benign and not worth chasing. Spikes that do not recover are a different pattern entirely, and that distinction — recovery versus no recovery — is the single most useful thing to watch in the first second after a jump.
Divergence
loss
│ ╱
│ ╱╱╱
│ ╱╱╱
│──────
└──────────────── steps
Here the loss does not recover. After some point it climbs monotonically, the gradient norm climbs with it, and the run ends in NaN. The usual causes, roughly in order of how often they are the real problem: the learning rate is too high (reduce it by 3–10×, not 20%); gradient clipping is disabled or its threshold is set so large it never engages; the data pipeline is feeding corrupt sequences with out-of-range token IDs; the model is using Post-LN at a depth where it is unstable, which normalization placement is meant to fix by moving the norm inside the residual branch; or attention is overflowing in fp16 because the scores are computed without proper scaling.
The diagnostic that separates these is the gradient norm in the steps just before the loss turns up. If the norm was already creeping upward 100 steps before divergence, the instability was building and the cause is a learning rate or normalization problem — something systemic that the optimizer was slowly losing control of. If the norm was flat and then jumped in a single step, suspect the data: one bad batch tipped a model that was otherwise fine, and the question is why clipping did not catch it.
Once a NaN appears in a single weight it does not stay local. The next matrix multiply that touches
that weight produces NaN outputs, the gradients of those outputs are NaN, and within a forward and
backward pass the corruption has propagated through the entire network. There is no recovering a run
in this state, so the only correct response is to detect it and stop immediately — check
torch.isnan(loss).any() every step and halt the moment it fires. The expensive mistake is letting a
NaN run continue and checkpoint, because a poisoned checkpoint silently contaminates every run that
later resumes from it.
Loss plateau early
loss
│╲
│ ╲
│ ╲___________
│
└──────────────── steps
The loss drops quickly through the first 1–5% of training and then flattens at a value that is clearly too high, with no further improvement no matter how long you wait. The model learned the easy structure — unigram frequencies, the most common bigrams — and then stopped. The common causes are that the learning rate is too low for the optimizer to take steps large enough to leave the initial basin; that warmup is too long, so the model crawls at a negligible learning rate for far too many steps before the schedule lets it move; that a large fraction of each batch is degenerate or constant, so there is little signal to learn from (check the token entropy per batch — a healthy text batch sits well above zero); or that gradients are vanishing in part of the network.
That last case is worth instrumenting directly, because "the loss is flat" does not tell you where the signal is dying. Log the gradient norm of every parameter, keyed by name.
grad_norms = {
name: param.grad.norm().item()
for name, param in model.named_parameters()
if param.grad is not None
}If the early layers show norms near zero while the later layers look normal, the gradient is failing to propagate back through depth, and that is almost always a normalization problem rather than a learning-rate one — the fix is in where and how the norm is applied, not in the optimizer. A uniform near-zero reading across all layers means something else, and the next pattern covers how to tell those apart.
Validation loss diverges from training loss
loss
│╲ valid ╱╱╱
│ ╲ ╱
│ ╲────╱
│ train
└──────────────── steps
Training loss keeps decreasing while validation loss bottoms out and then starts to rise. This is overfitting, and it is worth being precise about when it actually happens. In large-scale pretraining it is rare, because the data is the bottleneck long before model capacity is — you typically run out of compute to make another pass over the corpus well before the model is able to memorize it. The shape shows up far more often in fine-tuning, where the dataset is small and a high-capacity model can fit it in a few epochs.
For the fine-tuning case the levers are familiar: add dropout to the adapter layers so the model cannot rely on any single pathway, cut the number of epochs, lower the learning rate, and stop early on the validation curve rather than the training one. Early stopping is doing the real work here — the minimum of the validation loss is the checkpoint you want, and every step past it is the model trading generalization for memorization.
There is a pretraining version that looks similar but means something different. If the validation loss plateaus while the training loss keeps inching down, the model has not overfit in the classic sense; it has memorized specific training sequences that happen to overlap your validation set. Before concluding the model is overfitting, check for contamination — if your validation documents leaked into the training corpus, the validation loss is measuring memorization and is not a clean signal at all.
Loss oscillates without converging
loss
│╲╱╲╱╲╱╲╱╲╱╲╱
│
└──────────────── steps
The loss bounces up and down with no clear downward trend, and the gradient norm oscillates along with it. The optimizer is overshooting the minimum: each step is large enough to jump past the bottom of the local basin and land on the far wall, and the next step jumps back. The cause is a learning rate that is too large relative to the effective batch size, which is really a statement about gradient variance — a small batch gives a noisy estimate of the true gradient, and a noisy gradient scaled by a large step size walks erratically.
The variance of the mini-batch gradient falls off inversely with the batch size , where is the per-example gradient variance.
Doubling the batch halves the variance of the gradient estimate, which is why the fixes all converge on the same idea: lower the learning rate, raise the batch size, or — when memory will not allow a larger batch — accumulate gradients over more micro-steps before each update, which buys the same variance reduction at the cost of throughput. One more thing to check is Adam's : if it has drifted below 0.9 the second-moment estimate is averaged over too short a window to be stable, the per-coordinate step sizes jitter, and the optimizer itself becomes a source of the oscillation rather than a cure for it.
Gradient norm by layer
The previous patterns all looked at the loss; this one looks at the gradients directly, and it is the single most informative thing you can log when something is wrong but the loss is ambiguous. A healthy run has gradient norms that are roughly uniform across layers — not identical, depth and width introduce real variation, but within an order of magnitude of each other. The departures from uniformity each have a specific reading.
When the first few layers carry near-zero gradients and the later layers are normal, the gradient is vanishing on its way back through depth. This is almost always a normalization issue — verify that Pre-LN is correctly applied at every layer and that no residual branch is missing its norm. When the last few layers show exploding gradients while the earlier ones are fine, the learning rate is too high for the output end of the network, or the output projection was initialized at the wrong scale. When every layer reads near zero, either the effective learning rate is zero — warmup misconfigured, schedule pointing at the wrong step count — or the loss has saturated and there is genuinely no gradient signal left to propagate, which loops back to the degenerate-data check from the plateau pattern. And when a single layer is spiky while its neighbors are calm, that layer's weights likely have a numerical problem of their own; log the weight statistics — mean, standard deviation, maximum absolute value — for that specific layer and watch whether the max blows up before the spike.
Token loss by position
Cross-entropy averaged over the whole sequence hides a structure that is diagnostic on its own: break the loss down by token position. A healthy run shows higher loss on early positions, where the model has little context to predict from, and lower loss on later positions, where it has the preceding tokens to condition on. That downward slope across position is evidence the model is using context at all, and its absence is the failure to watch for.
If the per-position loss is flat — the same at position 1 as at position 1000 — the model is not conditioning on context, and there are two usual culprits. The attention mask may be zeroed or malformed, so every position effectively attends to nothing; or the positional encoding is misapplied, RoPE rotated with the wrong frequencies or offset, so the model cannot tell positions apart. If instead the loss is fine up to some position and worse beyond it, the model has learned short-range structure but not long-range, which is expected early in training and a problem only if it persists. When it does persist, the cause is usually in how the sequences were assembled — check the data pipeline's sequence packing and document masking, because a model that attends across document boundaries learns the wrong thing about long-range dependence and a model whose long sequences were silently truncated never sees the range you are measuring.
Learning rate schedule issues
loss
│╲
│ ╲╲╲
│ ╲╲──────────────────────
└──────────────────────────── steps
↑ LR hits floor too early
This plateau is subtler than the early one because the model was learning fine and then stopped, not because it converged but because the schedule starved it. A cosine schedule that decays to roughly 10% of its peak and then sits there for the remainder of training will flatten the loss while the model still has progress left in it — the steps simply became too small to make any. The tell is that the plateau coincides with the learning rate reaching its floor, which is why logging the learning rate itself, not just trusting that the schedule does what you configured, is worth the one extra line.
The fix is not to lower the loss target but to give the schedule more room: extend the training duration so the decay is stretched over more steps, or switch to a schedule that does not collapse so aggressively — a linear decay, or a cosine with a longer period and a higher floor. The mistake this pattern guards against is reading a schedule-induced plateau as convergence and stopping a run that had real headroom remaining.
What to log
The patterns above are only readable if the right signals are being recorded, and the right cadence matters as much as the right metrics — too sparse and you miss the 200-step window where a divergence announced itself, too dense and the logging itself becomes the throughput bottleneck. A reasonable default is every 50 steps in normal operation and every 10 when you are actively debugging.
At that cadence, log the training loss as the primary signal; the global gradient norm before clipping, because the pre-clip norm is what tells you how violent the raw gradient was; the current learning rate, to confirm the schedule is where you think it is; and tokens per second, as a cheap sanity check that catches data-loader stalls and hardware problems the moment throughput drops. The heavier signals — per-position loss and per-layer gradient norm — are expensive enough to log on a slower clock, every 500 steps, which is frequent enough to catch a developing vanishing-gradient problem but rare enough not to slow the run. Validation loss goes on its own schedule, every 1000 steps over a fixed held-out set, fixed so the curve is comparable across the whole run rather than drifting with whatever happened to be sampled.
The earlier posts in this series were about choices made before a single step runs — where the normalization goes, which optimizer drives the update, how the data is packed, which attention variant holds the KV cache, how position is encoded, and how the FFN is gated and routed. This post is about what happens when those choices interact badly at scale — a learning rate that was fine until it met a particular batch, a normalization placement that held until the network got deep enough to expose it. None of those failures arrive without warning. They are written into the curves first, and the run is salvageable for exactly as long as it takes you to read them.