0

I saw the following code segment for extending nn.Mudule. What I do not understand is the input_ @ self.weight in forward function. I can understand that it is try to use the weight information of input_. But @ is always used as decorator, why it can be used this way?

class Linear(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_size, out_size))
        self.bias = nn.Parameter(torch.randn(out_size))

    def forward(self, input_):
        return self.bias + input_ @ self.weight

linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
user785099
  • 5,323
  • 10
  • 44
  • 62

1 Answers1

3

The @ is a shorthand for the __matmul__ function: the matrix multiplication operator.

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • I know this `@` operator is used in numpy also. Is it safer to write `torch.matmul(a,b)` so we don't confuse numpy and PyTorch operators? – prosti Oct 05 '21 at 16:39
  • @prosti If you know you are working with `torch.Tensor` and not `numpy.array`, then it makes sense to use `@` as a shorthand and it wouldn't introduce any confusion. Of course, you can still write `torch.matmul` or alternatively use [`torch.Tensor.matmul`](https://pytorch.org/docs/stable/generated/torch.Tensor.matmul.html?highlight=matmul#torch.Tensor.matmul) but it's a little more verbose IMO. It just comes down to personal preferences or team conventions. – Ivan Oct 05 '21 at 18:44
  • Great answer, so the `@` operator is just a function and there is one defined for numpy arrays and one for PyTorch tensors. – prosti Oct 05 '21 at 18:54
  • Is `torch.mm()` is almost the same as `torch.matmul()` or they are the same? I hope my question is not to broad....... – prosti Oct 05 '21 at 18:58
  • Indeed, in practice, the interpreter will look at the class of both operands and search for the appropriate operation. Which means `tensor_a @ tensor_b` will translate to `tensor_a.matmul(tensor_b)`, and will end up calling `torch.matmul(tensor_a, tensor_b)`... I believe `torch.mm` only performs matrix multiplications, while `torch.matmul` is a more general operator (more capable than `torch.mm`) and can perform broadcasting on the operands (see [here](https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul)). – Ivan Oct 05 '21 at 19:01