S. Roy

Blog Post

How PPO Computes Loss Over a Language Model Output

Most explanations of PPO stay at the algorithm level. This post goes one level deeper: how the surrogate loss is actually computed token by token for a language model response.

Views: 5 min readCite

Picture the concrete objects PPO has in hand at the moment it computes a loss. There is a prompt xx, a completion y=(y1,,yL)y = (y_1, \dots, y_L) of LL tokens that the policy sampled, a frozen copy of the policy taken at the start of this iteration (πθold\pi_{\theta_\text{old}}), a separate frozen reference model (πref\pi_\text{ref}, the SFT checkpoint), and a single scalar from the reward model scoring how good yy was. Everything below is the arithmetic that turns those objects into one number to backpropagate.

Log-probs per token, in one pass each

The first step is to recover the per-token log-probabilities under both the current policy and the frozen snapshot. Each model runs a single forward pass over the concatenation [x;y][x; y] with yy supplied as labels, and at position tt we read off the log-probability the model assigns to the token that was actually generated, logπ(ytx,y<t)\log \pi(y_t \mid x, y_{<t}).

tθ=logπθ(ytx,y<t),told=logπθold(ytx,y<t)\ell^\theta_t = \log \pi_\theta(y_t \mid x, y_{<t}), \qquad \ell^\text{old}_t = \log \pi_{\theta_\text{old}}(y_t \mid x, y_{<t})

Teacher forcing is what makes this cheap: because the whole completion is already known, one pass emits the logits at every position simultaneously, and a single gather along the vocabulary axis picks out the realized token at each step. There is no autoregressive decode loop here — generation happened earlier, during rollout; loss computation just re-scores a fixed string.

Ratios in log space

With both log-prob vectors in hand, the probability ratio at each token is their difference exponentiated. We never form πθ/πθold\pi_\theta / \pi_{\theta_\text{old}} directly — dividing two tiny softmax probabilities is a fast route to underflow — so the ratio is computed as the exponential of a log-difference, which keeps every intermediate in a numerically safe range.

rt=exp ⁣(tθtold)r_t = \exp\!\big(\ell^\theta_t - \ell^\text{old}_t\big)

At the very first gradient step of an iteration θ=θold\theta = \theta_\text{old}, so every rt=1r_t = 1 exactly; the ratios drift away from 11 only as the inner epochs update θ\theta against the same batch of rollouts.

One advantage, broadcast across tokens

The advantage A^\hat A scales each ratio, and in vanilla RLHF it is a single scalar per completion, derived from the reward-model score (typically after whitening across the batch and folding in the value baseline through GAE). The same A^\hat A multiplies the term at every token position tt — the reward model judged the response as a whole, so credit is assigned uniformly to all the tokens that produced it.

A^t=A^for all t{1,,L}\hat A_t = \hat A \quad \text{for all } t \in \{1, \dots, L\}

This is exactly the assumption a process reward model relaxes: a PRM emits a score at intermediate steps, so A^t\hat A_t genuinely varies along the sequence and tokens in a flawed reasoning step can be penalized while a correct prefix is spared. Outcome-only RLHF cannot make that distinction — every token in a 400-token answer shares the fate of the final scalar.

Clip, minimize, average

Each token now contributes the standard PPO clipped term, the minimum of the raw ratio-weighted advantage and a version whose ratio is clamped to [1ϵ,1+ϵ][1-\epsilon,\,1+\epsilon]. The per-token losses are negated (optimizers minimize) and averaged over the LL tokens of the completion and over the BB completions in the batch.

L=1Bb=1B1Lbt=1Lbmin ⁣(rtA^b,  clip(rt,1ϵ,1+ϵ)A^b)\mathcal{L} = -\frac{1}{B}\sum_{b=1}^{B} \frac{1}{L_b}\sum_{t=1}^{L_b} \min\!\big(r_t\,\hat A_b,\; \operatorname{clip}(r_t, 1-\epsilon, 1+\epsilon)\,\hat A_b\big)

Averaging within a sequence before averaging across the batch (rather than pooling all tokens into one mean) keeps a long completion from dominating the gradient simply by contributing more terms, though implementations differ on this and some pool over all valid tokens at once.

Where the reference model enters

The KL penalty against πref\pi_\text{ref} is a separate object from πθold\pi_{\theta_\text{old}}, and conflating the two is the usual point of confusion. The snapshot πθold\pi_{\theta_\text{old}} exists only to define the ratio and is refreshed every PPO iteration; the reference πref\pi_\text{ref} is the SFT model, frozen for the entire run, and its job is to anchor the policy so reward optimization does not wander into degenerate text. It enters as a per-token penalty folded into the reward before advantages are computed:

rtshaped=rtRMβ(tθlogπref(ytx,y<t))r^{\text{shaped}}_t = r^{\text{RM}}_t - \beta\big(\ell^\theta_t - \log \pi_\text{ref}(y_t \mid x, y_{<t})\big)

where rtRMr^{\text{RM}}_t is the reward-model scalar placed at the final token and zero elsewhere. So the SFT model shows up twice in spirit but at different stages: as the KL anchor shaping the reward, and (at iteration boundaries) as the eventual ancestor of πθold\pi_{\theta_\text{old}}.

The computation in code

Stripped to its essentials, the loss over a batch of completions is a handful of tensor operations. The log-probs come from gathering the realized tokens out of the model's log-softmax; everything after that is the algebra above.

import torch
import torch.nn.functional as F
 
def per_token_logps(logits, labels):
    # logits: (B, L, V) predicting token t+1; labels: (B, L) realized tokens
    logp = F.log_softmax(logits, dim=-1)
    return logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1)  # (B, L)
 
def ppo_loss(policy_logits, labels, old_logp, advantages, mask, clip_eps=0.2):
    # policy_logits: (B, L, V) from the model being trained (grad flows here)
    # labels:       (B, L)    the realized completion tokens
    # old_logp:     (B, L)    from the frozen snapshot, detached
    # advantages:   (B,)      one scalar per completion (outcome RM)
    # mask:         (B, L)    1 for completion tokens, 0 for prompt/padding
    new_logp = per_token_logps(policy_logits, labels)        # (B, L)
 
    ratio = torch.exp(new_logp - old_logp)                   # log space -> stable
    adv = advantages.unsqueeze(-1)                            # (B, 1) broadcasts over L
 
    unclipped = ratio * adv
    clipped = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv
    per_token = -torch.min(unclipped, clipped)               # (B, L)
 
    # average within each sequence over valid tokens, then over the batch
    seq_loss = (per_token * mask).sum(-1) / mask.sum(-1).clamp(min=1)
    return seq_loss.mean()

The advantages passed in are already reward-shaped: before this function runs, the pipeline subtracts β\beta times the policy-vs-reference KL from the reward-model score and runs GAE against the value head. By the time we reach the clip, the reference model has done its work and what remains is the token-by-token min that makes PPO PPO.

Cite this work

Generated from article front matter.

Roy, Swastik. (2024). How PPO Computes Loss Over a Language Model Output. S. Roy. https://swastikroy.me/blog/ppo-loss-per-token

Export PDF opens your browser’s print dialog — choose “Save as PDF” for a Zenodo-ready file.