4

Suppose I have two tensors S and T defined as:

S = torch.rand((3,2,1))
T = torch.ones((3,2,1))

We can think of these as containing batches of tensors with shapes (2, 1). In this case, the batch size is 3.

I want to concatenate all possible pairings between batches. A single concatenation of batches produces a tensor of shape (4, 1). And there are 3*3 combinations so ultimately, the resulting tensor C must have a shape of (3, 3, 4, 1).

One solution is to do the following:

for i in range(S.shape[0]):
  for j in range(T.shape[0]):
    C[i,j,:,:] = torch.cat((S[i,:,:],T[j,:,:]))

But the for loop doesn't scale well to large batch sizes. Is there a PyTorch command to do this?

Ivan
  • 34,531
  • 8
  • 55
  • 100
user11128
  • 295
  • 1
  • 2
  • 12

3 Answers3

1

I don't know of any command out-of-the-box that does such operation. However, you can pull it off in a straightforward way using a single matrix multiplication.


The trick is to construct a tensor containing all pairs of batch elements by starting from already stacked S,T tensor. Then by multiplying it with a properly chosen mask tensor... In this method, keeping track of shapes and dimension sizes is essential.

  1. The stack is given by (notice the reshape, we essentially flatten the batch elements from S and T into a single batch axis on ST):

    >>> ST = torch.stack((S, T)).reshape(6, 2)
    >>> ST
    tensor([[0.7792, 0.0095],
            [0.1893, 0.8159],
            [0.0680, 0.7194],
            [1.0000, 1.0000],
            [1.0000, 1.0000],
            [1.0000, 1.0000]]
    # ST.shape = (6, 2)
    
  2. You can retrieve all (S[i], T[j]) pairs using range and itertools.product:

    >>> indices = torch.tensor(list(product(range(0, 3), range(3, 6))))
    tensor([[0, 3],
            [0, 4],
            [0, 5],
            [1, 3],
            [1, 4],
            [1, 5],
            [2, 3],
            [2, 4],
            [2, 5]])
    # indices.shape = (9, 2)
    
  3. From there, we construct one-hot-encodings of the indices using torch.nn.functional.one_hot:

    >>> mask = one_hot(indices).float()
    tensor([[[1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0.]],
    
            [[1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0.]],
    
            [[1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1.]],
    
            [[0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0.]],
    
            [[0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0.]],
    
            [[0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1.]],
    
            [[0., 0., 1., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0.]],
    
            [[0., 0., 1., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0.]],
    
            [[0., 0., 1., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1.]]])
    # mask.shape = (9, 2, 6)
    
  4. Finally, we compute the matrix multiplication and reshape it to the final form:

    >>> (mask@ST).reshape(3, 3, 4, 1)
    tensor([[[[0.7792],
              [0.0095],
              [1.0000],
              [1.0000]],
    
             [[0.7792],
              [0.0095],
              [1.0000],
              [1.0000]],
    
             [[0.7792],
              [0.0095],
              [1.0000],
              [1.0000]]],
    
    
            [[[0.1893],
              [0.8159],
              [1.0000],
              [1.0000]],
    
             [[0.1893],
              [0.8159],
              [1.0000],
              [1.0000]],
    
             [[0.1893],
              [0.8159],
              [1.0000],
              [1.0000]]],
    
    
            [[[0.0680],
              [0.7194],
              [1.0000],
              [1.0000]],
    
             [[0.0680],
              [0.7194],
              [1.0000],
              [1.0000]],
    
             [[0.0680],
              [0.7194],
              [1.0000],
              [1.0000]]]])
    

I initially went with torch.einsum: torch.einsum('bf,pib->pif', ST, mask). But, later realized than that bf,pib->pif reduces nicely to a simple torch.Tensor.matmul operation if we switch the two operands: i.e. with pib,bf->pif (subscript b is reduced in the middle).

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Thanks for the reply. However, if the batch size is big (in my case the tensors S and T both have size (2000,40,1), then the system crashes when we create the mask because it requires too much memory). Is there a way to extract the relevant indices from the stacked tensor without the mask? – user11128 Sep 03 '21 at 08:27
  • I will have a look at it with `torch.gather`, but there is no guarantee. – Ivan Sep 03 '21 at 08:46
0

In numpy something called np.meshgrid is used.

https://stackoverflow.com/a/35608701/3259896

So in pytorch, it would be

torch.stack(
torch.meshgrid(x, y)
).T.reshape(-1,2)

Where x and y are your two lists. You can use any number. x, y , z, etc.

And then you reshape it to the number of lists you use.

So if you used three lists, use .reshape(-1,3), for four use .reshape(-1,4), etc.

So for 5 tensors, use

torch.stack(
torch.meshgrid(a, b, c, d, e)
).T.reshape(-1,5)
SantoshGupta7
  • 5,607
  • 14
  • 58
  • 116
0

My solution is to use torch.repeat_interleave and Tensor.repeat to reproduce a for loop.

For instance I have

>>> tensor_1 # shape(3, 4)
tensor([[0.1164, 0.6336, 0.7037, 0.1360],
        [0.9316, 0.9569, 0.4108, 0.5415],
        [0.6325, 0.3159, 0.3307, 0.0700]])
>>> tensor_2 # shape(2, 4)
tensor([[0.1687, 0.3315, 0.1523, 0.1123],
        [0.1792, 0.8289, 0.7350, 0.2479]])

To get the result of

for i in range(tensor_1.shape[0]):
    for j in range(tensor_2.shape[0]):
        torch.cat([tensor_1[i, ...], tensor_2[j, ...]], dim=0) # shape (8, )

We can do

b, h = tensor_1.shape
e, h = tensor_2.shape

result = torch.cat(
    [torch.repeat_interleave(tensor_1, repeats=e, dim=0), tensor_2.repeat(b, 1), ]
    , dim=-1,
).reshape(b, e, 2 * h)

(torch.repeat_interleave is for the outer for loop, which repeats tensor_1 e times in an element-wise manner. Tensor.repeat is for the inner for loop, repeats tensor_2 b times as a whole), which gives

>>> result
tensor([[[0.1164, 0.6336, 0.7037, 0.1360, 0.1687, 0.3315, 0.1523, 0.1123],
         [0.1164, 0.6336, 0.7037, 0.1360, 0.1792, 0.8289, 0.7350, 0.2479]],

        [[0.9316, 0.9569, 0.4108, 0.5415, 0.1687, 0.3315, 0.1523, 0.1123],
         [0.9316, 0.9569, 0.4108, 0.5415, 0.1792, 0.8289, 0.7350, 0.2479]],

        [[0.6325, 0.3159, 0.3307, 0.0700, 0.1687, 0.3315, 0.1523, 0.1123],
         [0.6325, 0.3159, 0.3307, 0.0700, 0.1792, 0.8289, 0.7350, 0.2479]]])
mk6
  • 3
  • 2