88

With numpy, I can do a simple matrix multiplication like this:

a = numpy.ones((3, 2))
b = numpy.ones((2, 1))
result = a.dot(b)

However, this does not work with PyTorch:

a = torch.ones((3, 2))
b = torch.ones((2, 1))
result = torch.dot(a, b)

This code throws the following error:

RuntimeError: 1D tensors expected, but got 2D and 2D tensors

How do I perform matrix multiplication in PyTorch?

Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
timbmg
  • 3,192
  • 7
  • 34
  • 52

4 Answers4

127

Use torch.mm:

torch.mm(a, b)

torch.dot() behaves differently to np.dot(). There's been some discussion about what would be desirable here. Specifically, torch.dot() treats both a and b as 1D vectors (irrespective of their original shape) and computes their inner product. The error is thrown because this behaviour makes your a a vector of length 6 and your b a vector of length 2; hence their inner product can't be computed. For matrix multiplication in PyTorch, use torch.mm(). Numpy's np.dot() in contrast is more flexible; it computes the inner product for 1D arrays and performs matrix multiplication for 2D arrays.

torch.matmul performs matrix multiplications if both arguments are 2D and computes their dot product if both arguments are 1D. For inputs of such dimensions, its behaviour is the same as np.dot. It also lets you do broadcasting or matrix x matrix, matrix x vector and vector x vector operations in batches.

# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])

# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])
Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
mbpaulus
  • 7,301
  • 3
  • 29
  • 40
  • 8
    Since this is accepted answer, I think you should include torch.matmul. It performs dot product for 1D arrays and matrix multiplication for 2D arrays. – unlut Jul 08 '19 at 14:36
57

To perform a matrix (rank 2 tensor) multiplication, use any of the following equivalent ways:

AB = A.mm(B)

AB = torch.mm(A, B)

AB = torch.matmul(A, B)

AB = A @ B  # Python 3.5+ only

There are a few subtleties. From the PyTorch documentation:

torch.mm does not broadcast. For broadcasting matrix products, see torch.matmul().

For instance, you cannot multiply two 1-dimensional vectors with torch.mm, nor multiply batched matrices (rank 3). To this end, you should use the more versatile torch.matmul. For an extensive list of the broadcasting behaviours of torch.matmul, see the documentation.

For element-wise multiplication, you can simply do (if A and B have the same shape)

A * B  # element-wise matrix multiplication (Hadamard product)
Mateen Ulhaq
  • 24,552
  • 19
  • 101
  • 135
BiBi
  • 7,418
  • 5
  • 43
  • 69
  • 8
    I *love* the one-character `@` operator. `w @ x` will be my goto – Nathan majicvr.com Jul 31 '19 at 01:52
  • 2
    `torch.matmul` and `@` are equivalent only for a rank 2 tensor. The `@` operaion is "actually" `torch.bmm` (*b*atch *m*atrix *m*ultiply) in which the matrix multiply is done on the last two dimensions (https://discuss.pytorch.org/t/how-does-the-sign-work-in-this-instance/11232). – ponadto Aug 20 '22 at 13:31
11

Use torch.mm(a, b) or torch.matmul(a, b)
Both are same.

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

There's one more option that may be good to know. That is @ operator. @Simon H.

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

The three give the same results.

Related links:
Matrix multiplication operator
PEP 465 -- A dedicated infix operator for matrix multiplication

David Jung
  • 376
  • 5
  • 8
  • Are `torch.mm(a,b)`, `torch.matmul(a,b)` and `a@b` equivalent? I can't find any documentation on the @ operator. – Simon Hessner Feb 21 '19 at 21:42
  • Yeah, it seems that there isn't any documentation about `@` operator. But, there are several notations in the document that include `@` in it that give the semantic of the matrix multiplication. So I think that the `@` operator has been overloaded by PyTorch in the meaning of matrix multiplication. – David Jung Feb 22 '19 at 00:39
  • 1
    Added links to @ operator. – David Jung Feb 22 '19 at 01:10
8

You can use "@" for computing a dot product between two tensors in pytorch.

a = torch.tensor([[1,2],
                  [3,4]])
b = torch.tensor([[5,6],
                  [7,8]])
c = a@b #For dot product
c

d = a*b #For elementwise multiplication 
d
Nivesh Gadipudi
  • 486
  • 5
  • 15