0

When I run several calculations with torch.einsum in a row, the first one is always much slower than the following calculations.

The following code and plot illustrates the problem:

import torch as tor
from timeit import default_timer as timer

N = 1000
L = 10

time_arr = np.zeros(L)

for i in range(L):
    a = tor.randn(N, N).to("cuda:0") #3 random 1000x1000 matrices for each cycle
    b = tor.randn(N, N).to("cuda:0")
    c = tor.randn(N, N).to("cuda:0")

    time_start  = timer()
    tor.einsum("ij, kj",tor.einsum("ij, ja",  aa, ab), ac)
    time_end  = timer()

    time_arr[i] = time_end - time_start

Plot of the different times for each cylce of the loop

Fynn
  • 1
  • 2
  • I imagine some delayed initialization happens during the first invocation. Is there a problem here? :) – AKX Jan 10 '23 at 16:37
  • Presumably "ij, ja" needs to be parsed and appropriately code generated. The first time through, this calculation is cached, so the next time is faster. It's also possible that `einsum` loads new packages. – Frank Yellin Jan 10 '23 at 16:50

0 Answers0