Blog Post
How PPO Computes Loss Over a Language Model Output
Most explanations of PPO stay at the algorithm level. This post goes one level deeper: how the surrogate loss is actually computed token by token for a language model response.
Views: –5 min readCite
Picture the concrete objects PPO has in hand at the moment it computes a loss. There is a prompt , a completion of tokens that the policy sampled, a frozen copy of the policy taken at the start of this iteration (), a separate frozen reference model (, the SFT checkpoint), and a single scalar from the reward model scoring how good was. Everything below is the arithmetic that turns those objects into one number to backpropagate.
Log-probs per token, in one pass each
The first step is to recover the per-token log-probabilities under both the current policy and the frozen snapshot. Each model runs a single forward pass over the concatenation with supplied as labels, and at position we read off the log-probability the model assigns to the token that was actually generated, .
Teacher forcing is what makes this cheap: because the whole completion is already known, one pass emits the logits at every position simultaneously, and a single gather along the vocabulary axis picks out the realized token at each step. There is no autoregressive decode loop here — generation happened earlier, during rollout; loss computation just re-scores a fixed string.
Ratios in log space
With both log-prob vectors in hand, the probability ratio at each token is their difference exponentiated. We never form directly — dividing two tiny softmax probabilities is a fast route to underflow — so the ratio is computed as the exponential of a log-difference, which keeps every intermediate in a numerically safe range.
At the very first gradient step of an iteration , so every exactly; the ratios drift away from only as the inner epochs update against the same batch of rollouts.
One advantage, broadcast across tokens
The advantage scales each ratio, and in vanilla RLHF it is a single scalar per completion, derived from the reward-model score (typically after whitening across the batch and folding in the value baseline through GAE). The same multiplies the term at every token position — the reward model judged the response as a whole, so credit is assigned uniformly to all the tokens that produced it.
This is exactly the assumption a process reward model relaxes: a PRM emits a score at intermediate steps, so genuinely varies along the sequence and tokens in a flawed reasoning step can be penalized while a correct prefix is spared. Outcome-only RLHF cannot make that distinction — every token in a 400-token answer shares the fate of the final scalar.
Clip, minimize, average
Each token now contributes the standard PPO clipped term, the minimum of the raw ratio-weighted advantage and a version whose ratio is clamped to . The per-token losses are negated (optimizers minimize) and averaged over the tokens of the completion and over the completions in the batch.
Averaging within a sequence before averaging across the batch (rather than pooling all tokens into one mean) keeps a long completion from dominating the gradient simply by contributing more terms, though implementations differ on this and some pool over all valid tokens at once.
Where the reference model enters
The KL penalty against is a separate object from , and conflating the two is the usual point of confusion. The snapshot exists only to define the ratio and is refreshed every PPO iteration; the reference is the SFT model, frozen for the entire run, and its job is to anchor the policy so reward optimization does not wander into degenerate text. It enters as a per-token penalty folded into the reward before advantages are computed:
where is the reward-model scalar placed at the final token and zero elsewhere. So the SFT model shows up twice in spirit but at different stages: as the KL anchor shaping the reward, and (at iteration boundaries) as the eventual ancestor of .
The computation in code
Stripped to its essentials, the loss over a batch of completions is a handful of tensor operations. The log-probs come from gathering the realized tokens out of the model's log-softmax; everything after that is the algebra above.
import torch
import torch.nn.functional as F
def per_token_logps(logits, labels):
# logits: (B, L, V) predicting token t+1; labels: (B, L) realized tokens
logp = F.log_softmax(logits, dim=-1)
return logp.gather(-1, labels.unsqueeze(-1)).squeeze(-1) # (B, L)
def ppo_loss(policy_logits, labels, old_logp, advantages, mask, clip_eps=0.2):
# policy_logits: (B, L, V) from the model being trained (grad flows here)
# labels: (B, L) the realized completion tokens
# old_logp: (B, L) from the frozen snapshot, detached
# advantages: (B,) one scalar per completion (outcome RM)
# mask: (B, L) 1 for completion tokens, 0 for prompt/padding
new_logp = per_token_logps(policy_logits, labels) # (B, L)
ratio = torch.exp(new_logp - old_logp) # log space -> stable
adv = advantages.unsqueeze(-1) # (B, 1) broadcasts over L
unclipped = ratio * adv
clipped = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv
per_token = -torch.min(unclipped, clipped) # (B, L)
# average within each sequence over valid tokens, then over the batch
seq_loss = (per_token * mask).sum(-1) / mask.sum(-1).clamp(min=1)
return seq_loss.mean()The advantages passed in are already reward-shaped: before this function runs, the pipeline subtracts times the policy-vs-reference KL from the reward-model score and runs GAE against the value head. By the time we reach the clip, the reference model has done its work and what remains is the token-by-token min that makes PPO PPO.