There are mainly two training routines for most auto-regressive language models:
- Casual Language Model (given a word predict another word)
- Masked Language Model (given a fixed sequence space, predict the best word that fits into the mask)
Most probably what you want to understand, post 2022 GPT-3 popularity, is the Casual Language Model (CLM) training routine.
Here's an example code that demonstrates when the backpropagation takes place, look out for loss.backwards()
/ accelerator.backwards()
, https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L616C1-L626C38
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
We see that we run backpropagation at the end of each batch. And each batch is made up a few input data points from the DataLoader
object. In the case of language datasets, a data point is most probably an input sentence/paragraph of fixed size 512 / 1024 subwords tokens.
Q: When does backpropagation takes place in casual language model?
A: After a forward pass of the model through batch size, n and sequence max length, l, we compute the loss through backpropagation for the n x l steps.
Next we look into how loss.backward()
works, in a general vanilla transformers, most probably the loss is computed with
- for decoder-only perplexity tasks, the loss function is perplexity
- for encoder-decoder translation tasks, label smoothing that computes the KL divergence between the predicted words and the actual words
For each n x l, we compute the loss in whether we predicted the right words, see https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81 for some concrete examples.
The NLL / perplexity loss is computed per 1 x l for n sentences, in most cases, where the losses for a sentence prediction don't usually affect the other sentences in the batch. So averaging loss for n x l by n is reasonable.
Note that we don't compute NLL loss per token in each sentence but the NLL loss for all tokens in the sentence, for more details, it's kinda long but see Section 7.7 on https://web.stanford.edu/~jurafsky/slp3/7.pdf. Generally, the idea is the same, we compute the output sequence in each forward pass then check if every token we computed is right/wrong, a binary label for each prediction and target token pairs, then loss is computed for the whole sequence of predictions per sentence.
Q: What loss is computed when training a language model?
A: Most probably perplexity for simple decoder-only language models and KL divergence or cross-entropy loss for seq2seq tasks
Q: Stop stalling, answer the question, "Do we average the losses across the 8192 sequences?"
A: Assuming that 8192 is the n * l steps for each batch, i.e. if the batch is made up of 8 sentences of 1024 length, we compute loss for each batch.
The definition of "mini-batch" or "full-batch" or "epoch" is kind of different depending on who you ask, so lets call 8 sentences a batch in this case.
Okay, so "Do we average?"
A: Kinda, in most common routine, we have 2 losses, training loss and evaluation/validation loss.
For the validation loss, the usual case is that we compute perplexity for each input and then average it, https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L642C1-L656C38
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
losses = torch.cat(losses)
try:
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
For training loss, each batch has their loss computed for each n * l tokens and the loss is summed up, then averaged it out when the logger reports the training loss,
if args.with_tracking:
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
...
if args.with_tracking:
accelerator.log(
{
"perplexity": perplexity,
"eval_loss": eval_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
And for GPT-2 model to be specific, you can see that it does the same accumulation per n * l batch,
see https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L1269
Q: Why are you not answering whether it's average?
A: Now that's a good question! In most cases, you compute loss for every batch, yes you average the perplexity loss per batch you see for the whole dataset. But actually, the training routine is totally up to whoever codes the model.
In most models, you can average losses for n x l by n per batch and then macro-average the losses for every batch by no. of batches in the epoch to report the training loss. If there's averaging done, it's mostly divided only by n for each loss computed for 1 xl predictions.
While this is intuitive, for every model inside the transformers
, everyone picks different losses. https://github.com/huggingface/transformers/tree/main/src/transformers/models
There are many other factors that affects when a model gets updated and loss computed. If gradients are accumulated and updates are delayed, e.g. https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.gradient_accumulation_steps , then the average is not just across n x l but maybe d x n x l where d is the no. steps you've delayed.
Epilogue
So to correctly answer is loss averaged, or how/when the model updates the gradients based on the backpropagated loss, you have to specify:
- Which model/architecture is used? And if there's an existing code base that you can check with. Different implementations might have different loss computation routine too
- What are the training hyperparameters used, esp. the optimizers related ones?
- What kind of language model are you training? In most cases, I've been describing the casual language models, but there are others like masked language models or even non-autoregressive models that have different loss.