I have a heterogeneous graph with different node types and attributes stored as NetworkX Multigraph. I want to train a GNN model on this graph, therefore I need to convert it into PyTorch Geometric datatype.
I tried built-in function
from torch_geometric.utils.convert import from_networkx
data_PyG = from_networkx(G)
But it only works if G is networkx.Graph
or networkx.DiGraph
with the same set of attributes.
Is there a way to automatically convert nx.Multigraph
into HeteroData()
?