0

So I basically want to write a function that returns the required input depth (the input image's number of channels) of a PyTorch convolutional neural network without any external information, so only based on the network object. I'd imagine something like this:

def input_depth(network: "torch.nn.Module") -> int:
  # Return input depth, eg. 3 if the network works on RGB images, 1 if grayscale, etc.

Is this even possible? If so, is there an elegant way of doing this?

PS: I am aware of this SO question: PyTorch model input shape and the ability to inspect the output shape of each layer manually, but this is not what I'm looking for. I'm looking for an automatic way to determine the required input depth of the network.

Erasiel
  • 1
  • 1

1 Answers1

0

You could loop through the sub-modules of the provided network and look for nn.Conv2d instance and extract its in_channels attribute which would correspond to the number of channels expected by the first convolutional layer of the network:

def input_depth(network: "torch.nn.Module") -> int:
    for v in network.modules():
        if isinstance(v, nn.Conv2d):
            return v.in_channels
    raise Exception('No nn.Conv2d found in provided model')

Edit: indeed you are right, the above will only work if the order of initialization of the convolution layer is the same as the order they are used in inside the forward logic.

Since this logic is only defined by the user and not registered, the only way I see it working is by triggering a failure and looking at the expected number of channels. This can be done by using a try/except clause followed by a regex search on the runtime error.

def input_depth(network: "torch.nn.Module") -> int:
    try:
        network(torch.empty(1,0,1,1))
    except RuntimeError as e:
        search = re.search('(\d+)[^\d]+channels', str(e))
        return int(search.group(1))
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Thanks for the suggestion, but this does not guarantee the correct depth in every case. The edge case is when the first Conv2d in the forward order is not the one that was added first to the Module and since Module._modules is an OrderedDict, the order of traversal in network.modules() is fixed to the order of submodule addition and not forward order. – Erasiel Jul 31 '22 at 17:13
  • Have a look at my edit, it seems the only way to retrieve this information is by triggering a failure with a tensor input which we expect to fail (*e.g.* a tensor with no elements, the correct shape but an incorrect number of channels). Since `nn.Conv2d` checks for the number of channels, this should work properly. – Ivan Jul 31 '22 at 18:52