I am training a multilingual bert model for a sentiment classification task. I have 2 GPUs on 1 Machine so I am using Huggingface Accelerator
for distributed training. But when I run the code it throws a Runtime Error.
Model
class BERTModel(nn.Module):
def __init__(self):
super(BERTModel, self).__init__()
self.bert = transformers.BertModel.from_pretrained("bert-base-multilingual-uncased")
self.bert_drop = nn.Dropout(0.3)
self.out = nn.Linear(768 * 2, 1) # *2 since we have 2 pooling layers
def forward(self, ids, mask, token_type_ids):
o1, _ = self.bert(
ids,
attention_mask=mask,
token_type_ids=token_type_ids
)
mean_pooling = torch.mean(o1, 1)
max_pooling, _ = torch.max(o1, 1)
cat = torch.cat((mean_pooling, max_pooling), 1)
bo = self.bert_drop(cat)
output = self.out(bo)
return output
Train Function
def train_fn(data_loader, model, optimizer, scheduler):
"""
Training Function for the Model
parameters: data_loader - PyTorch DataLoader
model - The Model to be used for training
optimizer - The Optimizer to be used for training
scheduler - The Learning Rate Scheduler
returns: None
"""
accelerator = Accelerator()
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
model.train()
for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
ids = d["ids"]
token_type_ids = d["token_type_ids"]
mask = d["mask"]
targets = d["targets"]
ids = ids.to(torch.long)
token_type_ids = token_type_ids.to(torch.long)
mask = mask.to(torch.long)
targets = targets.to(torch.float)
optimizer.zero_grad()
outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids)
loss = loss_fn(outputs, targets)
if bi % 1000 == 0:
print(f"bi={bi}, loss={loss}")
accelerator.backward(loss)
optimizer.step()
scheduler.step()
Error
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<timed exec> in <module>
/opt/conda/lib/python3.6/site-packages/accelerate/notebook_launcher.py in notebook_launcher(function, args, num_processes, use_fp16, use_port)
107 try:
108 print(f"Launching a training on {num_processes} GPUs.")
--> 109 start_processes(launcher, nprocs=num_processes, start_method="fork")
110 finally:
111 # Clean up the environment variables set.
/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
156
157 # Loop on join until it returns True or raises an exception.
--> 158 while not context.join():
159 pass
160
/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
117 msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
118 msg += original_trace
--> 119 raise Exception(msg)
120
121
Exception:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
fn(i, *args)
File "/opt/conda/lib/python3.6/site-packages/accelerate/utils.py", line 274, in __call__
self.launcher(*args)
File "<timed exec>", line 276, in run
File "<timed exec>", line 74, in train_fn
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/cuda/amp/autocast_mode.py", line 135, in decorate_autocast
return func(*args, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 585, in forward
self.reducer.prepare_for_backward([])
RuntimeError: Expected to have finished reduction in the prior iteration before
starting a new one. This error indicates that your module has parameters that
were not used in producing loss. You can enable unused parameter detection by (1)
passing the keyword argument `find_unused_parameters=True` to
`torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward`
function outputs participate in calculating loss. If you already have done the
above two steps, then the distributed data parallel module wasn't able to locate
the output tensors in the return value of your module's `forward` function.
Please include the loss function and the structure of the return value of
`forward` of your module when reporting this issue (e.g. list, dict, iterable).