3

I have been trying to debug a certain model that uses torch.einsum operator in a layer which is repeated a couple of times.

While trying to analyze the GPU memory usage of the model during training, I have noticed that a certain Einsum operation dramatically increases the memory usage. I am dealing with multi-dimensional matrices. The operation is torch.einsum('b q f n, b f n d -> b q f d', A, B).

It is also worth mentioning that:

  • x was assigned before to a tensor of the same shape.
  • In every layer (they are all identical), the GPU memory is linearly increases) after this operation, and does not deallocate until the end of the model iteration.

I have been wondering why this operation uses so much memory, and why the memory stays allocated after every iteration over that layer type.

Ivan
  • 34,531
  • 8
  • 55
  • 100
ofir1080
  • 105
  • 1
  • 5
  • The operations ends up looping over dimensions `b`, `q`, `f`, `n`, and `d` and computes `A[b][q][f][n] * B[b][f][n][d]`. Of course, depending your input dimensions, you can expect this operation to take a while to compute. See [this answer](https://stackoverflow.com/questions/26089893/understanding-numpys-einsum/66007300#66007300) for more on the `einsum` operator. – Ivan Aug 30 '21 at 13:01
  • I understand that this operation is computationally expensive. But why is it expensive in the **memory manner** and why after this line, the memory is not deallocated? – ofir1080 Aug 30 '21 at 13:11
  • Well, you are still allocating `b*q*f*d` elements by the end of this operation, do you agree? – Ivan Aug 30 '21 at 13:17
  • Well, it is assigned to `x`, which is already of the this shape, so I'd think it would be overwritten. wouldn't it? – ofir1080 Aug 30 '21 at 14:06
  • Why would it? If `x` is still in the lexical scope then it will remain in memory. Can you give details on where is `x` defined exactly? – Ivan Aug 30 '21 at 14:31
  • `x` is the input data sample into the NN. much like in [here](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html), but in `forward` there is `x = torch.einsum('b q f n, b f n d -> b q f d', A, B)`. This operation is sort of a projection layer applied on `x`, which preserves its shape and overwrites its previous values. Am I seeing it wrong..? – ofir1080 Aug 30 '21 at 15:09

1 Answers1

1

Variable "x" is indeed overwritten, but the tensor data is kept in memory (also called the layer's activation) for later usage in the backward pass.

So in turn you are effectively allocating new memory data for the result of torch.einsum, but you won't be replacing x's memory even if it has been seemingly overwritten.


To pass this to the test, you can compute the forward pass under the torch.no_grad() context manager (where those activations won't be kept in memory) and see the memory usage difference, compared with a standard inference.

Ivan
  • 34,531
  • 8
  • 55
  • 100