0

How can I know which are the input node's or layer's name for a layer in PyTorch? Say if I have a torch.cat, how can I know the tensors or layers's name from where it is getting the input?
For this code from https://rosenfelder.ai/multi-input-neural-network-pytorch/

class LitClassifier(pl.LightningModule):
    def __init__(
        self, lr: float = 1e-3, num_workers: int = 4, batch_size: int = 32,
    ):
        super().__init__()
        self.lr = lr
        self.num_workers = num_workers
        self.batch_size = batch_size

        self.conv1 = conv_block(3, 16)
        self.conv2 = conv_block(16, 32)
        self.conv3 = conv_block(32, 64)

        self.ln1 = nn.Linear(64 * 26 * 26, 16)
        self.relu = nn.ReLU()
        self.batchnorm = nn.BatchNorm1d(16)
        self.dropout = nn.Dropout2d(0.5)
        self.ln2 = nn.Linear(16, 5)

        self.ln4 = nn.Linear(5, 10)
        self.ln5 = nn.Linear(10, 10)
        self.ln6 = nn.Linear(10, 5)
        self.ln7 = nn.Linear(10, 1)  

   def forward(self, img, tab):
        img = self.conv1(img)
        img = self.conv2(img)
        img = self.conv3(img)
        img = img.reshape(img.shape[0], -1)
        img = self.ln1(img)
        img = self.relu(img)
        img = self.batchnorm(img)
        img = self.dropout(img)
        img = self.ln2(img)
        img = self.relu(img)

        tab = self.ln4(tab)
        tab = self.relu(tab)
        tab = self.ln5(tab)
        tab = self.relu(tab)
        tab = self.ln6(tab)
        tab = self.relu(tab)

        x = torch.cat((img, tab), dim=1)
        x = self.relu(x)

        return self.ln7(x)

So if I want to know from which layer the torch.cat is receiving the input.
For keras we have model.get_layer(id=idx).input.name, is there something similar for PyTorch too?

  • 1
    From... source code? – Alexey S. Larionov Jun 17 '21 at 13:12
  • @AlexeyLarionov Provided a sample source code for the question, thanks! – Sanjiban Sengupta Jun 17 '21 at 13:21
  • `img` variable contains whatever was last assigned to it, i.e. last line `img = self.relu(img)`. Similarly with `tab` variable. I don't get the question, it's just basic Python – Alexey S. Larionov Jun 17 '21 at 13:26
  • @AlexeyLarionov For keras, whenever we create a layer, it has attributes for input node's name which can be found by model.get_layer(id=idx).input.name, so is there something similar exist for PyTorch where we can know which are the inputs for a particular layer? The variables tab and img are defined under the scope of forward(), can we extract that data from the model object as whole? – Sanjiban Sengupta Jun 17 '21 at 14:49
  • 1
    There's no such concept as "layer's input". Input is whatever you pass to `forward` method, like in your example a single `self.relu` layer is called 6 times with different inputs. There's `nn.Sequential` layer aggregation which basically implements passing some `x` to first layer, then output of this layer to the second layer and so one for all the layers. To get which layer is where you can work with indices of `nn.Sequential` inner layers – Alexey S. Larionov Jun 17 '21 at 14:54
  • `can we extract that data from the model object as whole?` You typically save some output you want as a field of your model, e.g. in `forward()` save some intermediate output like that `self.last_tab = tab`, then access it as `model.last_tab` wherever you want – Alexey S. Larionov Jun 17 '21 at 14:55
  • @AlexeyLarionov What i basically want is to find out which layers are called inside the model, not just when they are defined, but the exact chain of execution, say how they are in the graph representation, and from there I could extract the needed information. Similar to how model.get_layer or model.get_config is defined for keras – Sanjiban Sengupta Jun 17 '21 at 17:03
  • As I said, I'm not aware of anything like that. Layers are just layers - how they're applied is well beyond such tools, because layers can be used in arbitrary ways - they're just python variables. If the model has a simple structure consisted of `nn.Sequential`, then a simple `print(model)` might give what you're looking for, [like there](https://stackoverflow.com/questions/42480111/model-summary-in-pytorch). – Alexey S. Larionov Jun 17 '21 at 20:05

0 Answers0