I am using a custom PyTorch Dataset with the following:
class ImageDataset(Dataset):
def __init__(self, input_dir, input_num, input_format, transform=None):
self.input_num = input_num
# etc
def __len__ (self):
return self.input_num
def __getitem__(self,idx):
targetnum = idx % self.input_num
# etc
However, when I iterate over this dataset, iteration loops back to the start of the dataset instead of terminating at the end of the dataset. This effectively becomes an infinite loop in the iterator, with the epoch print statement never occurring for subsequent epochs.
train_dataset=ImageDataset(input_dir = 'path/to/directory',
input_num = 300, input_format = "mask") # Size 300
num_epochs = 10
for epoch in range(num_epochs):
print("EPOCH " + str(epoch+1) + "\n")
num = 0
for data in train_dataset:
print(num, end=" ")
num += 1
# etc
Print output (... for values in between):
EPOCH 1
0 1 2 3 4 5 6 7 ... 298 299 300 301 302 303 304 305 ... 597 598 599 600 601 602 603 604 ...
Why is the basic iteration over the Dataset continuing past the defined __len__
of the DataSet, and how can I ensure that iteration over the dataset terminates after hitting the length of the dataset when using this method (or is manually iterating over the range of the dataset length the only solution)?
Thank you.