Hello I am trying to create a GNN for a particle tracking project. I have a class that creates the Datset as follow:
class GraphDataset(Dataset):
def __init__(self,graph_files,transform=None, pre_transform=None):
super(GraphDataset,self).__init__()
self.graph_files = graph_files
@property
def raw_file_names(self):
return self.graph_files
@property
def processed_file_names(self):
return []
def __getitem__(self, idx):
with np.load(self.graph_files[idx]) as f:
x = torch.from_numpy(f['x']).type(torch.FloatTensor) #change to float
edge_attr = torch.from_numpy(f['edge_attr']).type(torch.FloatTensor)
edge_index = torch.from_numpy(f['edge_index']) # this is double
y = torch.from_numpy(f['y']).type(torch.FloatTensor)
pid = torch.from_numpy(f['pid'])
# make graph undirected
edge_index = torch.stack([torch.cat([edge_index[:,0], edge_index[:,1]], dim=0), torch.cat([edge_index[:,1], edge_index[:,0]], dim=0)], dim = 0)
edge_attr = torch.cat([edge_attr, -1*edge_attr], dim=0).T
y = torch.cat([y,y])
data = Data(x=x, edge_index=edge_index,
edge_attr=torch.transpose(edge_attr, 0, 1),
y=y, pid=pid)
data.num_nodes = len(x)
return data
def __len__(self):
return len(self.graph_files)
The grap_files is an array with the path to some .npz files like:
array(['.../graph_sec_g007.npz'], dtype='<U64')
I create the dataset as follows:
train_set = GraphDataset(graph_files=partition['train'])
everything looks ok here when I try the getitem method:
train_set.__getitem__(1)
I get the following:
Data(x=[31530, 6], edge_index=[2, 308624], edge_attr=[308624, 7], y=[308624], pid=[31530])
This is what I expect from the GraphDataset class, but when I run the Dataloader as follows:
train_loader = DataLoader(train_set,batch_size=128)
It creates only one batch I inspect the Dataloader like this:
for batch_idx, data in enumerate(train_loader):
print(data)
I get the following output:
DataBatch(x=[311611, 6], edge_index=[2, 2971396], edge_attr=[2971396, 7], y=[2971396], pid=[311611], batch=[311611], ptr=[11]).
So the partition['train'] array contains 10 files in this case and the Dataloader is creating only one batch with all the files hence the ptr[11], I have tried using only 1 file in the array and it doesn't work either the only difference is the ptr[2], but still a single batch.
Does anyone know why this is happening and how to fix it?
Thanks in advance