2

I am struggling to understand how backprop works for transformer-based LLMs.

Here is my guess of how this process works. Given a sequence of tokens with length 64, we process the sequence in parallel using teacher forcing (i.e., for each ACTUAL consecutive subsequence starting from the first token we PREDICT the next token and calculate a loss based on the new predicted token and the actual next token, therefore creating 63 cross-entropy loss values).

We do this for many (let's say, batch size 8192) sequences at a time, in one minibatch, during pretraining. We then take a backpropagation step through the network and adjust weights - till now we've only done a single step. We then move on to the next batch of size 8192 sequences.

  1. Is this understanding correct?
  2. If so, do we average the 63 losses for a single sequence?
  3. Do we average the losses across the 8192 sequences?
  4. If not averaging, how are the losses accumulated to backpropagate for a single minibatch, and why?

Tried searching for papers to explain this process in great detail for language models, but couldn't seem to find any - most were for neural networks generally and did not clarify some of these questions I have about language sequences.

  • Most probably, this might be closed and flagged and better question on datascience.stackexchange.com but I've tried to give some explanation based on the `transformers` library and how it's coded. – alvas Aug 18 '23 at 00:13

1 Answers1

0

There are mainly two training routines for most auto-regressive language models:

  • Casual Language Model (given a word predict another word)
  • Masked Language Model (given a fixed sequence space, predict the best word that fits into the mask)

Most probably what you want to understand, post 2022 GPT-3 popularity, is the Casual Language Model (CLM) training routine.

Here's an example code that demonstrates when the backpropagation takes place, look out for loss.backwards() / accelerator.backwards(), https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L616C1-L626C38

        for step, batch in enumerate(active_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                # We keep track of the loss at each epoch
                if args.with_tracking:
                    total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

We see that we run backpropagation at the end of each batch. And each batch is made up a few input data points from the DataLoader object. In the case of language datasets, a data point is most probably an input sentence/paragraph of fixed size 512 / 1024 subwords tokens.

Q: When does backpropagation takes place in casual language model?

A: After a forward pass of the model through batch size, n and sequence max length, l, we compute the loss through backpropagation for the n x l steps.


Next we look into how loss.backward() works, in a general vanilla transformers, most probably the loss is computed with

For each n x l, we compute the loss in whether we predicted the right words, see https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81 for some concrete examples.

The NLL / perplexity loss is computed per 1 x l for n sentences, in most cases, where the losses for a sentence prediction don't usually affect the other sentences in the batch. So averaging loss for n x l by n is reasonable.

Note that we don't compute NLL loss per token in each sentence but the NLL loss for all tokens in the sentence, for more details, it's kinda long but see Section 7.7 on https://web.stanford.edu/~jurafsky/slp3/7.pdf. Generally, the idea is the same, we compute the output sequence in each forward pass then check if every token we computed is right/wrong, a binary label for each prediction and target token pairs, then loss is computed for the whole sequence of predictions per sentence.

Q: What loss is computed when training a language model?

A: Most probably perplexity for simple decoder-only language models and KL divergence or cross-entropy loss for seq2seq tasks


Q: Stop stalling, answer the question, "Do we average the losses across the 8192 sequences?"

A: Assuming that 8192 is the n * l steps for each batch, i.e. if the batch is made up of 8 sentences of 1024 length, we compute loss for each batch.

The definition of "mini-batch" or "full-batch" or "epoch" is kind of different depending on who you ask, so lets call 8 sentences a batch in this case.

Okay, so "Do we average?"

A: Kinda, in most common routine, we have 2 losses, training loss and evaluation/validation loss.

For the validation loss, the usual case is that we compute perplexity for each input and then average it, https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py#L642C1-L656C38

        model.eval()
        losses = []
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                outputs = model(**batch)

            loss = outputs.loss
            losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))

        losses = torch.cat(losses)
        try:
            eval_loss = torch.mean(losses)
            perplexity = math.exp(eval_loss)
        except OverflowError:
            perplexity = float("inf")

For training loss, each batch has their loss computed for each n * l tokens and the loss is summed up, then averaged it out when the logger reports the training loss,

        if args.with_tracking:
            total_loss = 0
        if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
            # We skip the first `n` batches in the dataloader when resuming from a checkpoint
            active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
        else:
            active_dataloader = train_dataloader
        for step, batch in enumerate(active_dataloader):
            with accelerator.accumulate(model):
                outputs = model(**batch)
                loss = outputs.loss
                # We keep track of the loss at each epoch
                if args.with_tracking:
                    total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

        ...

        if args.with_tracking:
            accelerator.log(
                {
                    "perplexity": perplexity,
                    "eval_loss": eval_loss,
                    "train_loss": total_loss.item() / len(train_dataloader),
                    "epoch": epoch,
                    "step": completed_steps,
                },
                step=completed_steps,
            )

And for GPT-2 model to be specific, you can see that it does the same accumulation per n * l batch, see https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L1269

Q: Why are you not answering whether it's average?

A: Now that's a good question! In most cases, you compute loss for every batch, yes you average the perplexity loss per batch you see for the whole dataset. But actually, the training routine is totally up to whoever codes the model.

In most models, you can average losses for n x l by n per batch and then macro-average the losses for every batch by no. of batches in the epoch to report the training loss. If there's averaging done, it's mostly divided only by n for each loss computed for 1 xl predictions.

While this is intuitive, for every model inside the transformers, everyone picks different losses. https://github.com/huggingface/transformers/tree/main/src/transformers/models

There are many other factors that affects when a model gets updated and loss computed. If gradients are accumulated and updates are delayed, e.g. https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.gradient_accumulation_steps , then the average is not just across n x l but maybe d x n x l where d is the no. steps you've delayed.

Epilogue

So to correctly answer is loss averaged, or how/when the model updates the gradients based on the backpropagated loss, you have to specify:

  • Which model/architecture is used? And if there's an existing code base that you can check with. Different implementations might have different loss computation routine too
  • What are the training hyperparameters used, esp. the optimizers related ones?
  • What kind of language model are you training? In most cases, I've been describing the casual language models, but there are others like masked language models or even non-autoregressive models that have different loss.
alvas
  • 115,346
  • 109
  • 446
  • 738