1

Hello I am trying to create a GNN for a particle tracking project. I have a class that creates the Datset as follow:

class GraphDataset(Dataset):    
   def __init__(self,graph_files,transform=None, pre_transform=None):
        super(GraphDataset,self).__init__()
        self.graph_files = graph_files
    
    @property                 
    def raw_file_names(self):
        return self.graph_files

    @property
    def processed_file_names(self):
        return []

   
        
    def __getitem__(self, idx):
          with np.load(self.graph_files[idx]) as f:
            x = torch.from_numpy(f['x']).type(torch.FloatTensor)    #change to float 
            edge_attr = torch.from_numpy(f['edge_attr']).type(torch.FloatTensor)
            edge_index = torch.from_numpy(f['edge_index'])   # this is double 
            y = torch.from_numpy(f['y']).type(torch.FloatTensor)
            pid = torch.from_numpy(f['pid'])
           

            # make graph undirected
            
            edge_index = torch.stack([torch.cat([edge_index[:,0], edge_index[:,1]], dim=0), torch.cat([edge_index[:,1], edge_index[:,0]], dim=0)], dim = 0)
            edge_attr = torch.cat([edge_attr, -1*edge_attr], dim=0).T
            y = torch.cat([y,y])

            data = Data(x=x, edge_index=edge_index,
                         edge_attr=torch.transpose(edge_attr, 0, 1),
                         y=y, pid=pid)
           
            data.num_nodes = len(x)
          return data    
          
    def __len__(self):
          return len(self.graph_files)

The grap_files is an array with the path to some .npz files like:

array(['.../graph_sec_g007.npz'], dtype='<U64')

I create the dataset as follows:

train_set = GraphDataset(graph_files=partition['train'])

everything looks ok here when I try the getitem method:

train_set.__getitem__(1)

I get the following:

Data(x=[31530, 6], edge_index=[2, 308624], edge_attr=[308624, 7], y=[308624], pid=[31530])

This is what I expect from the GraphDataset class, but when I run the Dataloader as follows:

train_loader = DataLoader(train_set,batch_size=128)

It creates only one batch I inspect the Dataloader like this:

for batch_idx, data in enumerate(train_loader):
   print(data)

I get the following output:

DataBatch(x=[311611, 6], edge_index=[2, 2971396], edge_attr=[2971396, 7], y=[2971396], pid=[311611], batch=[311611], ptr=[11]).

So the partition['train'] array contains 10 files in this case and the Dataloader is creating only one batch with all the files hence the ptr[11], I have tried using only 1 file in the array and it doesn't work either the only difference is the ptr[2], but still a single batch.

Does anyone know why this is happening and how to fix it?

Thanks in advance

0 Answers0