I am training a BERT model on a relatively small dataset and cannot afford to lose any labelled sample as they must all be used for training. Due to GPU memory constraints, I am using gradient accumulation to train on larger batches (e.g. 32). According to PyTorch documentation, gradient accumulation is implemented as follows:
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0:
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
However, if you are using e.g. 110 training samples, with batch size 8 and accumulation step 4 (i.e. effective batch size 32), this method would only train the first 96 samples (i.e. 32 x 3), i.e. wasting 14 samples. In order to avoid this, I'd like to modify the code as follows (notice change to the final if statement):
scaler = GradScaler()
for epoch in epochs:
for i, (input, target) in enumerate(data):
with autocast():
output = model(input)
loss = loss_fn(output, target)
loss = loss / iters_to_accumulate
# Accumulates scaled gradients.
scaler.scale(loss).backward()
if (i + 1) % iters_to_accumulate == 0 or (i + 1) == len(data):
# may unscale_ here if desired (e.g., to allow clipping unscaled gradients)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
Is this correct and really that simple, or will this have any side effects? It seems very simple to me, but I've never seen it done before. Any help appreciated!