1

I came across the following code in a neural network tutorial. The following lines works correctly although they contradict with my knowledge on Python class.

class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.fc1 = nn.Linear(28 * 28, 200)
      self.fc2 = nn.Linear(200, 200)
      self.fc3 = nn.Linear(200, 10)

  def forward(self, x):
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return F.log_softmax(x)


net = Net()
net_out = net(data)

Here some data was passed into net.forward() and forward() executed.

However, to my knowledge, we have to use net.forward(data) instead of net(data) in order to access the function forward. Therefore, could anyone tell me why we can access forward() without mentioning the function name? Is there some sort of rule that allows us to access a function in a class without using class_name.function_name?

Kate Orlova
  • 3,225
  • 5
  • 11
  • 35
xxxnxxxnn
  • 23
  • 4
  • it won't happen according to my knowledge as well – classicdude7 Jun 07 '20 at 15:37
  • 2
    Keep in mind that there's more code running than what you showed here - the base class `nn.Module` is going to have method definitions of its own, that are also accessible to instances of your class. Apparently it defines a `.__call__()` method (that's what `net(...)` is invoking) that calls (perhaps indirectly) your `.forward()` method. – jasonharper Jun 07 '20 at 15:37

1 Answers1

2

Python has a set of "magical" methods that you can overwrite in order to update an object's behavior. In this particular case, one way of achieving net_out = net(data) is overwriting the __call__ function, which you can check out in this other post.

I suppose the overwriting happens in nn.Module. Something like the following:

class Module:
    ...
    def __call__(self, arg):
        self.forward(arg)
    ...         
alexfertel
  • 925
  • 1
  • 9
  • 22