Blog Post
Logit Lens: How Predictions Form Layer by Layer
Applying the unembedding matrix at intermediate layers to watch how a transformer's prediction evolves — and what direct logit attribution tells us about which components matter.
Views: –8 min readCite
The unembedding matrix converts the final residual stream vector into a distribution over the vocabulary. It is applied exactly once, at the end. But the residual stream exists at every layer — and nothing stops you from applying earlier. That is the logit lens: a diagnostic that treats every intermediate residual stream vector as if it were the final one, and reads off what token the model "would have predicted" had computation stopped there.
The Basic Construction
Formally, let be the residual stream at position after layer . The standard logit lens applies the unembedding directly:
At (the final layer), this recovers the model's actual output distribution. At , it produces an intermediate "prediction" that reflects how much of the final answer is already encoded in the stream.
Nostalgebraist (2020) coined the term and reported the key observation: for many predictions, the correct answer appears in the top-5 of well before the final layer. The last few layers are not computing the answer — they are sharpening a prediction that was already mostly formed.
What the Lens Reveals
Running the logit lens on a factual recall task (e.g., "The capital of France is ___") typically shows a consistent pattern. In the first few layers, the top token is noise — punctuation, frequent function words, or whatever the embedding geometry happens to project to. By the middle layers, semantically plausible tokens begin appearing (European cities, country names). In the last quarter of the network, the correct answer ("Paris") locks into the top position and stays there, with the final layers increasing its probability mass rather than changing the prediction.
This pattern is not universal. For tasks requiring genuine multi-step reasoning — counting, arithmetic, logical composition — the correct answer may not crystallize until very late layers, or may not be robustly present in intermediate layers at all. The logit lens is most informative for tasks where knowledge retrieval dominates over multi-step computation.
For failure cases, the lens is diagnostically useful: if the model predicts the wrong token, you can watch when the wrong prediction locks in, and infer which layers are responsible for the error. A wrong prediction that solidifies at layer 15 (of 32) suggests the failure is in the early-to-mid processing stages; one that only diverges in the final three layers points to late-stage refinement going wrong.
The Tuned Lens
The logit lens has a systematic problem: early residual stream vectors are not in the same coordinate system as the final one. The unembedding matrix was trained to decode the final layer's representation; applying it to earlier layers is asking it to decode a representation that hasn't yet been rotated into the "prediction subspace." The result is noisy intermediate predictions, and the lens may understate how much is already encoded in the stream.
Belrose et al. (2023) address this with the tuned lens. Instead of applying directly, they train a learned affine translator per layer, then apply the unembedding:
Each is trained to minimize the KL divergence between and the model's actual final prediction , with a small L2 penalty to keep close to the identity. This makes the intermediate predictions faithful: by construction, a tuned lens prediction at layer is the best affine approximation to the final prediction achievable from the layer- residual stream.
The tuned lens typically shows prediction crystallization happening earlier than the raw logit lens suggests, because it compensates for the coordinate mismatch. The two lenses together bound the uncertainty: the raw lens gives a conservative lower bound on how much is known early; the tuned lens gives a tighter estimate.
Direct Logit Attribution
The logit lens shows how the aggregate prediction evolves. Direct logit attribution (DLA) decomposes the final logit into per-component contributions.
Because the residual stream at the final layer is a sum of all component outputs, and is linear, the final logit for token decomposes as:
where is the additive contribution of attention head and is the MLP contribution at layer . Each term is a scalar: the "vote" that one component casts for token .
DLA turns an opaque logit into a sum of legible votes. For the IOI task ("Mary gave the ball to John, and then ___"), Wang et al. (2022) showed that name-mover heads in late layers contribute large positive DLA for "John" at the prediction position, while duplicate-token heads contribute negative DLA for "Mary" — precisely the structure their circuit analysis predicted. DLA provides both a quick diagnostic (which layers matter most?) and a quantitative check on circuit hypotheses (do the components you identified account for the total logit difference?).
Logit Difference as a Diagnostic Signal
For binary classification tasks — or any setting with a clear correct vs. incorrect token — cross-entropy loss is a noisy metric for circuit analysis because it conflates the model's confidence on the correct answer with its distribution over distractors. A cleaner signal is the logit difference:
For the IOI task with subject "Mary" and indirect object "John", this is . This scalar is positive when the model is right, negative when wrong, and its magnitude reflects confidence on the margin that matters.
DLA directly decomposes the logit difference: each component's contribution is the difference of its votes for the two tokens. Components that are interpretable as "promoting the correct answer" should have large positive DLA for the difference; components that introduce confusion should be negative. This makes logit difference DLA a particularly sharp tool for circuit verification — it tells you not just which components matter, but whether each component is helping or hurting the specific decision.
The IOI paper uses logit difference throughout as its primary metric, and the approach has since become standard in mechanistic interpretability circuit analyses.
Token Space vs. Semantic Space
There is a tension worth naming explicitly. The logit lens operates in token space: it maps residual stream vectors to probability distributions over discrete vocabulary items. This is interpretable — we can read token probabilities directly — but it may not faithfully reflect the model's internal computation.
The model reasons in a continuous, high-dimensional semantic space from layer 0 through layer . The projection into token space happens only at the final unembedding step. Intermediate residual stream vectors encode geometric relationships — distances, angles, subspace structure — that are meaningful for the model's computation but that don't cleanly map to individual tokens. A vector that, when projected to token space, says "Paris" with 30% confidence might computationally be far closer to the final "Paris" representation than a naive reading of that probability suggests.
This is why Dar et al. (2022) advocate analyzing transformers in embedding space rather than in token space: they show that the rows of form a structured vocabulary manifold, and that intermediate residual vectors can be meaningfully decomposed in terms of this manifold even when no single token dominates. The logit lens is a useful first pass; for careful circuit analysis, it should be supplemented with methods that respect the geometry of the representation space rather than collapsing it to a vocabulary distribution.
That geometry — and how to recover interpretable features from it — is the subject of the next post in this series.
Practical Applications
Beyond circuit analysis, the logit lens has direct debugging utility:
Identifying critical layers for fine-tuning. If the logit lens shows that the correct answer is stably predicted from layer 16 onward in a 32-layer model, fine-tuning layers 0–15 for that task is likely wasteful. Layer-selective fine-tuning strategies (like training only the last layers) can be grounded in logit lens observations rather than arbitrary hyperparameter search.
Finding where factual knowledge is stored. For factual probes, the logit lens can identify the layer at which a specific fact first enters the prediction. Combined with MLP activation analysis, this localizes factual storage to specific MLP layers — the empirical basis for weight-editing methods.
Diagnosing hallucinations. If a model hallucinates a confident wrong answer, the logit lens often shows the wrong prediction crystallizing in mid-to-late layers, sometimes overwriting an earlier correct prediction. This suggests that the failure is not in knowledge storage (the correct answer was briefly present) but in some later processing stage — which could be a circuit inhibiting the correct answer or promoting a wrong one.
References
- nostalgebraist (2020). "Interpreting GPT: the logit lens." LessWrong. lesswrong.com/posts/AcKRB8wDpdaN6v6ru
- Belrose, N., et al. (2023). "Eliciting Latent Predictions from Transformers with the Tuned Lens." arXiv:2303.08112.
- Wang, K., et al. (2022). "Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small." arXiv:2211.00593.
- Dar, G., et al. (2022). "Analyzing Transformers in Embedding Space." arXiv:2209.02535.