How can I know which are the input node's or layer's name for a layer in PyTorch?
Say if I have a torch.cat, how can I know the tensors or layers's name from where it is getting the input?
For this code from https://rosenfelder.ai/multi-input-neural-network-pytorch/
class LitClassifier(pl.LightningModule):
def __init__(
self, lr: float = 1e-3, num_workers: int = 4, batch_size: int = 32,
):
super().__init__()
self.lr = lr
self.num_workers = num_workers
self.batch_size = batch_size
self.conv1 = conv_block(3, 16)
self.conv2 = conv_block(16, 32)
self.conv3 = conv_block(32, 64)
self.ln1 = nn.Linear(64 * 26 * 26, 16)
self.relu = nn.ReLU()
self.batchnorm = nn.BatchNorm1d(16)
self.dropout = nn.Dropout2d(0.5)
self.ln2 = nn.Linear(16, 5)
self.ln4 = nn.Linear(5, 10)
self.ln5 = nn.Linear(10, 10)
self.ln6 = nn.Linear(10, 5)
self.ln7 = nn.Linear(10, 1)
def forward(self, img, tab):
img = self.conv1(img)
img = self.conv2(img)
img = self.conv3(img)
img = img.reshape(img.shape[0], -1)
img = self.ln1(img)
img = self.relu(img)
img = self.batchnorm(img)
img = self.dropout(img)
img = self.ln2(img)
img = self.relu(img)
tab = self.ln4(tab)
tab = self.relu(tab)
tab = self.ln5(tab)
tab = self.relu(tab)
tab = self.ln6(tab)
tab = self.relu(tab)
x = torch.cat((img, tab), dim=1)
x = self.relu(x)
return self.ln7(x)
So if I want to know from which layer the torch.cat is receiving the input.
For keras we have model.get_layer(id=idx).input.name
, is there something similar for PyTorch too?