10

I have a question that How to get the total number of batch iteration from pytorch dataloader?

The following is a common code for training

for i, batch in enumerate(dataloader):

Then, is there any method to get the total number of iteration for the "for loop"?

In my NLP problem, the total number of iteration is different from int(n_train_samples/batch_size)...

For example, if I truncate train data only 10,000 samples and set the batch size as 1024, then 363 iteration occurs in my NLP problem.

I wonder how to get the number of total iteration in "the for-loop".

Thank you.

Hyunseung Kim
  • 493
  • 1
  • 6
  • 17

2 Answers2

22

len(dataloader) returns the total number of batches. It depends on the __len__ function of your dataset, so make sure it is set correctly.

hkchengrex
  • 4,361
  • 23
  • 33
  • 1
    @HyunseungKim Do you have the `__len__` function in your dataset? The dataloader's len function depends on the `__len__` function of the dataset. I use it daily so I am sure that it works :) – hkchengrex Sep 18 '20 at 12:16
  • 1
    I'm sorry that I think there is some my mistakes. Actually, I tried to modify this code(https://github.com/SamLynnEvans/Transformer) for some purpose. In the script "Batch.py", there is a class which named "MyIterater" and it returns train_iter. But, I couldn't sure that it is dataloader or not... I have to check it more. – Hyunseung Kim Sep 18 '20 at 12:22
  • Now I undertand what your saying " It depends on the __ len __ function of your dataset, so make sure it is set correctly". __ len__ must be written. – Hyunseung Kim Sep 18 '20 at 12:26
3

There is one additional parameter when creating the dataloader. It is called drop_last.

If drop_last=True then length is number_of_training_examples // batch_size. If drop_last=False it may be number_of_training_examples // batch_size +1 .

BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)

For predefined datasets you may get the number of examples like:

# number of examples
len(dl_train.dataset) 

The correct number of batches inside dataloader is always:

# number of batches
len(dl_train) 
prosti
  • 42,291
  • 14
  • 186
  • 151