As you said, you need to use torch.cat
, but also torch.reshape
. Assume the following:
a = torch.rand(8,2)
b = torch.rand(8,4)
c = torch.rand(8,6)
And assume that it is indeed possible to reshape the tensors to a (8,3,-1)
shape, where -1
stands for as long as it need to be, then:
d = torch.cat((a,b,c), dim=1)
e = torch.reshape(d, (8,3,-1))
I'll explain. Because the 1st dimension if different in a,b,c
the concatenation has to be along the 1st dimension, as seen in variable d
. Then, you can reshape the tensor as seen in e
where the -1
stands for "as long as it needs to be".