2

This is a snippet of my code in PyTorch, my jupiter notebook stuck when I used num_workers > 0, I spent a lot on this problem without any answer. I do not have a GPU and I work only with a CPU.

class IndexedDataset(Dataset):

def __init__(self,data,targets, test=False):
    self.dataset = data 
    if not test:
        self.labels = targets.numpy()
        self.mask =  np.concatenate((np.zeros(NUM_LABELED), np.ones(NUM_UNLABELED)))


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]
        return image, self.labels[idx]
    
    def display(self, idx):
        plt.imshow(self.dataset[idx], cmap='gray')
        plt.show()

train_set = IndexedDataset(train_data, train_target, test = False)

test_set = IndexedDataset(test_data, test_target, test = True)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=2)

test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=2)

Any help, appreciated.

Sergii Dymchenko
  • 6,890
  • 1
  • 21
  • 46

2 Answers2

1

When num_workers is greater than 0, PyTorch uses multiple processes for data loading.

Jupyter notebooks have known issues with multiprocessing.

One way to resolve this is not to use Jupyter notebooks - just write a normal .py file and run it via command-line.

Or try use what's suggested here: Jupyter notebook never finishes processing using multiprocessing (Python 3).

Sergii Dymchenko
  • 6,890
  • 1
  • 21
  • 46
0

Since jupyter Notebook doesn't support python multiprocessing, there are two thin libraries, you should install one of them as mentioned here 1 and 2.

I prefer to solve my problem in two ways without using any external libraries:

  1. By converting my file from .ipynb format to .py format and run it in the terminal and I write my code in the main() function as follows:

    ...
    ...
    
    train_set = IndexedDataset(train_data, train_target, test = False)
    
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=4)
    
     if `__name__ ==  '__main__'`:
         for images,label in train_loader:
             print(images.shape)
    
  2. With multiprocessing library as follows:

In try.ipynb:

import multiprocessing as mp
import processing as ps

...
...

train_set = IndexedDataset(train_data, train_target, test = False)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
    
if __name__=="__main__":
    p = mp.Pool(8)
    r = p.map(ps.getShape,train_loader) 
    print(r)
    p.close()

In processing.py file:

def getShape(data):
    for i in data:
        return i[0].shape