19

When I have a tensor m of shape [12, 10] and a vector s of scalars with shape [12], how can I multiply each row of m with the corresponding scalar in s?

iacob
  • 20,084
  • 6
  • 92
  • 119
Chris
  • 1,266
  • 4
  • 16
  • 34

4 Answers4

34

You need to add a corresponding singleton dimension:

m * s[:, None]

s[:, None] has size of (12, 1) when multiplying a (12, 10) tensor by a (12, 1) tensor pytorch knows to broadcast s along the second singleton dimension and perform the "element-wise" product correctly.

iacob
  • 20,084
  • 6
  • 92
  • 119
Shai
  • 111,146
  • 38
  • 238
  • 371
  • Consider also `m * s.reshape((s.numel(), 1))`. This is easier to understand when reading code imo – jstm Dec 20 '22 at 03:46
3

You can broadcast a vector to a higher dimensional tensor like so:

def row_mult(input, vector):
    extra_dims = (1,)*(input.dim()-1)
    return t * vector.view(-1, *extra_dims)
iacob
  • 20,084
  • 6
  • 92
  • 119
2

A slighty hard to understand at first, but very powerful technique is to use Einstein summation:

torch.einsum('i,ij->ij', s, m)
Filip
  • 41
  • 1
  • 6
0

Shai's answer works if you know the number of dimensions in advance and can hardcode the correct number of None's. This can be extended to extra dimentions is required:

mask = (torch.rand(12) > 0.5).int()  
data = (torch.rand(12, 2, 3, 4))
result = data * mask[:,None,None,None]

result.shape                  # torch.Size([12, 2, 3, 4])
mask[:,None,None,None].shape  # torch.Size([12, 1, 1, 1])

If you are dealing with data of variable or unknown dimensions, then it may require manually extending mask to the correct shape

mask = (torch.rand(12) > 0.5).int()
while mask.dim() < data.dim(): mask.unsqueeze_(1)
result = data * mask

result.shape  # torch.Size([12, 2, 3, 4])
mask.shape    # torch.Size([12, 1, 1, 1])

This is a bit of an ugly solution, but it does work. There is probably a much more elegant way to correctly reshape the mask tensor inline for a variable number of dimensions

James McGuigan
  • 7,542
  • 4
  • 26
  • 29