1

Here is example pytorch code from the website:

class Net(nn.Module):

def __init__(self):
    super(Net, self).__init__()
    # 1 input image channel, 6 output channels, 3x3 square convolution
    # kernel
    self.conv1 = nn.Conv2d(1, 6, 3)
    self.conv2 = nn.Conv2d(6, 16, 3)
    # an affine operation: y = Wx + b
    self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    # Max pooling over a (2, 2) window
    x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
    # If the size is a square you can only specify a single number
    x = F.max_pool2d(F.relu(self.conv2(x)), 2)
    x = x.view(-1, self.num_flat_features(x))
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In the forward function, we simply apply a series of transformations to x, but never explicitly define which objects are part of that transformation. Yet when computing the gradient and updating the weights, Pytorch 'magically' knows which weights to update and how the gradient should be calculated.

How does this process work? Is there code analysis going on, or something else that I am missing?

chessprogrammer
  • 768
  • 4
  • 15
  • I'm not quite sure which part of the code causes the confusion here, but the `x` returned does have information on the layers, eg.: `x = self.fc3(x)` ends up with a tensor that has that specific `nn.Linear` applied to it. If you `print(x)` you can even see that it has a `grad_fn` attribute. – Hai Nguyen Aug 25 '20 at 08:00
  • I'm sorry, but this seems to be a programming issue, so it is off-topic. Please, read https://ai.stackexchange.com/help/on-topic. – nbro Aug 25 '20 at 10:16

1 Answers1

1

Yes, there is implicit analysis on forward pass. Examine the result tensor, there is thingie like grad_fn= <CatBackward>, that's a link, allowing you to unroll the whole computation graph. And it is built during real forward computation process, no matter how you defined your network module, object oriented with 'nn' or 'functional' way.

You can exploit this graph for net analysis, as torchviz do here: https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py

Alexey Birukov
  • 1,565
  • 15
  • 22