1

I have a list of 3 tensors with the shape: (8, 2), (8, 4), (8, 6)

And I want to turn this list into this shape: (8, 3, x)

How do I do this? I know I need to use some combination of torch.cat, torch.stack and torch.transpose, but I can't figure it out.

Thanks in advance!

Mad Physicist
  • 107,652
  • 25
  • 181
  • 264
mimookies
  • 63
  • 5

1 Answers1

1

As you said, you need to use torch.cat, but also torch.reshape. Assume the following:

a = torch.rand(8,2)
b = torch.rand(8,4)
c = torch.rand(8,6)

And assume that it is indeed possible to reshape the tensors to a (8,3,-1) shape, where -1 stands for as long as it need to be, then:

d = torch.cat((a,b,c), dim=1)
e = torch.reshape(d, (8,3,-1))

I'll explain. Because the 1st dimension if different in a,b,c the concatenation has to be along the 1st dimension, as seen in variable d. Then, you can reshape the tensor as seen in e where the -1 stands for "as long as it needs to be".

Tomer Geva
  • 1,764
  • 4
  • 16
  • Completely forgot that -1 existed for this, thank you! found out I actually can't reshape it to what I wanted, but figured that my model might not work properly – mimookies Mar 13 '22 at 16:21
  • Sure, if you have more problems with your model I'll be happy to try and help – Tomer Geva Mar 13 '22 at 16:24