2

I have 2 tensors, A and B:

A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]

tensor D comes from the operations "tensordot -> permute". How can I implement a new operation f() to make the tensordot operation after f() like:

A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)

1 Answers1

2

Have you considered using torch.einsum which is very flexible?

D = torch.einsum('ijab,abkl->ikjl', A, B)

The problem with tensordot is that it outputs all dimensions of A before those of B and what you are looking for (when permuting) is to "interleave" dimensions from A and B.

Shai
  • 111,146
  • 38
  • 238
  • 371