So I basically want to write a function that returns the required input depth (the input image's number of channels) of a PyTorch convolutional neural network without any external information, so only based on the network object. I'd imagine something like this:
def input_depth(network: "torch.nn.Module") -> int:
# Return input depth, eg. 3 if the network works on RGB images, 1 if grayscale, etc.
Is this even possible? If so, is there an elegant way of doing this?
PS: I am aware of this SO question: PyTorch model input shape and the ability to inspect the output shape of each layer manually, but this is not what I'm looking for. I'm looking for an automatic way to determine the required input depth of the network.