I saw the following code segment for extending nn.Mudule
. What I do not understand is the input_ @ self.weight
in forward
function. I can understand that it is try to use the weight information of input_
. But @
is always used as decorator, why it can be used this way?
class Linear(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_size, out_size))
self.bias = nn.Parameter(torch.randn(out_size))
def forward(self, input_):
return self.bias + input_ @ self.weight
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)