0

I was trying to do my first GNN practice using pytorch. I have successfully read my own data and converted to torch.data object.

'''
Read Edges and features from csv file
'''
edges = pd.read_csv('/content/drive/MyDrive/GCN/2009/ROI01/BT_EdgeList.csv')
features_raw = pd.read_csv('/content/drive/MyDrive/GCN/2009/ROI01/BT_NodeList.csv')[['class', 'PD1.', 'IL10.', 'Area']]
coordinate = pd.read_csv('/content/drive/MyDrive/GCN/2009/ROI01/BT_NodeList.csv')[['x', 'y']]


links = edges[['from', 'to']]
links_list = [list(x) for x in links.to_numpy()]

# convert links to torch format
edge_index = torch.tensor(links_list, dtype = torch.long)


# features
features_list = [list(x) for x in features_raw.to_numpy()]
coordinate_list = [list(x) for x in coordinate.to_numpy()]

# convert features to torch format
features = torch.tensor(features_list, dtype = torch.long)

# convert node position to torch format
print(coordinate_list)
pos = torch.tensor(coordinate_list, dtype = torch.long)

data = Data(x = features, y = [[1]], edge_index = edge_index.t().contiguous(), pos = pos)

And now I would like to visualize it to confirm that the graph was built correctly. Here are the codes I used for visualization:

g = to_networkx(data)

pos_layout = nx.spring_layout(g)
plt.figure(1,figsize=(14,12)) 

nx.draw(g, cmap=plt.get_cmap('Set1'),linewidths=6, node_size= 10, pos = pos_layout)
plt.show()

and there are 2 problems:
first, the location of nodes varies every time I do the plotting, even though I specified the layout.
second, I could not replicate the exact locations of each node as I did when generating the raw data using R. The ideal graph should look like: enter image description here I attached the edge and node file here to reproduce: https://drive.google.com/drive/folders/1KPMJM6PyPs5o_F86MRYuIN3_3I1N6pF-?usp=sharing

DigiPath
  • 179
  • 2
  • 10

1 Answers1

0

Referring to the documentation and this anwser :

https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html

https://stackoverflow.com/a/62066949/13643054

There is an additionnal seed argument that you can add to layouts in order to make visualization deterministic, which would make your code look like this :

pos_layout = nx.spring_layout(g, seed=12345)

second, I could not replicate the exact locations of each node as I did when generating the raw data using R

Could you elaborate a bit more ? I don't understand what happens and what should happen exactly

  • Thank you for your answer. I expect the graph should look like exactly the same as the figure I put in the question. The graph was plotted using R, I exported the data I used for plotting (attached in the link) but I couldn't achieve it using my Python code. – DigiPath Nov 06 '21 at 17:30
  • @DigiPath I see, then one of the functions (R or networkx) does not use your nodes coordinates as positions. I believe that by calling spring_layout on pos, you are in fact modifying them, so you should try `nx.draw(g, cmap=plt.get_cmap('Set1'),linewidths=6, node_size= 10, pos = pos)` directly. An addtionnal answer on this theme : https://stackoverflow.com/a/65004666/13643054 – QuanticDisaster Nov 10 '21 at 15:18